package aeroflotvariants

import (
	"sort"
	"time"

	dto "a.yandex-team.ru/travel/avia/shared_flights/api/internal/services/storage/DTO"
	"a.yandex-team.ru/travel/avia/shared_flights/api/internal/services/storage/aeroflot_variants/format"
	"a.yandex-team.ru/travel/avia/shared_flights/api/internal/services/storage/timezone"
	"a.yandex-team.ru/travel/avia/shared_flights/api/internal/storage"
	"a.yandex-team.ru/travel/avia/shared_flights/api/internal/storage/flight"
	"a.yandex-team.ru/travel/avia/shared_flights/api/pkg/structs"
	"a.yandex-team.ru/travel/avia/shared_flights/lib/go/dtutil"
	"a.yandex-team.ru/travel/avia/shared_flights/lib/go/math"
	"a.yandex-team.ru/travel/proto/shared_flights/snapshots"
)

type AeroflotVariantsService interface {
	GetAeroflotConnectingVariants(
		departFrom []*snapshots.TStationWithCodes,
		arriveTo []*snapshots.TStationWithCodes,
		after string,
		before string,
		isOneWay bool,
		nationalVersion string,
		showBanned bool,
	) (format.Response, error)
}

type aeroflotVariantsServiceImpl struct {
	*storage.Storage
	timezone.TimeZoneUtil
}

const (
	preferredNumberOfNights = 5
	minConnectionTime       = 90      // minutes
	maxConnectionTime       = 23 * 60 // minutes
	maxVariants             = 1000
)

type flightData struct {
	// first segment or direct flight
	fp1             *structs.FlightPattern
	fb1             *structs.FlightBase
	depDateIndices1 []int
	// second segment (not filled in for the direct flighs)
	fp2             *structs.FlightPattern
	fb2             *structs.FlightBase
	depDateIndices2 []int
	hasBannedDates  bool
}

type stationPair struct {
	fromStation      int32
	toStation        int32
	hasDirectFlights bool
}

func NewStationPair(fromStation, toStation int32, hasDirectFlights bool) stationPair {
	return stationPair{
		fromStation:      fromStation,
		toStation:        toStation,
		hasDirectFlights: hasDirectFlights,
	}
}

func NewAeroflotVariantsService(
	storage *storage.Storage,
	timeZoneUtil timezone.TimeZoneUtil,
) AeroflotVariantsService {
	return &aeroflotVariantsServiceImpl{
		Storage:      storage,
		TimeZoneUtil: timeZoneUtil,
	}
}

