package transfer

import (
	aviaAPI "a.yandex-team.ru/travel/app/backend/api/avia/v1"
	"a.yandex-team.ru/travel/app/backend/internal/avia/search/filtering2/filterinterface"
	aviaSearchProto "a.yandex-team.ru/travel/app/backend/internal/avia/search/proto/v1"
	"a.yandex-team.ru/travel/library/go/containers"
)

type TransferAirportsBackward struct {
	selectedAirportsBackward containers.Set[uint64]
}

func (tf *TransferAirportsBackward) GetFilterID() string {
	return "TransferAirportsBackward"
}

func NewTransferAirportsBackward(filters *aviaAPI.SearchFiltersReq) *TransferAirportsBackward {
	tf := &TransferAirportsBackward{}

	tf.selectedAirportsBackward = containers.SetOf[uint64]()
	if filters != nil && filters.Transfer != nil && filters.Transfer.Airports != nil {
		for _, a := range filters.Transfer.Airports.Backward {
			if a.State {
				tf.selectedAirportsBackward.Add(a.StationId)
			}
		}
	}

	return tf
}

func (tf *TransferAirportsBackward) InitFilterResponse(
	filters *aviaAPI.SearchFiltersReq,
	snippets map[string]*aviaSearchProto.Snippet,
	reference *aviaSearchProto.Reference,
	searchContext *aviaSearchProto.SearchContext,
	filterResponse *aviaAPI.SearchFiltersRsp,
) *aviaAPI.SearchFiltersRsp {
	if filterResponse.Transfer == nil {
		filterResponse.Transfer = &aviaAPI.SearchFiltersRsp_TransferFilter{}
	}
	if filterResponse.Transfer.Airports == nil {
		filterResponse.Transfer.Airports = &aviaAPI.SearchFiltersRsp_TransferAirportsFilter{}
	}

	airportsBackward := containers.SetOf[uint64]()
	for _, s := range snippets {
		for _, t := range s.Transfers.BackwardTransfers {
			airportsBackward.Add(t.ArrivalStationId)
			airportsBackward.Add(t.DepartureStationId)
		}
	}

	filterResponse.Transfer.Airports.Backward = buildInitialDirectionAirportsResponse(
		airportsBackward,
		tf.selectedAirportsBackward,
		reference,
	)

	return filterResponse
}

func (tf *TransferAirportsBackward) needToSkip(transfers *aviaSearchProto.Transfers) bool {
	for _, t := range transfers.BackwardTransfers {
		if tf.selectedAirportsBackward.Contains(t.ArrivalStationId) {
			return true
		}
		if tf.selectedAirportsBackward.Contains(t.DepartureStationId) {
			return true
		}
	}
	return false
}

func (tf *TransferAirportsBackward) Filter(
	filters *aviaAPI.SearchFiltersReq,
	snippets map[string]*aviaSearchProto.Snippet,
	reference *aviaSearchProto.Reference,
	filterResponse *aviaAPI.SearchFiltersRsp,
) *filterinterface.ExcludedKeys {
	excludedKeys := filterinterface.NewExcludedKeys()
	for sKey, s := range snippets {
		if tf.needToSkip(s.Transfers) {
			excludedKeys.AddSnippetKey(sKey)
		}
	}
	return excludedKeys
}

func (tf *TransferAirportsBackward) UpdateFilterResponse(
	snippets map[string]*aviaSearchProto.Snippet,
	excludedKeysByOthers *filterinterface.ExcludedKeys,
	reference *aviaSearchProto.Reference,
	filterResponse *aviaAPI.SearchFiltersRsp,
) *aviaAPI.SearchFiltersRsp {
	airportsBackward := containers.SetOf[uint64]()

	for sKey, s := range snippets {
		needToSkip := excludedKeysByOthers.ContainsSnippetKey(sKey)
		if needToSkip {
			continue
		}
		for _, t := range s.Transfers.BackwardTransfers {
			airportsBackward.Add(t.ArrivalStationId)
			airportsBackward.Add(t.DepartureStationId)
		}
	}
	for _, a := range filterResponse.Transfer.Airports.Backward {
		if !airportsBackward.Contains(a.StationId) {
			a.State.Enabled = false
		}
	}

	return filterResponse
}
