package popular

import (
	"context"
	"sort"
	"sync"
	"time"

	"a.yandex-team.ru/library/go/core/log/zap"
	"github.com/golang/protobuf/proto"

	"a.yandex-team.ru/travel/buses/backend/internal/common/logging"
	ipb "a.yandex-team.ru/travel/buses/backend/internal/common/proto"
	pb "a.yandex-team.ru/travel/buses/backend/proto"
)

type Config struct {
	DeflationPeriod     time.Duration `config:"popular-deflationperiod,required"`
	DeflationMultiplier float32       `config:"popular-deflationmultiplier,required"`
	RebuildPeriod       time.Duration `config:"popular-rebuildperiod,required"`
	TopSize             int           `config:"popular-topsize,required"`
	TopSizeFrom         int           `config:"popular-topsizefrom,required"`
	DropThreshold       float32       `config:"popular-dropthreshold,required"`
}

var DefaultConfig = Config{
	DeflationPeriod:     3 * time.Hour,
	DeflationMultiplier: 0.99,
	RebuildPeriod:       15 * time.Minute,
	TopSize:             80000,
	TopSizeFrom:         20,
	DropThreshold:       1,
}

type PointKey struct {
	Type pb.EPointKeyType
	ID   uint32
}

func NewPointKey(pbPointKey *pb.TPointKey) *PointKey {
	return &PointKey{
		Type: pbPointKey.Type,
		ID:   pbPointKey.Id,
	}
}

func (pk *PointKey) ToProto() *pb.TPointKey {
	return &pb.TPointKey{
		Type: pk.Type,
		Id:   pk.ID,
	}
}

type SortedDirections []*ipb.TPopularDirection

func (sd *SortedDirections) Push(direction *ipb.TPopularDirection, limit int) {
	if limit == 0 {
		return
	}
	ln := len(*sd)
	i := sort.Search(ln, func(i int) bool { return (*sd)[i].Weight < direction.Weight })
	if i == ln {
		if ln < limit {
			*sd = append(*sd, direction)
		}
		return
	}
	if ln < limit {
		*sd = append(*sd, &ipb.TPopularDirection{})
	}
	copy((*sd)[i+1:], (*sd)[i:])
	(*sd)[i] = direction
}

type Directions struct {
	cfg                  Config
	logger               *zap.Logger
	sortedMutex          sync.RWMutex
	rawMutex             sync.RWMutex
	rawDirections        map[PointKey]map[PointKey]float32
	sortedDirections     *SortedDirections
	sortedDirectionsFrom map[PointKey]SortedDirections
}

func NewDirections(cfg Config, logger *zap.Logger) *Directions {
	return &Directions{
		cfg:                  cfg,
		logger:               logging.WithModuleContext(logger, "popular.Directions"),
		sortedMutex:          sync.RWMutex{},
		rawMutex:             sync.RWMutex{},
		rawDirections:        make(map[PointKey]map[PointKey]float32),
		sortedDirections:     nil,
		sortedDirectionsFrom: make(map[PointKey]SortedDirections),
	}
}

func (d *Directions) Run(ctx context.Context) {
	d.rebuild()
	go func() {
		rebuildTicker := time.NewTicker(d.cfg.RebuildPeriod)
		deflateTicker := time.NewTicker(d.cfg.DeflationPeriod)
		for {
			select {
			case <-rebuildTicker.C:
				d.rebuild()
			case <-deflateTicker.C:
				d.deflate()
			case <-ctx.Done():
				return
			}
		}
	}()
}