func (service *aeroflotVariantsServiceImpl) GetAeroflotConnectingVariants(
	departFrom []*snapshots.TStationWithCodes,
	arriveTo []*snapshots.TStationWithCodes,
	after string,
	before string,
	isOneWay bool,
	nationalVersion string,
	showBanned bool,
) (format.Response, error) {
	departureStations := make([]int32, 0, len(departFrom))
	for _, station := range departFrom {
		departureStations = append(departureStations, station.Station.Id)
	}
	arrivalStations := make([]int32, 0, len(arriveTo))
	for _, station := range arriveTo {
		arrivalStations = append(arrivalStations, station.Station.Id)
	}
	response := format.Response{
		DepartureStations: departureStations,
		ArrivalStations:   arrivalStations,
		Variants:          make([]format.Variant, 0),
	}

	flightStorage := service.FlightStorage()

	remainingVariantSlots := maxVariants

	stationPairs := make([]stationPair, 0)
	for _, departureStation := range departureStations {
		for _, arrivalStation := range arrivalStations {
			flightsCache := service.Storage.FlightStorage().AeroflotCache().GetDirectFlights(
				int64(departureStation),
				int64(arrivalStation),
			)
			hasDirectFlights := len(flightsCache) > 0
			stationPairs = append(stationPairs, NewStationPair(departureStation, arrivalStation, hasDirectFlights))
		}
	}
	sort.Slice(stationPairs, func(i, j int) bool {
		if stationPairs[i].hasDirectFlights != stationPairs[j].hasDirectFlights {
			return stationPairs[i].hasDirectFlights
		}
		if stationPairs[i].fromStation != stationPairs[j].fromStation {
			return stationPairs[i].fromStation < stationPairs[j].fromStation
		}
		return stationPairs[i].toStation < stationPairs[j].toStation
	})

	var connectingFlightsCache []format.Variant
	for _, stationPair := range stationPairs {
		variants, err := service.getVariantsInternal(
			flightStorage,
			stationPair.fromStation,
			stationPair.toStation,
			after,
			before,
			isOneWay,
			nationalVersion,
			showBanned,
			remainingVariantSlots,
		)
		if err != nil {
			return response, err
		}
		hasDirectFlights := len(variants) > 0 && len(variants[0].Forward.Flights) == 1
		if !hasDirectFlights && stationPair.hasDirectFlights {
			connectingFlightsCache = append(connectingFlightsCache, variants...)
			continue
		}
		response.Variants = append(response.Variants, variants...)
		remainingVariantSlots -= len(variants)
		if remainingVariantSlots <= 0 {
			break
		}
	}

	if len(connectingFlightsCache) > 0 {
		extraVariants := math.Min(len(connectingFlightsCache)+len(response.Variants)-maxVariants, len(connectingFlightsCache))
		if extraVariants > 0 {
			connectingFlightsCache = connectingFlightsCache[0 : len(connectingFlightsCache)-extraVariants]
		}
		response.Variants = append(response.Variants, connectingFlightsCache...)
	}

	if len(response.Variants) > 0 {
		sort.Slice(response.Variants, func(i, j int) bool {
			isDirectI := len(response.Variants[i].Forward.Flights) == 1
			isDirectJ := len(response.Variants[j].Forward.Flights) == 1
			if isDirectI != isDirectJ {
				return isDirectI
			}
			return response.Variants[i].GetDepartureTimesString() < response.Variants[j].GetDepartureTimesString()
		})
	}

	return response, nil
}

func (service *aeroflotVariantsServiceImpl) getVariantsInternal(
	flightStorage flight.FlightStorage,
	departureStation int32,
	arrivalStation int32,
	after string,
	before string,
	isOneWay bool,
	nationalVersion string,
	showBanned bool,
	remainingVariantSlots int,
) ([]format.Variant, error) {
	result := make([]format.Variant, 0)

	startDateIndex := dtutil.DateCache.IndexOfStringDateP(dtutil.StringDate(after))
	endDateIndex := dtutil.DateCache.IndexOfStringDateP(dtutil.StringDate(before))
	flightsForwardData, err := service.getFlightsForward(
		flightStorage,
		int64(departureStation),
		int64(arrivalStation),
		startDateIndex,
		endDateIndex,
		nationalVersion,
		showBanned,
	)
	if err != nil {
		return result, err
	}

	if len(flightsForwardData) == 0 {
		return result, nil
	}

	if isOneWay {
		for _, slice1 := range flightsForwardData {
			for idx := range slice1.depDateIndices1 {
				result = append(result, service.createOneWayVariant(slice1, idx))
				remainingVariantSlots--
				if remainingVariantSlots <= 0 {
					return result, nil
				}
			}
		}
		return result, nil
	}

	result, err = service.getFlightsBackward(flightStorage, flightsForwardData, nationalVersion, showBanned, remainingVariantSlots)
	return result, err
}

