package cache

import (
	"container/heap"
	"context"
	"sync"
	"time"

	"github.com/golang/protobuf/proto"
	"github.com/jonboulle/clockwork"

	"a.yandex-team.ru/library/go/core/log/zap"
	"a.yandex-team.ru/travel/library/go/metrics"
	tpb "a.yandex-team.ru/travel/proto"

	"a.yandex-team.ru/travel/buses/backend/internal/common/dict"
	"a.yandex-team.ru/travel/buses/backend/internal/common/logging"
	ipb "a.yandex-team.ru/travel/buses/backend/internal/common/proto"
	pb "a.yandex-team.ru/travel/buses/backend/proto"
)

type Config struct {
	StorageExpiration time.Duration `config:"cache-storageexpiration,required"`
}

var DefaultConfig = Config{
	StorageExpiration: time.Minute,
}

type SearchKey struct {
	SupplierID uint32
	FromType   pb.EPointKeyType
	FromID     uint32
	ToType     pb.EPointKeyType
	ToID       uint32
	DateYear   int32
	DateMonth  int32
	DateDay    int32
}

func NewSearchKey(supplierID uint32, from *pb.TPointKey, to *pb.TPointKey, date *tpb.TDate) SearchKey {
	return SearchKey{
		SupplierID: supplierID,
		FromType:   from.Type,
		FromID:     from.Id,
		ToType:     to.Type,
		ToID:       to.Id,
		DateYear:   date.Year,
		DateMonth:  date.Month,
		DateDay:    date.Day,
	}
}

type SearchCacheQueue []*ipb.TSearchCacheRecord

func (q *SearchCacheQueue) Len() int {
	return len(*q)
}

func (q *SearchCacheQueue) Less(i, j int) bool {
	return (*q)[i].CreatedAt < (*q)[j].CreatedAt
}

func (q *SearchCacheQueue) Swap(i, j int) {
	(*q)[i], (*q)[j] = (*q)[j], (*q)[i]
	(*q)[i].QueueIndex = int64(i)
	(*q)[j].QueueIndex = int64(j)
}

func (q *SearchCacheQueue) Push(x interface{}) {
	record := x.(*ipb.TSearchCacheRecord)
	record.QueueIndex = int64(q.Len())
	*q = append(*q, record)
}

func (q *SearchCacheQueue) Pop() interface{} {
	size := q.Len()
	record := (*q)[size-1]
	(*q)[size-1] = nil     // to avoid memory leaks
	record.QueueIndex = -1 // for safety
	*q = (*q)[:size-1]
	return record
}

type MetricsKey struct {
	status     ipb.ECacheRecordStatus
	supplierID uint32
}

func MakeMetricsKey(record *ipb.TSearchCacheRecord) MetricsKey {
	return MetricsKey{status: record.Status, supplierID: record.SupplierId}
}

type SearchRecordStorage struct {
	cfg                   *Config
	logger                *zap.Logger
	appMetrics            *metrics.AppMetrics
	metricsKeysCount      int
	metricsRidesCount     int
	metricsRecordStatuses map[MetricsKey]int
	mutex                 sync.RWMutex
	records               map[SearchKey]*ipb.TSearchCacheRecord
	ttl                   time.Duration
	clock                 clockwork.Clock
	expiringQueue         SearchCacheQueue
	rideIDIndex           map[string]map[SearchKey]struct{}
	ctx                   context.Context
}

func NewSearchRecordStorage(
	ttl time.Duration, appMetrics *metrics.AppMetrics, cfg *Config, logger *zap.Logger,
) *SearchRecordStorage {
	return NewSearchRecordStorageWithClock(ttl, clockwork.NewRealClock(), appMetrics, cfg, logger)
}

