package flightp2p

import (
	"sort"
	"time"

	"a.yandex-team.ru/library/go/core/xerrors"
	dto "a.yandex-team.ru/travel/avia/shared_flights/api/internal/services/storage/DTO"
	"a.yandex-team.ru/travel/avia/shared_flights/api/internal/services/storage/flight_p2p/format"
	"a.yandex-team.ru/travel/avia/shared_flights/api/pkg/structs"
	"a.yandex-team.ru/travel/avia/shared_flights/lib/go/dtutil"
	"a.yandex-team.ru/travel/avia/shared_flights/lib/go/math"
	"a.yandex-team.ru/travel/proto/shared_flights/snapshots"
)

const OperatesDaily = 1234567

type CodeshareFlightKey struct {
	marketingCarrier int32
	marketingFlight  string
}

// Maps every leg of the flight to its flightBase and schedule
type FlightValue map[int]FlightLegValue

type FlightLegValue struct {
	FlightBase        structs.FlightBase
	operating         FlightDateMask
	codeshares        map[CodeshareFlightKey]FlightDateMask
	DepartureStation  int32
	ArrivalStation    int32
	arrivalDayShift   int
	departureDayShift int
	flightTitle       string
}

func (f *FlightLegValue) Mask() dtutil.DateMask {
	return f.operating.operates
}

func (f *FlightLegValue) ArrivalDayShift() int {
	return f.arrivalDayShift
}

func (f *FlightLegValue) DepartureDayShift() int {
	return f.departureDayShift
}

type FlightDateMask struct {
	carrier      int32           // marketing carrier
	flightNumber string          // marketing flight number
	flightTitle  string          // flight title (uses marketing or filing carrier code)
	operates     dtutil.DateMask // when it flies
	banned       dtutil.DateMask // when it's banned
}

type FlightBaseHashFields struct {
	OperatingCarrier      int32
	OperatingFlightNumber string
	LegNumber             int32

	DepartureStation       int64
	DepartureTimeScheduled int32
	DepartureTerminal      string
	ArrivalStation         int64
	ArrivalTimeScheduled   int32
	ArrivalTerminal        string

	AircraftTypeID int64
}

func (service *FlightP2PServiceImpl) GetFlightsP2PSchedule(
	departFrom []*snapshots.TStationWithCodes,
	arriveTo []*snapshots.TStationWithCodes,
	nationalVersion string,
	showBanned bool,
	startScheduleDate dtutil.IntDate,
) (format.ScheduleResponse, error) {
	departureStations := make([]int32, 0, len(departFrom))
	departureStationsSet := make(map[int32]void)
	for _, station := range departFrom {
		departureStations = append(departureStations, station.Station.Id)
		departureStationsSet[station.Station.Id] = none
	}
	arrivalStations := make([]int32, 0, len(arriveTo))
	arrivalStationsSet := make(map[int32]void)
	for _, station := range arriveTo {
		arrivalStations = append(arrivalStations, station.Station.Id)
		arrivalStationsSet[station.Station.Id] = none
	}
	response := format.ScheduleResponse{
		DepartureStations: departureStations,
		ArrivalStations:   arrivalStations,
		Flights:           make([]format.ScheduleFlight, 0),
	}

	departFromTz := service.GetTimeZoneByStationID(int64(departFrom[0].Station.Id))
	if departFromTz == nil {
		return response, xerrors.Errorf("cannot load timezone for station  %+v", departFrom[0].Station.Id)
	}

	arriveToTz := service.GetTimeZoneByStationID(int64(arriveTo[0].Station.Id))
	if arriveToTz == nil {
		return response, xerrors.Errorf("cannot load timezone for station  %+v", arriveTo[0].Station.Id)
	}

	flightsCache, err := service.BuildFlightsCache(departFrom, arriveTo, nationalVersion, showBanned)
	if err != nil {
		return response, err
	}
	service.CreateResponse(&response, flightsCache, departFromTz, arriveToTz, startScheduleDate)
	return response, nil
}

type FlightCacheOptions struct {
	IgnoreCodeshare bool
}