func (service *aeroflotVariantsServiceImpl) getFlightsForward(
	flightStorage flight.FlightStorage,
	departureStation int64,
	arrivalStation int64,
	startDateIndex int,
	endDateIndex int,
	nationalVersion string,
	showBanned bool,
) ([]flightData, error) {
	result := make([]flightData, 0)

	directSegments := flightStorage.AeroflotCache().GetDirectFlights(departureStation, arrivalStation)
	if len(directSegments) > 0 {
		for _, fp := range directSegments {
			operatesFromIndex := dtutil.DateCache.IndexOfStringDateP(dtutil.StringDate(fp.OperatingFromDate))
			operatesUntilIndex := dtutil.DateCache.IndexOfStringDateP(dtutil.StringDate(fp.OperatingUntilDate))
			if operatesFromIndex > endDateIndex || operatesUntilIndex < startDateIndex {
				continue
			}
			fb, err := flightStorage.GetFlightBase(fp.FlightBaseID, false)
			if err != nil {
				return result, err
			}
			flightData := flightData{
				fp1:             fp,
				fb1:             &fb,
				depDateIndices1: make([]int, 0),
			}
			for depDateIndex := startDateIndex; depDateIndex <= endDateIndex; depDateIndex++ {
				if operatesFromIndex > depDateIndex || operatesUntilIndex < depDateIndex {
					continue
				}
				if !dtutil.OperatesOn(fp.OperatingOnDays, dtutil.DateCache.WeekDay(depDateIndex)) {
					continue
				}
				if service.Storage.BlacklistRuleStorage().IsBanned(
					fb, fp, dtutil.DateCache.Date(depDateIndex).StringDateDashed(), nationalVersion) {
					if !showBanned {
						continue
					} else {
						flightData.hasBannedDates = true
					}
				}
				flightData.depDateIndices1 = append(flightData.depDateIndices1, depDateIndex)
			}
			if len(flightData.depDateIndices1) > 0 {
				result = append(result, flightData)
			}
		}
		if len(result) > 0 {
			return result, nil
		}
	}

	connectingSegments := flightStorage.AeroflotCache().GetConnectingFlights(departureStation, arrivalStation)
	if len(connectingSegments) > 0 {
		for _, variant := range connectingSegments {
			for _, fp1 := range variant.FirstSegments {
				operatesFromIndex := dtutil.DateCache.IndexOfStringDateP(dtutil.StringDate(fp1.OperatingFromDate))
				operatesUntilIndex := dtutil.DateCache.IndexOfStringDateP(dtutil.StringDate(fp1.OperatingUntilDate))
				if operatesFromIndex > endDateIndex || operatesUntilIndex < startDateIndex {
					continue
				}
				fb1, err := flightStorage.GetFlightBase(fp1.FlightBaseID, false)
				if err != nil {
					return result, err
				}
				for _, fp2 := range variant.SecondSegments {
					operatesFromIndex2 := dtutil.DateCache.IndexOfStringDateP(dtutil.StringDate(fp2.OperatingFromDate))
					operatesUntilIndex2 := dtutil.DateCache.IndexOfStringDateP(dtutil.StringDate(fp2.OperatingUntilDate))
					// endDateIndex+2 because the second segment is allowed to depart on the next day after the first one
					// and the first segment may be an overnight segment, hence adding another day
					if operatesFromIndex2 > endDateIndex+2 || operatesUntilIndex2 < startDateIndex {
						continue
					}
					fb2, err := flightStorage.GetFlightBase(fp2.FlightBaseID, false)
					if err != nil {
						return result, err
					}
					flightData := flightData{
						fp1:             fp1,
						fb1:             &fb1,
						fp2:             fp2,
						fb2:             &fb2,
						depDateIndices1: make([]int, 0),
						depDateIndices2: make([]int, 0),
					}
					for depDateIndex := startDateIndex; depDateIndex <= endDateIndex; depDateIndex++ {
						if !dtutil.OperatesOn(fp1.OperatingOnDays, dtutil.DateCache.WeekDay(depDateIndex)) {
							continue
						}
						if operatesFromIndex > depDateIndex || operatesUntilIndex < depDateIndex {
							continue
						}
						// Compare fp1.ArrivalTime and fp2.DepartureTime, apply MCT (minimum connection time)
						if fb1.ArrivalStation != fb2.DepartureStation {
							// this is impossible, so the cache has somehow got corrupted
							continue
						}
						depDateIndex2 := depDateIndex + int(fp1.ArrivalDayShift)
						tz := service.GetTimeZoneByStationID(int64(fb1.ArrivalStation))
						flight1ArrivalTime := dtutil.LocalTime(depDateIndex2, fb1.ArrivalTimeScheduled, tz)
						flight2DepartureTime := dtutil.LocalTime(depDateIndex2, fb2.DepartureTimeScheduled, tz)

						connectionTime := int(flight2DepartureTime.Sub(flight1ArrivalTime).Minutes())
						if connectionTime < minConnectionTime {
							if connectionTime+1440 > maxConnectionTime {
								continue
							}
							depDateIndex2++
						}

						if operatesFromIndex2 > depDateIndex2 || operatesUntilIndex2 < depDateIndex2 {
							continue
						}
						if !dtutil.OperatesOn(fp2.OperatingOnDays, dtutil.DateCache.WeekDay(depDateIndex2)) {
							continue
						}
						if service.Storage.BlacklistRuleStorage().IsBanned(
							fb1, fp1, dtutil.DateCache.Date(depDateIndex).StringDateDashed(), nationalVersion) {
							if !showBanned {
								continue
							} else {
								flightData.hasBannedDates = true
							}
						}
						if service.Storage.BlacklistRuleStorage().IsBanned(
							fb2, fp2, dtutil.DateCache.Date(depDateIndex2).StringDateDashed(), nationalVersion) {
							if !showBanned {
								continue
							} else {
								flightData.hasBannedDates = true
							}
						}
						flightData.depDateIndices1 = append(flightData.depDateIndices1, depDateIndex)
						flightData.depDateIndices2 = append(flightData.depDateIndices2, depDateIndex2)
					}
					if len(flightData.depDateIndices1) > 0 {
						result = append(result, flightData)
					}
				}
			}
		}
	}

	return result, nil
}