func NewSearchRecordStorageWithClock(
	ttl time.Duration, clock clockwork.Clock, appMetrics *metrics.AppMetrics, cfg *Config, logger *zap.Logger,
) *SearchRecordStorage {

	srs := &SearchRecordStorage{
		logger:                logging.WithModuleContext(logger, "searchcache"),
		appMetrics:            appMetrics,
		metricsKeysCount:      0,
		metricsRidesCount:     0,
		metricsRecordStatuses: make(map[MetricsKey]int),
		cfg:                   cfg,
		records:               make(map[SearchKey]*ipb.TSearchCacheRecord),
		ttl:                   ttl,
		clock:                 clock,
		rideIDIndex:           make(map[string]map[SearchKey]struct{}),
		ctx:                   nil,
	}

	return srs
}

func (srs *SearchRecordStorage) Run(ctx context.Context) {
	srs.ctx = ctx
	go srs.loop()
}

func (srs *SearchRecordStorage) expired(record *ipb.TSearchCacheRecord) bool {
	recordAge := time.Duration(srs.clock.Now().Unix()-record.CreatedAt) * time.Second
	return recordAge > srs.ttl
}

func (srs *SearchRecordStorage) Get(key SearchKey) (*ipb.TSearchCacheRecord, bool) {
	srs.mutex.RLock()
	defer srs.mutex.RUnlock()
	record, ok := srs.records[key]
	if !ok {
		return &ipb.TSearchCacheRecord{}, false
	}
	return proto.Clone(record).(*ipb.TSearchCacheRecord), true
}

func (srs *SearchRecordStorage) GetSearchKeysByRideID(rideID string) (map[SearchKey]struct{}, bool) {
	srs.mutex.RLock()
	defer srs.mutex.RUnlock()
	keys, ok := srs.rideIDIndex[rideID]
	return keys, ok
}

func (srs *SearchRecordStorage) GetFirstSearchKeyByRideID(rideID string) (SearchKey, bool) {
	searchKeys, ok := srs.GetSearchKeysByRideID(rideID)
	if !ok {
		return SearchKey{}, false
	}
	for key := range searchKeys {
		return key, true
	}
	srs.logger.Errorf("SearchRecordStorage.GetFirstSearchKeyByRideID: empty keys set was not deleted")
	return SearchKey{}, false
}

// TODO: remove setKey logic
func setKey(record *ipb.TSearchCacheRecord, key SearchKey) {
	record.SupplierId = key.SupplierID
	record.From = &pb.TPointKey{Type: key.FromType, Id: key.FromID}
	record.To = &pb.TPointKey{Type: key.ToType, Id: key.ToID}
	record.Date = &tpb.TDate{Year: key.DateYear, Month: key.DateMonth, Day: key.DateDay}
}

func getKey(record *ipb.TSearchCacheRecord) SearchKey {
	return NewSearchKey(record.SupplierId, record.From, record.To, record.Date)
}

func (srs *SearchRecordStorage) Set(key SearchKey, record *ipb.TSearchCacheRecord) {
	srs.mutex.Lock()
	defer srs.mutex.Unlock()

	setKey(record, key)

	previousRecord, ok := srs.records[key]
	if ok {
		queueIndex := previousRecord.QueueIndex
		record.QueueIndex = queueIndex
		srs.expiringQueue[queueIndex] = record
		heap.Fix(&srs.expiringQueue, int(queueIndex))

		srs.deleteUnsafeFromRideIDIndex(key, previousRecord)

		srs.metricsRecordStatuses[MakeMetricsKey(previousRecord)]--
		srs.metricsRidesCount -= len(previousRecord.Rides)
	} else {
		heap.Push(&srs.expiringQueue, record)
		srs.metricsKeysCount++
	}

	srs.updateUnsafeRideIDIndex(key, record)

	srs.metricsRecordStatuses[MakeMetricsKey(record)]++
	srs.metricsRidesCount += len(record.Rides)
	srs.records[key] = record
}

func (srs *SearchRecordStorage) deleteUnsafeFromRideIDIndex(key SearchKey, record *ipb.TSearchCacheRecord) {
	for _, ride := range record.Rides {
		if keys, ok := srs.rideIDIndex[ride.Id]; ok {
			delete(keys, key)
			if len(keys) == 0 {
				delete(srs.rideIDIndex, ride.Id)
			}
		}
	}
}