func (service *FlightP2PServiceImpl) BuildFlightsCache(
	departFrom []*snapshots.TStationWithCodes,
	arriveTo []*snapshots.TStationWithCodes,
	nationalVersion string,
	showBanned bool,
	options_ ...FlightCacheOptions,
) (map[format.FlightKey]map[tFlightLeg]map[tFlightBaseHash]*FlightLegValue, error) {
	options := FlightCacheOptions{}
	for _, option := range options_ {
		options = option
	}
	startDate := service.Storage.FlightStorage().GetStartDate()
	startDateIndex := dtutil.DateCache.IndexOfStringDateP(dtutil.StringDate(startDate))
	// Sorts flight patterns out by flight key -> leg number -> flightBaseId
	legs := make(map[format.FlightKey]map[tFlightLeg]map[tFlightBaseHash]*FlightLegValue)

	flightBasesCache := make(map[FlightBaseHashFields]int32)

	for _, departureStation := range departFrom {
		for _, arrivalStation := range arriveTo {

			flightKeys, ok := service.Storage.FlightStorage().GetFlightsP2P(
				int64(departureStation.Station.Id),
				int64(arrivalStation.Station.Id),
			)
			if !ok {
				continue
			}

			for _, flightKey := range flightKeys {
				flightPatterns, ok := service.Storage.FlightStorage().GetFlightsByKey(flightKey)
				if !ok {
					continue
				}

				// first, regroup flight pattern by leg and flightBaseId
				// then try combining them from min to max and check whether the dates intersect
				for _, legFlights := range flightPatterns {
					for _, flightPattern := range legFlights {
						if flightPattern.IsCodeshare && options.IgnoreCodeshare {
							continue
						}
						flightBase, err := service.Storage.FlightStorage().GetFlightBase(flightPattern.FlightBaseID, flightPattern.IsDop)
						if err != nil {
							return legs, err
						}
						// skip half-flights in the point-to-point response - they could not be shown anyway
						if !dtutil.IntTime(flightBase.DepartureTimeScheduled).IsValid() ||
							!dtutil.IntTime(flightBase.ArrivalTimeScheduled).IsValid() {
							continue
						}

						operatingDateMask := dtutil.NewDateMask(startDateIndex, dtutil.MaxDaysInSchedule)
						operatingDateMask.AddRange(
							dtutil.StringDate(flightPattern.OperatingFromDate),
							dtutil.StringDate(flightPattern.OperatingUntilDate),
							flightPattern.OperatingOnDays,
						)

						bannedDateMask := dtutil.NewDateMask(startDateIndex, dtutil.MaxDaysInSchedule)
						bannedFrom, bannedUntil, isBanned := service.Storage.BlacklistRuleStorage().GetBannedDates(flightBase, flightPattern, nationalVersion)
						if isBanned {
							if !showBanned {
								operatingDateMask.RemoveRange(dtutil.StringDate(bannedFrom), dtutil.StringDate(bannedUntil), OperatesDaily)
							} else {
								bannedDateMask.AddRange(dtutil.StringDate(bannedFrom), dtutil.StringDate(bannedUntil), OperatesDaily)
								bannedDateMask.IntersectWithMask(operatingDateMask)
							}
						}

						if operatingDateMask.IsEmpty() {
							continue
						}

						flightKey := format.FlightKey{
							OperatingCarrier: flightBase.OperatingCarrier,
							OperatingFlight:  flightBase.OperatingFlightNumber,
						}

						legsMapByNumber, ok := legs[flightKey]
						if !ok {
							legsMapByNumber = make(map[tFlightLeg]map[tFlightBaseHash]*FlightLegValue)
							legs[flightKey] = legsMapByNumber
						}
						legsMapByFlightBase, ok := legsMapByNumber[flightPattern.LegNumber]
						if !ok {
							legsMapByFlightBase = make(map[tFlightBaseHash]*FlightLegValue)
							legsMapByNumber[flightPattern.LegNumber] = legsMapByFlightBase
						}
						flightBaseHash := getOrCreateFlightBaseHash(flightBasesCache, &flightBase)
						currentLeg, ok := legsMapByFlightBase[flightBaseHash]
						if !ok {
							currentLeg = &FlightLegValue{
								FlightBase:        flightBase,
								codeshares:        make(map[CodeshareFlightKey]FlightDateMask),
								DepartureStation:  departureStation.Station.Id,
								ArrivalStation:    arrivalStation.Station.Id,
								arrivalDayShift:   int(flightPattern.ArrivalDayShift),
								departureDayShift: int(flightPattern.DepartureDayShift),
							}
							if flightPattern.IsCodeshare {
								codeShareKey := CodeshareFlightKey{
									marketingCarrier: flightPattern.MarketingCarrier,
									marketingFlight:  flightPattern.MarketingFlightNumber,
								}
								currentLeg.codeshares[codeShareKey] = FlightDateMask{
									carrier:      flightPattern.MarketingCarrier,
									flightNumber: flightPattern.MarketingFlightNumber,
									flightTitle:  flightPattern.FlightTitle(),
									operates:     operatingDateMask,
									banned:       bannedDateMask,
								}
							} else {
								currentLeg.operating = FlightDateMask{
									carrier:      flightPattern.MarketingCarrier,
									flightTitle:  flightPattern.FlightTitle(),
									flightNumber: flightPattern.MarketingFlightNumber,
									operates:     operatingDateMask,
									banned:       bannedDateMask,
								}
								currentLeg.flightTitle = flightPattern.FlightTitle()
							}
						} else {
							// merge leg with existing leg
							if flightPattern.IsCodeshare {
								codeShareKey := CodeshareFlightKey{
									marketingCarrier: flightPattern.MarketingCarrier,
									marketingFlight:  flightPattern.MarketingFlightNumber,
								}
								codeshareFlight, ok := currentLeg.codeshares[codeShareKey]
								if !ok {
									currentLeg.codeshares[codeShareKey] = FlightDateMask{
										carrier:      flightPattern.MarketingCarrier,
										flightNumber: flightPattern.MarketingFlightNumber,
										flightTitle:  flightPattern.FlightTitle(),
										operates:     operatingDateMask,
										banned:       bannedDateMask,
									}
								} else {
									codeshareFlight.operates.AddMask(operatingDateMask)
									codeshareFlight.banned.AddMask(bannedDateMask)
								}
							} else {
								if currentLeg.operating.flightTitle == "" {
									currentLeg.operating = FlightDateMask{
										carrier:      flightPattern.MarketingCarrier,
										flightTitle:  flightPattern.FlightTitle(),
										flightNumber: flightPattern.MarketingFlightNumber,
									}
									currentLeg.flightTitle = flightPattern.FlightTitle()
								}
								if currentLeg.operating.operates == nil {
									currentLeg.operating.operates = operatingDateMask
								} else {
									currentLeg.operating.operates.AddMask(operatingDateMask)
								}
								if currentLeg.operating.banned == nil {
									currentLeg.operating.banned = bannedDateMask
								} else {
									currentLeg.operating.banned.AddMask(bannedDateMask)
								}
							}
						}
						legsMapByFlightBase[flightBaseHash] = currentLeg
					}
				}
			}
		}
	}
	return legs, nil
}

