package filtering

import (
	"context"
	"sort"

	"golang.org/x/exp/maps"

	aviaAPI "a.yandex-team.ru/travel/app/backend/api/avia/v1"
	"a.yandex-team.ru/travel/app/backend/internal/avia/search/filtering/helpers"
	aviaSearchProto "a.yandex-team.ru/travel/app/backend/internal/avia/search/proto/v1"
	aviaProtoV2 "a.yandex-team.ru/travel/avia/library/proto/common/v2"
)

type airportFilter struct {
	needToApplyFilter bool

	forwardDepartureIds  map[uint64]struct{}
	forwardArrivalIds    map[uint64]struct{}
	backwardDepartureIds map[uint64]struct{}
	backwardArrivalIds   map[uint64]struct{}

	forwardDepartureIdsInRequest  map[uint64]struct{}
	forwardArrivalIdsInRequest    map[uint64]struct{}
	backwardDepartureIdsInRequest map[uint64]struct{}
	backwardArrivalIdsInRequest   map[uint64]struct{}

	anyForwardDeparture  bool
	anyForwardArrival    bool
	anyBackwardDeparture bool
	anyBackwardArrival   bool
}

func (af *airportFilter) getFilterID() string {
	return "AirportFiler"
}

func newAirportFilter(
	filters *aviaAPI.SearchFiltersReq,
	snippets map[string]*aviaSearchProto.Snippet,
	reference *aviaSearchProto.Reference,
	searchContext *aviaSearchProto.SearchContext,
) *airportFilter {
	af := &airportFilter{
		needToApplyFilter: true,

		forwardDepartureIds:  make(map[uint64]struct{}),
		forwardArrivalIds:    make(map[uint64]struct{}),
		backwardDepartureIds: make(map[uint64]struct{}),
		backwardArrivalIds:   make(map[uint64]struct{}),

		forwardDepartureIdsInRequest:  make(map[uint64]struct{}),
		forwardArrivalIdsInRequest:    make(map[uint64]struct{}),
		backwardDepartureIdsInRequest: make(map[uint64]struct{}),
		backwardArrivalIdsInRequest:   make(map[uint64]struct{}),
	}
	for _, s := range snippets {
		af.forwardDepartureIds[reference.Flights[s.Forward[0]].StationFromId] = struct{}{}
		af.forwardArrivalIds[reference.Flights[s.Forward[len(s.Forward)-1]].StationToId] = struct{}{}
		if len(s.Backward) > 0 {
			af.backwardDepartureIds[reference.Flights[s.Backward[0]].StationFromId] = struct{}{}
			af.backwardArrivalIds[reference.Flights[s.Backward[len(s.Backward)-1]].StationToId] = struct{}{}
		}
	}

	if filters != nil && filters.Airport != nil && filters.Airport.Forward != nil {
		fillAirportsFromRequest(filters.Airport.Forward.Departure, af.forwardDepartureIdsInRequest)
		fillAirportsFromRequest(filters.Airport.Forward.Arrival, af.forwardArrivalIdsInRequest)
	}
	if filters != nil && filters.Airport != nil && filters.Airport.Backward != nil {
		fillAirportsFromRequest(filters.Airport.Backward.Departure, af.backwardDepartureIdsInRequest)
		fillAirportsFromRequest(filters.Airport.Backward.Arrival, af.backwardArrivalIdsInRequest)
	}

	af.anyForwardDeparture = len(af.forwardDepartureIdsInRequest) == 0
	af.anyForwardArrival = len(af.forwardArrivalIdsInRequest) == 0
	af.anyBackwardDeparture = len(af.backwardDepartureIdsInRequest) == 0
	af.anyBackwardArrival = len(af.backwardArrivalIdsInRequest) == 0

	af.needToApplyFilter = true
	if filters == nil || filters.Airport == nil {
		af.needToApplyFilter = false
	} else if len(af.forwardDepartureIdsInRequest) == 0 &&
		len(af.forwardArrivalIdsInRequest) == 0 &&
		len(af.backwardDepartureIdsInRequest) == 0 &&
		len(af.backwardArrivalIdsInRequest) == 0 {
		af.needToApplyFilter = false
	} else if maps.Equal(af.forwardDepartureIdsInRequest, af.forwardDepartureIds) &&
		maps.Equal(af.forwardArrivalIdsInRequest, af.forwardArrivalIds) &&
		maps.Equal(af.backwardDepartureIdsInRequest, af.backwardDepartureIds) &&
		maps.Equal(af.backwardArrivalIdsInRequest, af.backwardArrivalIds) {
		af.needToApplyFilter = false
	}

	return af
}

