package popularityscores

import (
	"sort"
	"sync"

	"a.yandex-team.ru/travel/avia/shared_flights/api/pkg/structs"
)

type CarriersPopularityScores struct {
	scoresMap                 map[int32]map[string]int32
	flightNumberToCarriersMap map[string][]int32
	flightNumberMapMutex      sync.RWMutex
}

type FlightPatternsProvider interface {
	GetFlightPatterns() map[int32]*structs.FlightPattern
}

const DefaultNationalVersion = ""
const PreferredNationalVersion = "ru"

func NewCarriersPopularityScores() *CarriersPopularityScores {
	return &CarriersPopularityScores{
		scoresMap:                 make(map[int32]map[string]int32),
		flightNumberToCarriersMap: make(map[string][]int32),
		flightNumberMapMutex:      sync.RWMutex{},
	}
}

func (s *CarriersPopularityScores) SetScore(carrier int32, nationalVersion string, score int32) {
	scores, ok := s.scoresMap[carrier]
	if !ok {
		scores = make(map[string]int32)
	}
	scores[nationalVersion] = score
	s.scoresMap[carrier] = scores
}

func (s *CarriersPopularityScores) UpdateDefaultScore(carrier int32) {
	scores, ok := s.scoresMap[carrier]
	if !ok {
		return
	}
	if _, ok = scores[DefaultNationalVersion]; !ok {
		ruScore, okRu := scores[PreferredNationalVersion]
		if okRu {
			scores[DefaultNationalVersion] = ruScore
			s.scoresMap[carrier] = scores
		}
	}
}

func (s *CarriersPopularityScores) GetScore(carrier int32, nationalVersion string) int32 {
	scores, ok := s.scoresMap[carrier]
	if !ok {
		return 0
	}
	result, ok := scores[nationalVersion]
	if !ok {
		result, ok = scores[DefaultNationalVersion]
		if !ok {
			return 0
		}
	}
	return result
}

func (s *CarriersPopularityScores) GetPopularityScores(nationalVersion string) map[int32]int32 {
	result := make(map[int32]int32)
	for c := range s.scoresMap {
		result[c] = s.GetScore(c, nationalVersion)
	}
	return result
}

func (s *CarriersPopularityScores) UpdateFlightNumbersCache(segmentsProvider FlightPatternsProvider) error {
	// build flight numbers cache
	carriersMap := make(map[string]map[int32]bool)
	for _, segment := range segmentsProvider.GetFlightPatterns() {
		value, ok := carriersMap[segment.MarketingFlightNumber]
		if !ok {
			value = make(map[int32]bool)
		}
		value[segment.MarketingCarrier] = true
		carriersMap[segment.MarketingFlightNumber] = value
	}

	for flightNumber, flightCarriers := range carriersMap {
		carriersList := s.getCarriersInternal(flightNumber)
		for _, carrier := range carriersList {
			flightCarriers[carrier] = false
		}
		for carrier, shouldAddCarrier := range flightCarriers {
			if !shouldAddCarrier {
				continue
			}
			carriersList = append(carriersList, carrier)
		}
		s.updateCarriersMapInternal(flightNumber, carriersList)
	}

	return nil
}

func (s *CarriersPopularityScores) updateCarriersMapInternal(flightNumber string, carriers []int32) {
	s.flightNumberMapMutex.Lock()
	defer s.flightNumberMapMutex.Unlock()
	s.flightNumberToCarriersMap[flightNumber] = carriers
}

func (s *CarriersPopularityScores) GetCarriers(flightNumber string, nationalVersion string) []int32 {
	carriers := s.getCarriersInternal(flightNumber)
	if len(carriers) == 0 {
		return carriers
	}
	// sort carriers by their popularity scores
	sort.SliceStable(carriers, func(i, j int) bool {
		return s.GetScore(carriers[i], nationalVersion) > s.GetScore(carriers[j], nationalVersion)
	})
	return carriers
}

func (s *CarriersPopularityScores) getCarriersInternal(flightNumber string) []int32 {
	s.flightNumberMapMutex.RLock()
	defer s.flightNumberMapMutex.RUnlock()
	carriers, ok := s.flightNumberToCarriersMap[flightNumber]
	if !ok {
		return []int32{}
	}
	return append([]int32(nil), carriers...)
}