// flightsCache maps flight legs out by flight key -> leg number -> flightBaseId
func (service *FlightP2PServiceImpl) CreateResponse(
	response *format.ScheduleResponse,
	flightsCache map[format.FlightKey]map[int32]map[int32]*FlightLegValue,
	departFromTz *time.Location,
	arriveToTz *time.Location,
	startScheduleDate dtutil.IntDate,
) {
	flights := make([]format.ScheduleFlight, 0, len(flightsCache))
	for _, fv := range flightsCache {
		flightValue := ThreadVariants(fv)
		for segmentIndices := range GenerateCombinations(flightValue) {
			legCombination := flightValue.GetCombination(segmentIndices)
			route := make([]int64, 0, len(legCombination)+1)
			legNumFrom := -1
			legNumTo := -1
			arrivalDayShift := 0
			startDayShift := 0
			for num, leg := range legCombination {
				if len(route) == 0 {
					route = append(route, leg.FlightBase.DepartureStation)
				}
				route = append(route, leg.FlightBase.ArrivalStation)
				if int32(leg.FlightBase.DepartureStation) == leg.DepartureStation {
					legNumFrom = num
				}
				if legNumFrom == -1 {
					dayShift := leg.arrivalDayShift
					if num < len(legCombination)-1 {
						departureDayShift := legCombination[num+1].departureDayShift
						if departureDayShift > dayShift {
							dayShift = departureDayShift
						}
					}
					startDayShift -= dayShift
				}
				if int32(leg.FlightBase.ArrivalStation) == leg.ArrivalStation {
					legNumTo = num
					arrivalDayShift += leg.arrivalDayShift
				} else if legNumFrom >= 0 && legNumTo == -1 {
					departureDayShift := 0
					if num < len(legCombination)-1 {
						departureDayShift = legCombination[num+1].departureDayShift
					}
					arrivalDayShift += math.Max(leg.arrivalDayShift, departureDayShift)
				}
			}
			if legNumFrom == -1 || legNumTo == -1 {
				continue
			}
			departureLeg := legCombination[legNumFrom]
			arrivalLeg := legCombination[legNumTo]

			if departureLeg.operating.operates == nil {
				continue
			}

			departureIntDate := dtutil.IntDate(0)
			if !departureLeg.operating.operates.IsEmpty() {
				departureIntDate = departureLeg.operating.operates.GetFirstDate()
			} else if !departureLeg.operating.banned.IsEmpty() {
				departureIntDate = departureLeg.operating.banned.GetFirstDate()
			}

			departureIntTime := dtutil.IntTime(departureLeg.FlightBase.DepartureTimeScheduled)
			arrivalIntTime := dtutil.IntTime(arrivalLeg.FlightBase.ArrivalTimeScheduled)
			var (
				departureDate time.Time
				arrivalDate   time.Time
			)
			if int(departureIntDate) > 0 {
				departureDate = dtutil.IntToTime(departureIntDate, departureIntTime, departFromTz)
				arrivalDate = dtutil.IntToTime(
					departureIntDate.AddDaysP(departureLeg.arrivalDayShift),
					arrivalIntTime,
					arriveToTz,
				)
			} else {
				departureDate = time.Now()
				arrivalDate = departureDate
			}

			flight := format.ScheduleFlight{
				TitledFlight: dto.TitledFlight{
					FlightID: dto.FlightID{
						AirlineID: departureLeg.FlightBase.OperatingCarrier,
						Number:    departureLeg.FlightBase.OperatingFlightNumber,
					},
					Title: departureLeg.flightTitle,
				},
				DepartureTime:     dtutil.FormatTimeHHMM(departureIntTime),
				DepartureTimezone: dtutil.FormatTimezone(departureDate),
				DepartureTerminal: departureLeg.FlightBase.DepartureTerminal,
				DepartureStation:  int32(departureLeg.FlightBase.DepartureStation),
				ArrivalTime:       dtutil.FormatTimeHHMM(arrivalIntTime),
				ArrivalTimezone:   dtutil.FormatTimezone(arrivalDate),
				ArrivalTerminal:   arrivalLeg.FlightBase.ArrivalTerminal,
				ArrivalStation:    int32(arrivalLeg.FlightBase.ArrivalStation),
				StartTime:         dtutil.FormatTimeHHMM(dtutil.IntTime(legCombination[0].FlightBase.DepartureTimeScheduled)),
				StartDayShift:     int32(startDayShift),
				TransportModelID:  departureLeg.FlightBase.AircraftTypeID,
				Route:             route,
				ArrivalDayShift:   int32(arrivalDayShift),
			}
			if len(departureLeg.codeshares) > 0 {
				flight.Codeshares = make([]format.CodeshareFlight, 0, len(departureLeg.codeshares))
				for _, codeshare := range departureLeg.codeshares {
					if codeshare.operates != nil && codeshare.flightTitle != departureLeg.flightTitle {
						codeshare.operates.RemoveMask(codeshare.banned)
						if !codeshare.operates.IsEmpty() {
							flightCodeshare := format.CodeshareFlight{}
							flightCodeshare.TitledFlight = dto.TitledFlight{
								FlightID: dto.FlightID{
									AirlineID: codeshare.carrier,
									Number:    codeshare.flightNumber,
								},
								Title: codeshare.flightTitle,
							}
							flightCodeshare.Masks = GenerateMasks(codeshare.operates, startScheduleDate)
							if len(flightCodeshare.Masks) > 0 {
								flight.Codeshares = append(flight.Codeshares, flightCodeshare)
							}
						}
					}
				}
			}
			departureLeg.operating.operates.RemoveMask(departureLeg.operating.banned)
			hasMasks := false
			if !departureLeg.operating.operates.IsEmpty() {
				flight.Masks = GenerateMasks(departureLeg.operating.operates, startScheduleDate)
				hasMasks = hasMasks || len(flight.Masks) > 0
			}
			if !departureLeg.operating.banned.IsEmpty() {
				flight.Banned = GenerateMasks(departureLeg.operating.banned, startScheduleDate)
				hasMasks = hasMasks || len(flight.Banned) > 0
			}
			if hasMasks {
				flights = append(flights, flight)
			}
		}
	}
	sortScheduleFlightsInPlace(flights)
	response.Flights = flights
}