func (af *airportFilter) initFilterResponse(
	ctx context.Context,
	filters *aviaAPI.SearchFiltersReq,
	snippets map[string]*aviaSearchProto.Snippet,
	reference *aviaSearchProto.Reference,
	searchContext *aviaSearchProto.SearchContext,
	filterResponse *aviaAPI.SearchFiltersRsp,
) *aviaAPI.SearchFiltersRsp {
	forwardDeparture := &aviaAPI.SearchFiltersRsp_DepartureOrArrivalAirports{
		SettlementId: getSettlementID(searchContext.PointFrom),
		Airports:     getAirports(af.forwardDepartureIds, af.forwardDepartureIdsInRequest, reference),
		Title:        helpers.BuildDepartureTitle(searchContext.PointFrom, reference),
	}
	forwardArrival := &aviaAPI.SearchFiltersRsp_DepartureOrArrivalAirports{
		SettlementId: getSettlementID(searchContext.PointTo),
		Airports:     getAirports(af.forwardArrivalIds, af.forwardArrivalIdsInRequest, reference),
		Title:        helpers.BuildArrivalTitle(searchContext.PointTo, reference),
	}

	var backward *aviaAPI.SearchFiltersRsp_DirectionAirports = nil
	if searchContext.DateBackward != nil {
		backwardDeparture := &aviaAPI.SearchFiltersRsp_DepartureOrArrivalAirports{
			SettlementId: getSettlementID(searchContext.PointTo),
			Airports:     getAirports(af.backwardDepartureIds, af.backwardDepartureIdsInRequest, reference),
			Title:        helpers.BuildDepartureTitle(searchContext.PointTo, reference),
		}
		backwardArrival := &aviaAPI.SearchFiltersRsp_DepartureOrArrivalAirports{
			SettlementId: getSettlementID(searchContext.PointFrom),
			Airports:     getAirports(af.backwardArrivalIds, af.backwardArrivalIdsInRequest, reference),
			Title:        helpers.BuildArrivalTitle(searchContext.PointFrom, reference),
		}

		backward = &aviaAPI.SearchFiltersRsp_DirectionAirports{
			Departure: backwardDeparture,
			Arrival:   backwardArrival,
		}
	}

	filterResponse.Airport = &aviaAPI.SearchFiltersRsp_AirportFilter{
		Forward: &aviaAPI.SearchFiltersRsp_DirectionAirports{
			Departure: forwardDeparture,
			Arrival:   forwardArrival,
		},
		Backward: backward,
	}
	return filterResponse
}

func (af *airportFilter) filter(
	filters *aviaAPI.SearchFiltersReq,
	snippets map[string]*aviaSearchProto.Snippet,
	reference *aviaSearchProto.Reference,
	filterResponse *aviaAPI.SearchFiltersRsp,
) (map[string]struct{}, map[string]struct{}) {
	if !af.needToApplyFilter {
		return nil, nil
	}

	excludedSnippetKeys := make(map[string]struct{}, len(snippets))
	for sKey, s := range snippets {
		if af.needToSkipSnippet(s, reference) {
			excludedSnippetKeys[sKey] = struct{}{}
		}
	}
	return excludedSnippetKeys, nil
}

func (af *airportFilter) needToSkipSnippet(s *aviaSearchProto.Snippet, reference *aviaSearchProto.Reference) bool {
	if !af.anyForwardDeparture {
		if _, found := af.forwardDepartureIdsInRequest[reference.Flights[s.Forward[0]].StationFromId]; !found {
			return true
		}
	}
	if !af.anyForwardArrival {
		if _, found := af.forwardArrivalIdsInRequest[reference.Flights[s.Forward[len(s.Forward)-1]].StationToId]; !found {
			return true
		}
	}
	if len(s.Backward) > 0 {
		if !af.anyBackwardDeparture {
			if _, found := af.backwardDepartureIdsInRequest[reference.Flights[s.Backward[0]].StationFromId]; !found {
				return true
			}
		}
		if !af.anyBackwardArrival {
			if _, found := af.backwardArrivalIdsInRequest[reference.Flights[s.Backward[len(s.Backward)-1]].StationToId]; !found {
				return true
			}
		}
	}
	return false
}

