package search

import (
	"context"
	"fmt"
	"time"

	"a.yandex-team.ru/library/go/core/log"
	"a.yandex-team.ru/travel/avia/wizard/pkg/wizard/domain/models"
	"a.yandex-team.ru/travel/avia/wizard/pkg/wizard/domain/parameters"
	"a.yandex-team.ru/travel/avia/wizard/pkg/wizard/lib/containers"
	"a.yandex-team.ru/travel/avia/wizard/pkg/wizard/repositories/ydb"
	containerslib "a.yandex-team.ru/travel/library/go/containers"
	"a.yandex-team.ru/travel/library/go/syncutil"
)

var partnersForHidden = containerslib.SetOf("tutu")

type resultProvider func(
	queryParameters *parameters.QueryParameters,
	fromSettlement, toSettlement *models.Settlement,
	departureDate *time.Time,
	now time.Time,
	ctx context.Context,
) (
	wizardResult []*ydb.WizardSearchResult,
	err error,
)

type resultByPartnerFetcher struct {
	logger                      log.Logger
	providerType                string
	resultProviderByAllPartners resultProvider
	resultProviderByPartners    resultProvider
	usePartnerCache             bool
}

func newResultByPartnerFetcher(
	logger log.Logger,
	providerType string,
	resultProviderByAllPartners resultProvider,
	resultProviderByPartners resultProvider,
	usePartnerCache bool,
) *resultByPartnerFetcher {
	return &resultByPartnerFetcher{
		logger:                      logger,
		providerType:                providerType,
		resultProviderByAllPartners: resultProviderByAllPartners,
		resultProviderByPartners:    resultProviderByPartners,
		usePartnerCache:             usePartnerCache,
	}
}

func (f *resultByPartnerFetcher) getResults(
	queryParameters *parameters.QueryParameters,
	fromSettlement, toSettlement *models.Settlement,
	departureDate *time.Time,
	now time.Time,
	ctx context.Context,
) (
	wizardResult []*ydb.WizardSearchResult,
	err error,
) {
	var partnerResults, controlResults []*ydb.WizardSearchResult
	var partnerErr, controlErr error
	wg := syncutil.WaitGroup{}
	wg.Go(func() {
		controlResults, controlErr = f.resultProviderByAllPartners(queryParameters, fromSettlement, toSettlement, departureDate, now, ctx)
	})
	wg.Go(func() {
		if !queryParameters.Flags.EnablePartnerCache() || !f.usePartnerCache {
			return
		}
		partnerResults, partnerErr = f.resultProviderByPartners(queryParameters, fromSettlement, toSettlement, departureDate, now, ctx)
		if partnerErr != nil {
			f.logger.Error(
				"getByPartner",
				log.Error(partnerErr),
				log.String("providerType", f.providerType),
			)
		}
	})
	wg.Wait()

	if controlErr != nil {
		return nil, fmt.Errorf("failed fetching by all partner: %w", controlErr)
	}

	partnerResults = removeHiddenPartnersFromPartnerResults(queryParameters.Flags.EnableTutuPartnerResults(), partnerResults)
	filtered := filterResultByPartner(buildPartnerFilter(queryParameters), partnerResults)
	if len(filtered) > 0 {
		partnerResults = filtered
	}
	updateCommonFields(controlResults, partnerResults)

	if len(partnerResults) != 0 && partnerErr == nil {
		return partnerResults, partnerErr
	}
	return controlResults, controlErr
}

func removeHiddenPartnersFromPartnerResults(enablePartners bool, results []*ydb.WizardSearchResult) []*ydb.WizardSearchResult {
	if enablePartners {
		return results
	}
	return containerslib.FilterBy(results, func(result *ydb.WizardSearchResult) bool {
		return !partnersForHidden.Contains(result.SearchResult.Value.Fares[0].GetPartner())
	})
}

func updateCommonFields(srcResults []*ydb.WizardSearchResult, dstResults []*ydb.WizardSearchResult) {
	if len(srcResults) == 0 || len(dstResults) == 0 {
		return
	}
	if srcResults[0].SearchResult.Value == nil {
		return
	}
	statusValue := srcResults[0].SearchResult.Value.PollingStatus
	version := srcResults[0].SearchResult.Value.Version
	filterValue := srcResults[0].FilterState.Value
	for _, partner := range partnersForHidden.Values() {
		filterValue.Partners.Remove(partner)
	}

	for _, result := range dstResults {
		if result.SearchResult.Value != nil {
			result.SearchResult.Value.PollingStatus = statusValue
			result.SearchResult.Value.Version = version
		}
		result.FilterState.Value = filterValue
	}
}

func filterResultByPartner(filter containers.SetOfString, srcResults []*ydb.WizardSearchResult) []*ydb.WizardSearchResult {
	if len(filter) == 0 || len(srcResults) == 0 {
		return srcResults
	}
	filtered := make([]*ydb.WizardSearchResult, 0)
	for _, result := range srcResults {
		if len(result.SearchResult.Value.GetFares()) == 0 {
			continue
		}
		partner := result.SearchResult.Value.Fares[0].GetPartner()
		if filter.Contains(partner) {
			filtered = append(filtered, result)
		}
	}
	return filtered
}

func buildPartnerFilter(queryParameters *parameters.QueryParameters) (partners containers.SetOfString) {
	defer func() {
		if partners != nil && !queryParameters.Flags.EnableTutuPartnerResults() {
			for _, partner := range partnersForHidden.Values() {
				partners.Remove(partner)
			}
		}
	}()

	partners, ok := queryParameters.Filters().Partners()
	if ok && len(partners) > 0 {
		return partners
	}
	if queryParameters.PartnerCode != nil && *queryParameters.PartnerCode != "" {
		return containers.NewSetOfString(*queryParameters.PartnerCode)
	}
	return nil
}