type ThreadLegSegment interface {
	// LegSegments return segments for current leg number. Depending on underlying type
	// it legNum might be just an index in sequence (0, 1, ...) or real leg number from database (1, 2, ...)
	LegSegments(legNum int) map[segmentID]ThreadSegment
	// Legs returns an slice of indices for addressing legs via LegSegments in ascending order.
	// Provided slice should not have repeated values.
	// Using returned slice elements from first to last in LegSegments function
	// should give segment variants in real flight order
	Legs() []int
}

type ThreadSegment interface {
	Mask() dtutil.DateMask
	ArrivalDayShift() int
	DepartureDayShift() int
}

type ThreadVariants map[tFlightLeg]map[tFlightBaseHash]*FlightLegValue

// LegSegments return segments for current leg number
func (tv ThreadVariants) LegSegments(leg int) map[segmentID]ThreadSegment {
	FlightBaseMap := tv[tFlightLeg(leg)]
	var segments = make(map[segmentID]ThreadSegment, len(FlightBaseMap))
	for index, segment := range FlightBaseMap {
		// conversion from specific type to interface
		segments[index] = segment
	}
	return segments
}
func (tv ThreadVariants) Legs() []int {
	legNums := make([]int, 0, len(tv))
	for num := range tv {
		legNums = append(legNums, int(num))
	}
	sort.Ints(legNums)
	return legNums
}