func (af *airportFilter) updateFilterResponse(
	ctx context.Context,
	snippets map[string]*aviaSearchProto.Snippet,
	excludedSnippetKeysByOthers map[string]struct{},
	excludedVariantKeysByOthers map[string]struct{},
	reference *aviaSearchProto.Reference,
	filterResponse *aviaAPI.SearchFiltersRsp,
) *aviaAPI.SearchFiltersRsp {
	// Дизейблить нужно то, что исчезло из-за других фильтров.
	// То есть нужно взять snippets, убрать отфильтрованные другими фильтрами excludedSnippetKeysByOthers.
	// По оставшимся сниппетам построить список аэропортов.
	// Все что попало в этот список не надо дизейблить.

	forwardDepartureIds := make(map[uint64]struct{})
	forwardArrivalIds := make(map[uint64]struct{})
	backwardDepartureIds := make(map[uint64]struct{})
	backwardArrivalIds := make(map[uint64]struct{})

	for sKey, s := range snippets {
		if _, found := excludedSnippetKeysByOthers[sKey]; found {
			continue
		}
		forwardDepartureIds[reference.Flights[s.Forward[0]].StationFromId] = struct{}{}
		forwardArrivalIds[reference.Flights[s.Forward[len(s.Forward)-1]].StationToId] = struct{}{}
		if len(s.Backward) > 0 {
			backwardDepartureIds[reference.Flights[s.Backward[0]].StationFromId] = struct{}{}
			backwardArrivalIds[reference.Flights[s.Backward[len(s.Backward)-1]].StationToId] = struct{}{}
		}
	}

	for _, f := range filterResponse.Airport.Forward.Departure.Airports {
		if _, found := forwardDepartureIds[f.StationId]; !found {
			f.State.Enabled = false
		}
	}
	for _, f := range filterResponse.Airport.Forward.Arrival.Airports {
		if _, found := forwardArrivalIds[f.StationId]; !found {
			f.State.Enabled = false
		}
	}
	if filterResponse.Airport.Backward != nil {
		for _, f := range filterResponse.Airport.Backward.Departure.Airports {
			if _, found := backwardDepartureIds[f.StationId]; !found {
				f.State.Enabled = false
			}
		}
		for _, f := range filterResponse.Airport.Backward.Arrival.Airports {
			if _, found := backwardArrivalIds[f.StationId]; !found {
				f.State.Enabled = false
			}
		}
	}

	return filterResponse
}

func fillAirportsFromRequest(filter *aviaAPI.SearchFiltersReq_DepartureOrArrivalAirports, dataTofill map[uint64]struct{}) {
	if filter != nil {
		for _, a := range filter.Airports {
			if a.State {
				dataTofill[a.StationId] = struct{}{}
			}
		}
	}
}

func getAirports(
	idsInSnippets map[uint64]struct{},
	idsInRequest map[uint64]struct{},
	reference *aviaSearchProto.Reference,
) []*aviaAPI.SearchFiltersRsp_Airport {
	settlementIDs := make(map[uint64]struct{})
	for aID := range idsInSnippets {
		settlementIDs[reference.Stations[aID].SettlementId] = struct{}{}
	}
	needSettlementTitle := len(settlementIDs) > 1

	airports := make([]*aviaAPI.SearchFiltersRsp_Airport, 0, len(idsInSnippets))
	for aID := range idsInSnippets {
		station := reference.Stations[aID]
		a := &aviaAPI.SearchFiltersRsp_Airport{
			State: &aviaAPI.SearchFiltersRsp_BoolFilterState{
				Enabled: true,
			},
			StationId:       aID,
			StationTitle:    station.Title,
			SettlementTitle: "",
			AviaCode:        station.AviaCode,
		}
		if needSettlementTitle && station.SettlementId != 0 {
			a.SettlementTitle = reference.Settlements[station.SettlementId].Title
		}
		if _, found := idsInRequest[aID]; found {
			a.State.Value = true
		}
		airports = append(airports, a)
	}
	sort.SliceStable(airports, func(i, j int) bool {
		return airports[i].StationId < airports[j].StationId
	})

	return airports
}

func getSettlementID(point *aviaProtoV2.Point) uint64 {
	if point.Type == aviaProtoV2.PointType_POINT_TYPE_SETTLEMENT {
		return point.Id
	}
	return 0
}
