package flightp2p

import (
	"a.yandex-team.ru/travel/avia/shared_flights/api/internal/logthrottler"
	"math"
	"time"

	"a.yandex-team.ru/library/go/core/log"
	"a.yandex-team.ru/travel/avia/shared_flights/api/internal/services/storage/flight_p2p/format"
	"a.yandex-team.ru/travel/avia/shared_flights/lib/go/dtutil"
	"a.yandex-team.ru/travel/avia/shared_flights/lib/go/logger"
	"a.yandex-team.ru/travel/proto/shared_flights/snapshots"
)

const INF = 24 * 30 * time.Hour

func (service *FlightP2PServiceImpl) GetFlightP2PSegmentInfo(
	fromStations, toStations []*snapshots.TStationWithCodes,
	nationalVersion string,
	showBanned bool,
	debug bool,
) (format.SegmentInfoResponse, error) {
	responseMap := make(map[format.PointPair]format.PointPairInfo)
	for _, fromStation := range fromStations {
		for _, toStation := range toStations {
			if fromStation.Station.Id == toStation.Station.Id {
				continue
			}
			pointPair := format.PointPair{StationFrom: fromStation.Station.Id, StationTo: toStation.Station.Id}
			if _, alreadyCalculated := responseMap[pointPair]; alreadyCalculated {
				continue
			}

			flights, minDuration := service.getP2PCountAndDuration(
				fromStation, toStation,
				nationalVersion, showBanned,
			)
			var flightList []format.FlightKeyDebugInfo
			if debug {
				for key, value := range flights {
					flightList = append(flightList, format.FlightKeyDebugInfo{
						FlightKey: key,
						MinDuration: format.DurationString{
							Duration:       value.Duration.String(),
							DepartureTime:  value.DepartureTime.String(),
							DepartureShift: value.DepartureShift.String(),
							ArrivalTime:    value.ArrivalTime.String(),
							ArrivalShift:   value.ArrivalShift.String(),
							DepartureDate:  string(value.DepartureDate.StringDateDashed()),
							ArrivalDate:    string(value.ArrivalDate.StringDateDashed()),
						},
					})
				}
			}
			responseMap[pointPair] = format.PointPairInfo{
				RoutesCount: len(flights),
				Debug:       flightList,
				MinDuration: uint64(math.Round(minDuration.Minutes())),
			}
		}
	}
	response := make(format.SegmentInfoResponse, 0, len(responseMap))
	for k, v := range responseMap {
		response = append(response, format.SegmentInfoResponseItem{PointPair: k, PointPairInfo: v})
	}
	return response, nil

}

func (service *FlightP2PServiceImpl) getP2PCountAndDuration(
	fromStation, toStation *snapshots.TStationWithCodes,
	nationalVersion string,
	showBanned bool,
) (map[format.FlightKey]format.Duration, time.Duration) {

	flightCache, err := service.BuildFlightsCache(
		[]*snapshots.TStationWithCodes{fromStation},
		[]*snapshots.TStationWithCodes{toStation},
		nationalVersion,
		showBanned,
		FlightCacheOptions{IgnoreCodeshare: true},
	)

	if err != nil {
		return nil, 0
	}

	flights := make(map[format.FlightKey]format.Duration, len(flightCache))

	minDuration := INF

	departureTimezone := service.Timezones().GetTimeZone(fromStation.Station)
	arrivalTimezone := service.Timezones().GetTimeZone(toStation.Station)
	offsetCache := make(timezoneOffsetCache)
	for threadVariantKey, fv := range flightCache {
		flightValue := ThreadVariants(fv)
		if len(flightValue) == 0 {
			logger.Logger().Warn(
				"Empty thread variants",
				log.Reflect("threadVariantKey", threadVariantKey),
				log.Int32("fromStation", fromStation.Station.Id),
				log.Int32("toStation", toStation.Station.Id),
			)
			continue
		}

		foundStationPair := false
		countCombinations := 0
		for segmentIndices := range GenerateCombinations(flightValue) {
			legCombination := flightValue.GetCombination(segmentIndices)
			countCombinations++
			if len(segmentIndices) == 0 || len(legCombination) == 0 {
				logger.Logger().Warn(
					"Zero-length sequence generated",
					log.Reflect("segmentIndices", segmentIndices),
					log.Reflect("legCombination", legCombination),
					log.Reflect("flightValue", flightValue),
				)
				continue
			}

			var departureSequenceNumber, arrivalSequenceNumber int
			departureSequenceNumber, arrivalSequenceNumber = getRouteLegNumbers(legCombination, fromStation, toStation)
			if departureSequenceNumber < 0 || arrivalSequenceNumber < 0 {
				continue
			}
			foundStationPair = true
			duration := service.getFlightSegmentDuration(
				legCombination, departureSequenceNumber, arrivalSequenceNumber,
				departureTimezone, arrivalTimezone,
				offsetCache,
			)
			if duration.Duration > 0 {
				if _, found := flights[threadVariantKey]; found {
					flights[threadVariantKey] = flights[threadVariantKey].Min(duration)
				} else {
					flights[threadVariantKey] = duration
				}
				if duration.Duration < minDuration {
					minDuration = duration.Duration
				}
			}
		}
		if !foundStationPair {
			logthrottler.LogWithThrottling(
				notFoundStationPairLogLimiterKey{
					flight: threadVariantKey,
					from:   fromStation.Station.Id,
					to:     toStation.Station.Id,
				},
				10*time.Second,
				logger.Logger().AddCallerSkip(1).Error,
				"Broken thread: cannot find departure and arrival station pair",
				log.Reflect("fromStation", fromStation),
				log.Reflect("toStation", toStation),
				log.Reflect("threadVariantKey", threadVariantKey),
				log.Int("combinations", countCombinations),
			)
		}
	}
	if minDuration == INF {
		minDuration = 0
	}
	return flights, minDuration
}