// Converts received indices back to slice of specific type
func (tv ThreadVariants) GetCombination(indices ThreadIndices) []*FlightLegValue {
	segmentValues := make([]*FlightLegValue, 0, len(indices))
	for _, index := range indices {
		segmentValues = append(segmentValues, tv[tFlightLeg(index.LegID)][index.SegmentID])
	}
	return segmentValues
}

type ThreadVariant []ThreadSegment

func GenerateCombinations(
	flightValue ThreadLegSegment,
) <-chan ThreadIndices {
	c := make(chan ThreadIndices)
	legNums := flightValue.Legs()
	go func(c chan ThreadIndices) {
		defer close(c)

		putNextCombination(c, flightValue, legNums, make(ThreadIndices, 0, len(legNums)), nil)
	}(c)

	return c
}

type segmentID = int32
type segmentIndex = struct {
	LegID     int
	SegmentID segmentID
}
type ThreadIndices = []segmentIndex

func putNextCombination(
	c chan<- ThreadIndices,
	flightValue ThreadLegSegment,
	legNums []int,
	currentValue ThreadIndices,
	prev ThreadSegment,
) {
	pos := len(currentValue)
	if len(currentValue) >= len(legNums) {
		c <- append([]segmentIndex(nil), currentValue...)
		return
	}
	currentValuePushed := false
	for segmentID, flightSegment := range flightValue.LegSegments(legNums[pos]) {
		if flightSegment.Mask() == nil {
			if len(currentValue) > 0 && !currentValuePushed {
				currentValuePushed = true
				c <- append([]segmentIndex(nil), currentValue...)
			}
			continue
		}
		// Skip combinations that don't fly on the same dates
		if pos > 0 {
			intersectionDaysMask := dtutil.CloneDateMask(prev.Mask())
			shiftDays := math.Max(prev.ArrivalDayShift(), flightSegment.DepartureDayShift())
			intersectionDaysMask.ShiftDays(shiftDays)
			intersectionDaysMask.IntersectWithMask(flightSegment.Mask())
			if intersectionDaysMask.IsEmpty() {
				continue
			}
		}

		currentValuePushed = true
		nextValue := append(currentValue, segmentIndex{legNums[pos], segmentID})
		putNextCombination(c, flightValue, legNums, nextValue, flightSegment)
	}
	if len(currentValue) > 0 && !currentValuePushed {
		c <- append([]segmentIndex(nil), currentValue...)
	}
}