func (d *Directions) rebuild() {
	d.rawMutex.RLock()
	defer d.rawMutex.RUnlock()
	sortedDirections := make(SortedDirections, 0)
	for from, toWeight := range d.rawDirections {
		sortedDirectionsFrom := make(SortedDirections, 0)
		for to, weight := range toWeight {
			popularDirection := ipb.TPopularDirection{
				From:   from.ToProto(),
				To:     to.ToProto(),
				Weight: weight,
			}
			sortedDirections.Push(&popularDirection, d.cfg.TopSize)
			sortedDirectionsFrom.Push(&popularDirection, d.cfg.TopSizeFrom)
		}
		d.sortedMutex.Lock()
		d.sortedDirectionsFrom[from] = sortedDirectionsFrom
		d.sortedMutex.Unlock()
	}
	d.sortedMutex.Lock()
	d.sortedDirections = &sortedDirections
	d.sortedMutex.Unlock()
	d.logger.Infof("Directions.rebuild: done with %d directions and %d fromDirections",
		len(*d.sortedDirections), len(d.sortedDirectionsFrom))
}

func (d *Directions) deflate() {
	d.rawMutex.Lock()
	defer d.rawMutex.Unlock()
	for from, toWeight := range d.rawDirections {
		for to, weight := range toWeight {
			if weight*d.cfg.DeflationMultiplier < d.cfg.DropThreshold {
				delete(toWeight, to)
			} else {
				toWeight[to] = weight * d.cfg.DeflationMultiplier
			}
		}
		if len(toWeight) == 0 {
			delete(d.rawDirections, from)
		}
	}
	d.logger.Info("Directions.deflate: done")
}

func (d *Directions) Register(from *pb.TPointKey, to *pb.TPointKey) {
	fromPK := NewPointKey(from)
	toPK := NewPointKey(to)
	d.rawMutex.Lock()
	defer d.rawMutex.Unlock()
	toWeight, ok := d.rawDirections[*fromPK]
	if !ok {
		toWeight = make(map[PointKey]float32)
		d.rawDirections[*fromPK] = toWeight
	}
	toWeight[*toPK]++
}

func (d *Directions) GetDirections() []*ipb.TPopularDirection {
	if d.sortedDirections == nil {
		return nil
	}
	d.sortedMutex.RLock()
	defer d.sortedMutex.RUnlock()
	popularDirections := make([]*ipb.TPopularDirection, len(*d.sortedDirections))
	for i, d := range *d.sortedDirections {
		popularDirections[i] = proto.Clone(d).(*ipb.TPopularDirection)
	}
	return popularDirections
}

func (d *Directions) GetDirectionsFrom(from *pb.TPointKey) []*ipb.TPopularDirection {
	fromPK := NewPointKey(from)
	d.sortedMutex.RLock()
	defer d.sortedMutex.RUnlock()
	toWeight, ok := d.sortedDirectionsFrom[*fromPK]
	if !ok {
		return nil
	}
	popularDirections := make([]*ipb.TPopularDirection, len(toWeight))
	for i, direction := range toWeight {
		popularDirections[i] = proto.Clone(direction).(*ipb.TPopularDirection)
	}
	return popularDirections
}

func (d *Directions) Iter(ctx context.Context) <-chan proto.Message {
	ch := make(chan proto.Message)
	go func() {
		d.rawMutex.RLock()
		defer d.rawMutex.RUnlock()

		defer close(ch)
		for from, toWeight := range d.rawDirections {
			pbFrom := from.ToProto()
			for to, weight := range toWeight {
				pbTo := to.ToProto()
				item := &ipb.TPopularDirection{
					From:   pbFrom,
					To:     pbTo,
					Weight: weight,
				}
				select {
				case ch <- item:
					continue
				case <-ctx.Done():
					return
				}
			}
		}
	}()
	return ch
}

func (d *Directions) Add(message proto.Message) {
	d.rawMutex.Lock()
	defer d.rawMutex.Unlock()

	m := message.(*ipb.TPopularDirection)
	from := NewPointKey(m.From)
	to := NewPointKey(m.To)
	if _, ok := d.rawDirections[*from]; !ok {
		d.rawDirections[*from] = make(map[PointKey]float32)
	}
	d.rawDirections[*from][*to] = m.Weight
}
