package cache

import (
	"context"
	"fmt"
	"runtime"
	"strings"
	"sync"
	"time"

	"github.com/golang/protobuf/proto"
	"github.com/jonboulle/clockwork"
	"github.com/opentracing/opentracing-go"

	"a.yandex-team.ru/library/go/core/log"
	"a.yandex-team.ru/library/go/units"
	travel "a.yandex-team.ru/travel/proto"
	api "a.yandex-team.ru/travel/trains/search_api/api/tariffs"
	"a.yandex-team.ru/travel/trains/search_api/internal/pkg/date"
	"a.yandex-team.ru/travel/trains/search_api/internal/pkg/searchprops"
	"a.yandex-team.ru/travel/trains/search_api/internal/pkg/tariffs"
)

type tariffRouteKeyDate struct {
	year  int32
	month int32
	day   int32
}

type tariffRouteKey struct {
	departurePointExpressID int32
	arrivalPointExpressID   int32
	departureDate           tariffRouteKeyDate
}

func makeTariffRouteKey(departureID int32, arrivalID int32, departureDate *travel.TDate) tariffRouteKey {
	return tariffRouteKey{
		departurePointExpressID: departureID,
		arrivalPointExpressID:   arrivalID,
		departureDate: tariffRouteKeyDate{
			year:  departureDate.Year,
			month: departureDate.Month,
			day:   departureDate.Day,
		},
	}
}

type TariffCache struct {
	cfg         *Config
	logger      log.Logger
	lock        sync.RWMutex
	tariffs     map[tariffRouteKey][]byte
	expirations TariffExpirations
	metrics     *cacheMetrics
	clock       clockwork.Clock
}

func NewTariffCache(cfg *Config, logger log.Logger) *TariffCache {
	return NewTariffCacheWithClock(clockwork.NewRealClock(), cfg, logger)
}

func NewTariffCacheWithClock(clock clockwork.Clock, cfg *Config, logger log.Logger) *TariffCache {
	return &TariffCache{
		cfg:         cfg,
		logger:      logger,
		tariffs:     make(map[tariffRouteKey][]byte),
		expirations: NewTariffExpirations(),
		metrics:     newCacheMetrics(),
		clock:       clock,
	}
}

func (c *TariffCache) Run(ctx context.Context) {
	c.logger.Info("TariffCache runs")
	go func() {
		removeExpiredTicker := c.clock.NewTicker(c.cfg.RemoveExpiredPeriod)
		for {
			select {
			case <-ctx.Done():
				return
			case <-removeExpiredTicker.Chan():
				c.removeExpired()
			}
		}
	}()
}

func (c *TariffCache) removeExpired() {
	i := 0
	for c.removeOneExpired() {
		i++
		runtime.Gosched()
	}
	c.logger.Info("Expired tariffs removed", log.Int("count", i))
	c.metrics.cachedTariffsCount.Set(float64(len(c.tariffs)))
}

func (c *TariffCache) removeOneExpired() bool {
	const (
		stopRemoving     = false
		continueRemoving = true
	)

	c.lock.Lock()
	defer c.lock.Unlock()

	if c.expirations.Empty() {
		return stopRemoving
	}

	expiration := c.expirations.Head()
	if c.clock.Now().Before(expiration.expireAt) {
		return stopRemoving
	}

	c.expirations.Pop()
	delete(c.tariffs, expiration.routeKey)
	return continueRemoving
}

func (c *TariffCache) Select(
	ctx context.Context,
	departurePointExpressIDs []int32,
	arrivalPointExpressIDs []int32,
	leftBorder time.Time,
	rightBorder time.Time,
) ([]*api.DirectionTariffInfo, error) {
	span, _ := opentracing.StartSpanFromContext(ctx, "trains.search_api.internal.pkg.tariffs.cache.Cache.Select")
	defer span.Finish()

	directions := make([]tariffs.TrainDirection, 0, len(departurePointExpressIDs)*len(arrivalPointExpressIDs))
	for _, departureID := range departurePointExpressIDs {
		for _, arrivalID := range arrivalPointExpressIDs {
			directions = append(directions, tariffs.TrainDirection{
				DeparturePointExpressID: departureID,
				ArrivalPointExpressID:   arrivalID,
			})
		}
	}
	return c.selectByDirections(ctx, directions, leftBorder, rightBorder), nil
}

func (c *TariffCache) SelectByDirections(
	ctx context.Context,
	directions []tariffs.TrainDirection,
	leftBorder time.Time,
	rightBorder time.Time,
) ([]*api.DirectionTariffInfo, error) {
	span, _ := opentracing.StartSpanFromContext(ctx, "trains.search_api.internal.pkg.tariffs.cache.Cache.SelectByDirections")
	defer span.Finish()

	return c.selectByDirections(ctx, directions, leftBorder, rightBorder), nil
}