func (srs *SearchRecordStorage) updateUnsafeRideIDIndex(key SearchKey, record *ipb.TSearchCacheRecord) {
	const errorMessage = "SearchRecordStorage.updateUnsafeRideIDIndex"
	for _, ride := range record.Rides {
		if keys, ok := srs.rideIDIndex[ride.Id]; ok {
			for anotherKey := range keys {
				if anotherKey == key {
					continue
				}
				anotherRecord, ok := srs.records[anotherKey]
				if !ok {
					continue
				}
				for _, anotherRide := range anotherRecord.Rides {
					if ride.Id == anotherRide.Id && !proto.Equal(ride, anotherRide) {
						srs.logger.Errorf(
							"%s: duplicate rideId=%s with different data. Deleting previous search for key=%v",
							errorMessage, ride.Id, anotherKey)
						srs.deleteUnsafe(anotherKey)
						break
					}
				}
			}
		}
		// because we have delete keys
		if _, ok := srs.rideIDIndex[ride.Id]; !ok {
			srs.rideIDIndex[ride.Id] = make(map[SearchKey]struct{})
		}
		if _, ok := srs.rideIDIndex[ride.Id][key]; !ok {
			srs.rideIDIndex[ride.Id][key] = struct{}{}
		}
	}
}

func (srs *SearchRecordStorage) deleteUnsafe(key SearchKey) {
	record, ok := srs.records[key]
	if ok {
		srs.deleteUnsafeFromRideIDIndex(key, record)
		heap.Remove(&srs.expiringQueue, int(record.QueueIndex))
		srs.metricsKeysCount--
		srs.metricsRecordStatuses[MakeMetricsKey(record)]--
		srs.metricsRidesCount -= len(record.Rides)
		delete(srs.records, key)
	}
}

// Do not use direct delete(srs.records, key)
func (srs *SearchRecordStorage) Delete(key SearchKey) {
	srs.mutex.Lock()
	defer srs.mutex.Unlock()
	srs.deleteUnsafe(key)
}

func (srs *SearchRecordStorage) Len() int {
	return len(srs.records)
}

func (srs *SearchRecordStorage) loop() {
	ticker := srs.clock.NewTicker(srs.cfg.StorageExpiration)
	for {
		select {
		case <-ticker.Chan():
			srs.mutex.Lock()
			for srs.Len() > 0 {
				record := srs.expiringQueue[0]
				if !srs.expired(record) {
					break
				}
				srs.deleteUnsafe(getKey(record))
			}
			srs.mutex.Unlock()
			for metricsKey, cnt := range srs.metricsRecordStatuses {
				supplier, err := dict.GetSupplier(metricsKey.supplierID)
				if err != nil {
					continue
				}
				srs.appMetrics.GetOrCreateGauge(
					"cache", map[string]string{
						"status":   metricsKey.status.String(),
						"supplier": supplier.Name,
					}, "statuses").Set(float64(cnt))
			}
			srs.appMetrics.GetOrCreateGauge(
				"cache", nil, "keys_count").Set(float64(srs.metricsKeysCount))
			srs.appMetrics.GetOrCreateGauge(
				"cache", nil, "rides_count").Set(float64(srs.metricsRidesCount))
		case <-srs.ctx.Done():
			ticker.Stop()
			return
		}
	}
}

func (srs *SearchRecordStorage) Iter(ctx context.Context) <-chan proto.Message {
	ch := make(chan proto.Message)
	go func() {
		srs.mutex.RLock()
		keys := make([]SearchKey, srs.Len())
		keysLen := 0
		for key := range srs.records {
			keys[keysLen] = key
			keysLen++
		}
		srs.mutex.RUnlock()

		defer close(ch)
		for i := 0; i < keysLen; i++ {
			record, ok := srs.Get(keys[i])
			if !ok {
				continue
			}
			select {
			case ch <- record:
				continue
			case <-ctx.Done():
				return
			}
		}
	}()
	return ch
}

func (srs *SearchRecordStorage) Add(message proto.Message) {
	m := message.(*ipb.TSearchCacheRecord)
	srs.Set(getKey(m), m)
}