func (service *aeroflotVariantsServiceImpl) getFlightsBackward(
	flightStorage flight.FlightStorage,
	flightsForward []flightData,
	nationalVersion string,
	showBanned bool,
	remainingVariantSlots int,
) ([]format.Variant, error) {
	result := make([]format.Variant, 0)
	for _, slice1 := range flightsForward {
		minDepDateIndex := slice1.depDateIndices1[0] + 1
		maxDepDateIndex := math.Max(minDepDateIndex, slice1.depDateIndices1[len(slice1.depDateIndices1)-1]+1)
		// Make sure the preferred return date fits into the dates range
		maxDepDateIndex = math.Max(maxDepDateIndex, minDepDateIndex+preferredNumberOfNights)
		slice1DepartureStation := slice1.fb1.DepartureStation
		slice1ArrivalStation := slice1.fb1.ArrivalStation
		if slice1.fb2 != nil {
			slice1ArrivalStation = slice1.fb2.ArrivalStation
			minDepDateIndex = slice1.depDateIndices2[0] + 1
		}
		flightsBackward, err := service.getFlightsForward(
			flightStorage,
			slice1ArrivalStation,
			slice1DepartureStation,
			minDepDateIndex,
			maxDepDateIndex+preferredNumberOfNights,
			nationalVersion,
			showBanned,
		)
		if err != nil {
			return result, err
		}
		if len(flightsBackward) > 0 {
			// Prefer returning back in 5 days
			for idx1, depDateIndex := range slice1.depDateIndices1 {
				preferredReturnDateIndex := depDateIndex + preferredNumberOfNights
				foundPreferredReturnDate := false
				for _, slice2 := range flightsBackward {
					for idx2, retDateIndex := range slice2.depDateIndices1 {
						if retDateIndex == preferredReturnDateIndex {
							foundPreferredReturnDate = true
							result = append(result, service.createRoundTripVariant(slice1, idx1, slice2, idx2))
							remainingVariantSlots--
							if remainingVariantSlots <= 0 {
								return result, nil
							}
						}
					}
				}
				if !foundPreferredReturnDate {
					// No preferred return, take the earliest possible return date
					for _, slice2 := range flightsBackward {
						result = append(result, service.createRoundTripVariant(slice1, idx1, slice2, 0))
						remainingVariantSlots--
						if remainingVariantSlots <= 0 {
							return result, nil
						}
					}
				}
			}
		}
	}
	return result, nil
}

func (service *aeroflotVariantsServiceImpl) createRoundTripVariant(
	slice1 flightData,
	depDateIndex int,
	slice2 flightData,
	retDateIndex int,
) format.Variant {
	variant := format.Variant{
		Forward:  service.createSlice(slice1, depDateIndex),
		Backward: service.createSlice(slice2, retDateIndex),
	}
	if slice1.hasBannedDates || slice2.hasBannedDates {
		variant.Banned = "true"
	}
	return variant
}