func (c *TariffCache) selectByDirections(ctx context.Context, directions []tariffs.TrainDirection, leftBorder time.Time, rightBorder time.Time) []*api.DirectionTariffInfo {
	var result []*api.DirectionTariffInfo

	directions = filterGoodDirections(directions)

	for _, direction := range directions {
		for _, departureDate := range getDepartureDateRange(leftBorder, rightBorder) {
			key := makeTariffRouteKey(
				direction.DeparturePointExpressID,
				direction.ArrivalPointExpressID,
				date.GetProtoFromDate(departureDate),
			)
			c.lock.RLock()
			bytes, found := c.tariffs[key]
			c.lock.RUnlock()

			if !found {
				continue
			}
			result = append(result, decodeDirectionTariffInfo(bytes))
		}
	}

	searchprops.Set(ctx, "train_tariff_cache_missed", "0")
	if len(result) == 0 {
		searchprops.Set(ctx, "train_tariff_cache_missed", "1")
	}

	c.logger.Info("selected tariffs by directions",
		log.Int("tariffs_count", len(result)),
		log.String("directions", formatDirections(directions)),
		log.Time("leftBorder", leftBorder),
		log.Time("rightBorder", rightBorder),
	)
	return result
}

func formatDirections(directions []tariffs.TrainDirection) string {
	rows := make([]string, len(directions))
	for i, direction := range directions {
		rows[i] = fmt.Sprintf("%d->%d", direction.DeparturePointExpressID, direction.ArrivalPointExpressID)
	}
	return fmt.Sprintf("[%s]", strings.Join(rows, ", "))
}

func filterGoodDirections(allDirections []tariffs.TrainDirection) []tariffs.TrainDirection {
	filteredDirections := make(map[tariffs.TrainDirection]struct{})
	for _, direction := range allDirections {
		if direction.DeparturePointExpressID != direction.ArrivalPointExpressID {
			filteredDirections[direction] = struct{}{}
		}
	}

	var result []tariffs.TrainDirection
	for direction := range filteredDirections {
		result = append(result, direction)
	}
	return result
}

func getDepartureDateRange(leftBorder, rightBorder time.Time) (result []time.Time) {
	leftBorder = date.DateFromTime(leftBorder)
	rightBorder = date.DateFromTime(rightBorder).Add(units.Day)

	for leftBorder.Before(rightBorder) {
		result = append(result, leftBorder)
		leftBorder = leftBorder.Add(units.Day)
	}
	return result
}

func (c *TariffCache) Iter(ctx context.Context) <-chan proto.Message {
	keys := c.getTariffKeys()
	ch := make(chan proto.Message, 5)
	go func() {
		defer close(ch)
		for _, key := range keys {
			c.lock.RLock()
			bytes, found := c.tariffs[key]
			c.lock.RUnlock()

			if !found {
				continue
			}
			select {
			case ch <- decodeDirectionTariffInfo(bytes):
				continue
			case <-ctx.Done():
				return
			}
		}
	}()
	return ch
}

func (c *TariffCache) Add(message proto.Message) {
	tariff, ok := message.(*api.DirectionTariffInfo)
	if !ok {
		c.logger.Error("unexpected proto message type")
		return
	}

	key := makeTariffRouteKey(
		tariff.DeparturePointExpressId,
		tariff.ArrivalPointExpressId,
		tariff.DepartureDate,
	)
	c.lock.Lock()
	defer c.lock.Unlock()

	bytes, found := c.tariffs[key]
	if found {
		storedTariffInfo := decodeDirectionTariffInfo(bytes)
		updatedAt := tariff.UpdatedAt.AsTime()
		storedUpdatedAt := storedTariffInfo.UpdatedAt.AsTime()

		if storedUpdatedAt.After(updatedAt) {
			return
		}
		tariff.CreatedAt = storedTariffInfo.CreatedAt
	}

	if c.makeExpireAt(tariff).Before(c.clock.Now()) {
		return
	}
	c.saveTariffExpiration(key, tariff)

	if bytes, err := encodeDirectionTariffInfo(tariff); err != nil {
		c.logger.Error("tariff encoding fails", log.Error(err))
		c.metrics.errors.Inc()
	} else {
		c.tariffs[key] = bytes
	}
	c.metrics.cachedTariffsCount.Set(float64(len(c.tariffs)))
}

func decodeDirectionTariffInfo(bytes []byte) *api.DirectionTariffInfo {
	info := new(api.DirectionTariffInfo)
	_ = proto.Unmarshal(bytes, info)
	return info
}

func encodeDirectionTariffInfo(info *api.DirectionTariffInfo) ([]byte, error) {
	const funcName = "trains.search_api.internal.pkg.tariffs.cache.setDirectionTariffInfo"

	bytes, err := proto.Marshal(info)
	if err != nil {
		return nil, fmt.Errorf("%s: %w", funcName, err)
	}
	return bytes, nil
}

func (c *TariffCache) saveTariffExpiration(key tariffRouteKey, tariff *api.DirectionTariffInfo) {
	_, found := c.tariffs[key]
	if found {
		return
	}

	expireAt := c.makeExpireAt(tariff)
	expiration := &tariffExpiration{
		expireAt: expireAt,
		routeKey: key,
	}
	c.expirations.Push(expiration)
}

func (c *TariffCache) makeExpireAt(tariff *api.DirectionTariffInfo) time.Time {
	return date.GetDateFromProto(tariff.DepartureDate).Add(units.Day)
}

func (c *TariffCache) getTariffKeys() []tariffRouteKey {
	keys := make([]tariffRouteKey, 0, len(c.tariffs))

	c.lock.RLock()
	defer c.lock.RUnlock()

	for key := range c.tariffs {
		keys = append(keys, key)
	}
	return keys
}