type notFoundStationPairLogLimiterKey struct {
	flight format.FlightKey
	from   int32
	to     int32
}

func (service *FlightP2PServiceImpl) getFlightSegmentDuration(
	combination []*FlightLegValue,
	segmentFrom int,
	segmentTo int,
	departureTimezone *time.Location,
	arrivalTimezone *time.Location,
	offsetCache timezoneOffsetCache,
) format.Duration {
	if departureTimezone == nil || arrivalTimezone == nil {
		return format.Duration{}
	}

	departureTime := dtutil.IntTime(combination[segmentFrom].FlightBase.DepartureTimeScheduled)
	arrivalTime := dtutil.IntTime(combination[segmentTo].FlightBase.ArrivalTimeScheduled)

	if !departureTime.IsValid() || !arrivalTime.IsValid() {
		return format.Duration{}
	}
	departureDate := combination[segmentFrom].Mask().GetDates()[0]
	arrivalDate := combination[segmentTo].Mask().GetDates()[0]
	departureOffset := offsetCache.Get(departureDate, departureTimezone)
	arrivalOffset := offsetCache.Get(arrivalDate, arrivalTimezone)

	var arrivalDaysShift time.Duration
	for i := segmentFrom; i <= segmentTo; i++ {
		arrivalDaysShift += time.Duration(combination[i].ArrivalDayShift())
	}

	duration := (arrivalTime.Duration() + arrivalDaysShift*24*time.Hour - arrivalOffset) - (departureTime.Duration() - departureOffset)

	return format.Duration{
		Duration:       duration,
		DepartureTime:  departureTime.Duration(),
		DepartureShift: departureOffset,
		ArrivalTime:    arrivalTime.Duration(),
		ArrivalShift:   arrivalOffset,
		DepartureDate:  departureDate,
		ArrivalDate:    arrivalDate,
	}
}

func getRouteLegNumbers(thread []*FlightLegValue, fromStation *snapshots.TStationWithCodes, toStation *snapshots.TStationWithCodes) (int, int) {
	var departureSegmentLegNumber = -1
	var arrivalSegmentLegNumber = -1
	var legNumber = -1
	for legNumber = 0; legNumber < len(thread); legNumber++ {
		segment := thread[legNumber]
		if segment.FlightBase.DepartureStation == int64(fromStation.Station.Id) {
			departureSegmentLegNumber = legNumber
			break
		}
	}
	if departureSegmentLegNumber < 0 {
		return -1, -1
	}
	for legNumber = departureSegmentLegNumber; legNumber < len(thread); legNumber++ {
		segment := thread[legNumber]
		if segment.FlightBase.ArrivalStation == int64(toStation.Station.Id) {
			arrivalSegmentLegNumber = legNumber
			break
		}
	}
	return departureSegmentLegNumber, arrivalSegmentLegNumber
}