func (service *aeroflotVariantsServiceImpl) createOneWayVariant(slice1 flightData, depDateIndex int) format.Variant {
	variant := format.Variant{
		Forward: service.createSlice(slice1, depDateIndex),
	}
	if slice1.hasBannedDates {
		variant.Banned = "true"
	}
	return variant
}

// Note: depDateIndex is the index in the depDateIndices list, not an absolute date index in cache
func (service *aeroflotVariantsServiceImpl) createSlice(slice flightData, depDateIndex int) format.Slice {
	result := format.Slice{}
	if slice.hasBannedDates {
		result.Banned = "true"
	}
	depDateIndex1 := slice.depDateIndices1[depDateIndex]
	flight1DepartureTz := service.GetTimeZoneByStationID(int64(slice.fb1.DepartureStation))
	flight1ArrivalTz := service.GetTimeZoneByStationID(int64(slice.fb1.ArrivalStation))
	flight1DepartureTime := dtutil.LocalTime(depDateIndex1, slice.fb1.DepartureTimeScheduled, flight1DepartureTz)
	flight1ArrivalTime := dtutil.LocalTime(
		depDateIndex1+int(slice.fp1.ArrivalDayShift), slice.fb1.ArrivalTimeScheduled, flight1ArrivalTz)
	flight1 := format.Flight{
		ArrivalDatetime:   flight1ArrivalTime.Format(time.RFC3339),
		ArrivalStation:    int32(slice.fb1.ArrivalStation),
		ArrivalTerminal:   slice.fb1.ArrivalTerminal,
		DepartureDatetime: flight1DepartureTime.Format(time.RFC3339),
		DepartureStation:  int32(slice.fb1.DepartureStation),
		DepartureTerminal: slice.fb1.DepartureTerminal,
		TransportModelID:  slice.fb1.AircraftTypeID,
		TitledFlight: dto.TitledFlight{
			FlightID: dto.FlightID{
				AirlineID: slice.fp1.MarketingCarrier,
				Number:    slice.fp1.MarketingFlightNumber,
			},
			Title: slice.fp1.FlightTitle(),
		},
	}
	if slice.fp2 == nil {
		result.Duration = int(flight1ArrivalTime.Sub(flight1DepartureTime).Minutes())
		result.Flights = []format.Flight{flight1}
		return result
	}
	depDateIndex2 := slice.depDateIndices2[depDateIndex]
	flight2DepartureTz := service.GetTimeZoneByStationID(int64(slice.fb2.DepartureStation))
	flight2ArrivalTz := service.GetTimeZoneByStationID(int64(slice.fb2.ArrivalStation))
	flight2DepartureTime := dtutil.LocalTime(depDateIndex2, slice.fb2.DepartureTimeScheduled, flight2DepartureTz)
	flight2ArrivalTime := dtutil.LocalTime(
		depDateIndex2+int(slice.fp2.ArrivalDayShift), slice.fb2.ArrivalTimeScheduled, flight2ArrivalTz)
	flight2 := format.Flight{
		ArrivalDatetime:   flight2ArrivalTime.Format(time.RFC3339),
		ArrivalStation:    int32(slice.fb2.ArrivalStation),
		ArrivalTerminal:   slice.fb2.ArrivalTerminal,
		DepartureDatetime: flight2DepartureTime.Format(time.RFC3339),
		DepartureStation:  int32(slice.fb2.DepartureStation),
		DepartureTerminal: slice.fb2.DepartureTerminal,
		TransportModelID:  slice.fb2.AircraftTypeID,
		TitledFlight: dto.TitledFlight{
			FlightID: dto.FlightID{
				AirlineID: slice.fp2.MarketingCarrier,
				Number:    slice.fp2.MarketingFlightNumber,
			},
			Title: slice.fp2.FlightTitle(),
		},
	}
	result.Flights = []format.Flight{flight1, flight2}
	result.Duration = int(flight2ArrivalTime.Sub(flight1DepartureTime).Minutes())
	return result
}
