package filtering

import (
	"context"
	"sort"

	"a.yandex-team.ru/library/go/core/log"
	aviaAPI "a.yandex-team.ru/travel/app/backend/api/avia/v1"
	"a.yandex-team.ru/travel/app/backend/internal/avia/search/filtering/helpers"
	aviaSearchProto "a.yandex-team.ru/travel/app/backend/internal/avia/search/proto/v1"
)

type aviacompanyFilter struct {
	needToApplyFilter     bool
	selectedAviacompanies map[uint64]struct{}
	combinationsAllowed   bool
	logger                log.Logger
}

func (af *aviacompanyFilter) getFilterID() string {
	return "AviacompanyFilter"
}

func newAviacompanyFilter(logger log.Logger, filters *aviaAPI.SearchFiltersReq) *aviacompanyFilter {
	af := &aviacompanyFilter{
		logger:              logger,
		needToApplyFilter:   true,
		combinationsAllowed: false,
	}
	if filters == nil || filters.Aviacompany == nil {
		af.needToApplyFilter = false
	}
	if filters != nil && filters.Aviacompany != nil {
		af.combinationsAllowed = filters.Aviacompany.AviacompanyCombinations
		af.selectedAviacompanies = make(map[uint64]struct{}, len(filters.Aviacompany.Aviacompanies))
		for _, a := range filters.Aviacompany.Aviacompanies {
			if a.State {
				af.selectedAviacompanies[a.AviacompanyId] = struct{}{}
			}
		}
	}
	if len(af.selectedAviacompanies) == 0 {
		af.needToApplyFilter = false
	}
	return af
}

func (af *aviacompanyFilter) isAviacompanySelected(id uint64) bool {
	if len(af.selectedAviacompanies) == 0 {
		return true
	}
	_, found := af.selectedAviacompanies[id]
	return found
}

func (af *aviacompanyFilter) initFilterResponse(
	ctx context.Context,
	filters *aviaAPI.SearchFiltersReq,
	snippets map[string]*aviaSearchProto.Snippet,
	reference *aviaSearchProto.Reference,
	searchContext *aviaSearchProto.SearchContext,
	filterResponse *aviaAPI.SearchFiltersRsp,
) *aviaAPI.SearchFiltersRsp {
	allAviacompanies := make(map[uint64]struct{})
	if reference != nil && reference.Flights != nil {
		for _, s := range snippets {
			for _, f := range s.Forward {
				id := reference.Flights[f].AviaCompanyId
				if id != 0 {
					allAviacompanies[id] = struct{}{}
				}
			}
			for _, b := range s.Backward {
				id := reference.Flights[b].AviaCompanyId
				if id != 0 {
					allAviacompanies[id] = struct{}{}
				}
			}
		}
	}
	ac := make([]*aviaAPI.SearchFiltersRsp_AviacompanyState, 0, len(allAviacompanies))
	for id := range allAviacompanies {
		_, found := af.selectedAviacompanies[id]
		a := &aviaAPI.SearchFiltersRsp_AviacompanyState{
			AviacompanyId: id,
			State: &aviaAPI.SearchFiltersRsp_BoolFilterState{
				Enabled: true,
				Value:   found,
			},
		}
		ac = append(ac, a)
	}
	if reference != nil && reference.AviaCompanies != nil {
		getTitle := func(i int) string {
			title := ""
			if company, ok := reference.AviaCompanies[ac[i].AviacompanyId]; ok {
				title = company.Title
			}
			return title
		}

		sort.SliceStable(ac, func(i, j int) bool {
			return helpers.CompareCyrillicFirst(getTitle(i), getTitle(j))
		})
	}

	filterResponse.Aviacompany = &aviaAPI.SearchFiltersRsp_AviacompanyFilter{
		AviacompanyCombinations: &aviaAPI.SearchFiltersRsp_BoolFilterState{
			Enabled: true,
			Value:   af.combinationsAllowed,
		},
		Aviacompanies: ac,
	}
	return filterResponse
}

func (af *aviacompanyFilter) needToSkipWithCombinations(snippet *aviaSearchProto.Snippet, reference *aviaSearchProto.Reference) bool {
	for _, f := range snippet.Forward {
		if af.isAviacompanySelected(reference.Flights[f].AviaCompanyId) {
			return false
		}
	}
	for _, f := range snippet.Backward {
		if af.isAviacompanySelected(reference.Flights[f].AviaCompanyId) {
			return false
		}
	}
	return true
}

func (af *aviacompanyFilter) needToSkip(snippet *aviaSearchProto.Snippet, reference *aviaSearchProto.Reference) bool {
	for _, f := range snippet.Forward {
		if !af.isAviacompanySelected(reference.Flights[f].AviaCompanyId) {
			return true
		}
	}
	for _, f := range snippet.Backward {
		if !af.isAviacompanySelected(reference.Flights[f].AviaCompanyId) {
			return true
		}
	}
	return false
}

func (af *aviacompanyFilter) filter(
	filters *aviaAPI.SearchFiltersReq,
	snippets map[string]*aviaSearchProto.Snippet,
	reference *aviaSearchProto.Reference,
	filterResponse *aviaAPI.SearchFiltersRsp,
) (map[string]struct{}, map[string]struct{}) {
	if !af.needToApplyFilter {
		return nil, nil
	}

	var needToSkip func(*aviaSearchProto.Snippet, *aviaSearchProto.Reference) bool
	if af.combinationsAllowed {
		needToSkip = af.needToSkipWithCombinations
	} else {
		needToSkip = af.needToSkip
	}

	excludedSnippetKeys := make(map[string]struct{}, len(snippets))
	for sKey, s := range snippets {
		if needToSkip(s, reference) {
			excludedSnippetKeys[sKey] = struct{}{}
		}
	}
	return excludedSnippetKeys, nil
}

func (af *aviacompanyFilter) updateFilterResponse(
	ctx context.Context,
	snippets map[string]*aviaSearchProto.Snippet,
	excludedSnippetKeysByOthers map[string]struct{},
	excludedVariantKeysByOthers map[string]struct{},
	reference *aviaSearchProto.Reference,
	filterResponse *aviaAPI.SearchFiltersRsp,
) *aviaAPI.SearchFiltersRsp {
	remainingAviacompanies := make(map[uint64]struct{})
	hasFullVariantAviacompanies := make(map[uint64]struct{})
	if reference != nil && reference.Flights != nil {
		for sKey, s := range snippets {
			if _, needToSkip := excludedSnippetKeysByOthers[sKey]; needToSkip {
				continue
			}
			singleAC := true
			ac := reference.Flights[s.Forward[0]].AviaCompanyId
			for _, f := range s.Forward {
				remainingAviacompanies[reference.Flights[f].AviaCompanyId] = struct{}{}
				if reference.Flights[f].AviaCompanyId != ac {
					singleAC = false
				}
			}
			for _, f := range s.Backward {
				remainingAviacompanies[reference.Flights[f].AviaCompanyId] = struct{}{}
				if reference.Flights[f].AviaCompanyId != ac {
					singleAC = false
				}
			}
			if singleAC {
				hasFullVariantAviacompanies[ac] = struct{}{}
			}
		}
	}

	if af.combinationsAllowed {
		for _, a := range filterResponse.Aviacompany.Aviacompanies {
			if _, found := remainingAviacompanies[a.AviacompanyId]; !found {
				a.State.Enabled = false
			}
		}
	} else {
		for _, a := range filterResponse.Aviacompany.Aviacompanies {
			if _, found := hasFullVariantAviacompanies[a.AviacompanyId]; !found {
				a.State.Enabled = false
			}
		}
	}

	return filterResponse
}