const maxWeeks = 60

func GenerateMasks(dm dtutil.DateMask, startScheduleDate dtutil.IntDate) []format.Mask {
	var masks []format.Mask
	dates := dm.GetDateIndexesFromDate(startScheduleDate)
	if len(dates) == 0 {
		return masks
	}
	weeks := make([]uint8, maxWeeks)
	start := dates[0]
	maxWeekIndex := 0

	for _, date := range dates {
		weekIndex := (date - start) / 7
		weekday := dtutil.DateCache.WeekDay(date)
		weeks[weekIndex] |= 1 << int(weekday)
		if weekIndex > maxWeekIndex {
			maxWeekIndex = weekIndex
		}
	}

	maskStartWeek := 0
	maskEndWeek := 0
	maskOperatingDays := weeks[0]
	for weekIndex := 1; weekIndex <= maxWeekIndex; weekIndex++ {
		operatingDays := weeks[weekIndex]
		if operatingDays != 0 && operatingDays == maskOperatingDays {
			maskEndWeek = weekIndex
			continue
		}
		if maskStartWeek >= 0 {
			mask := createMask(start+maskStartWeek*7, start+maskEndWeek*7+6, maskOperatingDays)
			masks = append(masks, mask)
			maskStartWeek = -1
			maskEndWeek = -1
			maskOperatingDays = 0
		}
		if operatingDays == 0 {
			continue
		}
		maskStartWeek = weekIndex
		maskEndWeek = weekIndex
		maskOperatingDays = operatingDays
	}
	if maskStartWeek >= 0 {
		mask := createMask(start+maskStartWeek*7, start+maskEndWeek*7+6, maskOperatingDays)
		masks = append(masks, mask)
	}

	return masks
}

func getOrCreateFlightBaseHash(fbCache map[FlightBaseHashFields]int32, fb *structs.FlightBase) int32 {
	fbCacheValue := FlightBaseHashFields{
		OperatingCarrier:       fb.OperatingCarrier,
		OperatingFlightNumber:  fb.OperatingFlightNumber,
		LegNumber:              fb.LegNumber,
		DepartureStation:       fb.DepartureStation,
		DepartureTimeScheduled: fb.DepartureTimeScheduled,
		DepartureTerminal:      fb.DepartureTerminal,
		ArrivalStation:         fb.ArrivalStation,
		ArrivalTimeScheduled:   fb.ArrivalTimeScheduled,
		ArrivalTerminal:        fb.ArrivalTerminal,
		AircraftTypeID:         fb.AircraftTypeID,
	}
	value, ok := fbCache[fbCacheValue]
	if !ok {
		fbCache[fbCacheValue] = fb.ID
		return fb.ID
	}
	return value
}

func createMask(startIndex, endIndex int, maskOperatingDays uint8) format.Mask {
	operatesOn := int32(dtutil.FromBitmask(maskOperatingDays))
	if operatesOn < 10 && endIndex-startIndex <= 7 {
		// when the mask is actually a day
		for dateIndex := startIndex; dateIndex <= endIndex; dateIndex++ {
			if int32(dtutil.DateCache.WeekDay(dateIndex)) == operatesOn {
				return format.Mask{
					From:  dtutil.DateCache.Date(dateIndex).StringDateDashed(),
					Until: dtutil.DateCache.Date(dateIndex).StringDateDashed(),
					On:    operatesOn,
				}
			}
		}
	}
	return format.Mask{
		From:  dtutil.DateCache.Date(startIndex).StringDateDashed(),
		Until: dtutil.DateCache.Date(endIndex).StringDateDashed(),
		On:    operatesOn,
	}
}

func sortScheduleFlightsInPlace(flights []format.ScheduleFlight) {
	sort.Slice(flights, func(i, j int) bool {
		if flights[i].DepartureTime != flights[j].DepartureTime {
			return flights[i].DepartureTime < flights[j].DepartureTime
		}
		for idx, station := range flights[i].Route {
			if len(flights[j].Route) <= idx {
				return true
			}
			if station == flights[j].Route[idx] {
				continue
			}
			return station < flights[j].Route[idx]
		}
		return true
	})
}
