package airport

import (
	"golang.org/x/exp/maps"

	aviaAPI "a.yandex-team.ru/travel/app/backend/api/avia/v1"
	"a.yandex-team.ru/travel/app/backend/internal/avia/search/filtering2/filterinterface"
	"a.yandex-team.ru/travel/app/backend/internal/avia/search/filtering2/helpers"
	aviaSearchProto "a.yandex-team.ru/travel/app/backend/internal/avia/search/proto/v1"
	"a.yandex-team.ru/travel/library/go/containers"
)

type AirportBackwardArrivalFilter struct {
	needToApplyFilter bool
	idsInRequest      containers.Set[uint64]
}

func (af *AirportBackwardArrivalFilter) GetFilterID() string {
	return "AirportBackwardArrivalFilter"
}

func NewAirportBackwardArrivalFilter(filters *aviaAPI.SearchFiltersReq) *AirportBackwardArrivalFilter {
	idsInRequest := containers.SetOf[uint64]()
	if filters != nil && filters.Airport != nil && filters.Airport.Backward != nil && filters.Airport.Backward.Arrival != nil {
		for _, a := range filters.Airport.Backward.Arrival.Airports {
			if a.State {
				idsInRequest.Add(a.StationId)
			}
		}
	}

	return &AirportBackwardArrivalFilter{
		needToApplyFilter: true,
		idsInRequest:      idsInRequest,
	}
}

func (af *AirportBackwardArrivalFilter) InitFilterResponse(
	filters *aviaAPI.SearchFiltersReq,
	snippets map[string]*aviaSearchProto.Snippet,
	reference *aviaSearchProto.Reference,
	searchContext *aviaSearchProto.SearchContext,
	filterResponse *aviaAPI.SearchFiltersRsp,
) *aviaAPI.SearchFiltersRsp {
	if filterResponse.Airport == nil {
		filterResponse.Airport = &aviaAPI.SearchFiltersRsp_AirportFilter{}
	}
	if filterResponse.Airport.Backward == nil {
		filterResponse.Airport.Backward = &aviaAPI.SearchFiltersRsp_DirectionAirports{}
	}

	idsInSnippets := containers.SetOf[uint64]()
	for _, s := range snippets {
		idsInSnippets.Add(reference.Flights[s.Backward[len(s.Backward)-1]].StationToId)
	}
	af.needToApplyFilter = len(af.idsInRequest) > 0 && !maps.Equal(idsInSnippets, af.idsInRequest)

	filterResponse.Airport.Backward.Arrival = &aviaAPI.SearchFiltersRsp_DepartureOrArrivalAirports{
		SettlementId: getSettlementID(searchContext.PointFrom),
		Airports:     getAirports(idsInSnippets, af.idsInRequest, reference),
		Title:        helpers.BuildArrivalTitle(searchContext.PointFrom, reference),
	}

	return filterResponse
}

func (af *AirportBackwardArrivalFilter) Filter(
	filters *aviaAPI.SearchFiltersReq,
	snippets map[string]*aviaSearchProto.Snippet,
	reference *aviaSearchProto.Reference,
	filterResponse *aviaAPI.SearchFiltersRsp,
) *filterinterface.ExcludedKeys {
	excludedKeys := filterinterface.NewExcludedKeys()
	if !af.needToApplyFilter {
		return excludedKeys
	}

	for sKey, s := range snippets {
		if !af.idsInRequest.Contains(reference.Flights[s.Backward[len(s.Backward)-1]].StationToId) {
			excludedKeys.AddSnippetKey(sKey)
		}
	}
	return excludedKeys
}

func (af *AirportBackwardArrivalFilter) UpdateFilterResponse(
	snippets map[string]*aviaSearchProto.Snippet,
	excludedKeysByOthers *filterinterface.ExcludedKeys,
	reference *aviaSearchProto.Reference,
	filterResponse *aviaAPI.SearchFiltersRsp,
) *aviaAPI.SearchFiltersRsp {
	idsInSnippets := containers.SetOf[uint64]()
	for sKey, s := range snippets {
		if excludedKeysByOthers.ContainsSnippetKey(sKey) {
			continue
		}
		idsInSnippets.Add(reference.Flights[s.Backward[len(s.Backward)-1]].StationToId)
	}

	for _, f := range filterResponse.Airport.Backward.Arrival.Airports {
		if !idsInSnippets.Contains(f.StationId) {
			f.State.Enabled = false
		}
	}

	return filterResponse
}
