package tariffmatcher

import (
	"context"
	"crypto"
	_ "crypto/md5"
	"fmt"
	"strings"
	"sync"

	"a.yandex-team.ru/library/go/core/log"
	"a.yandex-team.ru/library/go/core/xerrors"
	"github.com/opentracing/opentracing-go"

	farefamiliesstructs "a.yandex-team.ru/travel/avia/fare_families/internal/services/fare_families/data_structs/fare_families"
	"a.yandex-team.ru/travel/avia/fare_families/internal/services/fare_families/data_structs/payloads"
	"a.yandex-team.ru/travel/avia/fare_families/internal/services/fare_families/dicts"
	"a.yandex-team.ru/travel/avia/fare_families/internal/services/fare_families/storage"
	"a.yandex-team.ru/travel/proto/dicts/rasp"
)

const (
	EmptyFareFamilyKey = ""
	RussiaCountryID    = 225
)

type TariffMatcher interface {
	MatchFareFamilies(variants *payloads.VariantsFromPartner, context context.Context) (result *payloads.FareFamiliesMap, err error)
	MatchFareFamily(carrier, fareCode, from, to, departure string) (*farefamiliesstructs.FareFamilyForVariant, error)
}

type tariffMatcherImpl struct {
	storage storage.Storage
	dicts   dicts.DictsRegistry
	logger  log.Logger
}

type flightIndex struct {
	directionIndex int
	segmentIndex   int
}

type fareFamiliesMapWithLock struct {
	data map[string]farefamiliesstructs.FareFamilyForVariant
	lock sync.Mutex
}

func (ffmap *fareFamiliesMapWithLock) Set(flightFareFamily *farefamiliesstructs.FareFamilyForVariant) {
	ffmap.lock.Lock()
	ffmap.data[flightFareFamily.Key] = *flightFareFamily
	ffmap.lock.Unlock()
}

type variantsMapWithLock struct {
	data map[string]payloads.FareFamiliesEntry
	lock sync.Mutex
}

func (ffmap *variantsMapWithLock) Set(fareKey string, fareFamiliesEntry *payloads.FareFamiliesEntry) {
	ffmap.lock.Lock()
	ffmap.data[fareKey] = *fareFamiliesEntry
	ffmap.lock.Unlock()
}

func NewTariffMatcher(storage storage.Storage, dicts dicts.DictsRegistry, logger log.Logger) (TariffMatcher, error) {
	tariffMatcherInstance := &tariffMatcherImpl{
		storage: storage,
		dicts:   dicts,
		logger:  logger,
	}

	return tariffMatcherInstance, nil
}

func (tm *tariffMatcherImpl) MatchFareFamilies(
	variants *payloads.VariantsFromPartner, context context.Context) (result *payloads.FareFamiliesMap, err error) {
	if context != nil {
		span, _ := opentracing.StartSpanFromContext(
			context,
			"internal.services.fare_families...:MatchFareFamilies",
		)
		defer span.Finish()
	}

	fareFamiliesMap := fareFamiliesMapWithLock{
		data: map[string]farefamiliesstructs.FareFamilyForVariant{},
		lock: sync.Mutex{},
	}
	variantsMap := &variantsMapWithLock{
		data: map[string]payloads.FareFamiliesEntry{},
		lock: sync.Mutex{},
	}

	for fareKey, fare := range variants.Fares {
		fareFamilyKeys, err := tm.matchFareFamilyAndUpdateMap(fareKey, fare, variants.Flights, &fareFamiliesMap)
		if err != nil {
			tm.logger.Warn("Unable to match fare family", log.Error(err), log.Any("fare", fare))
			return result, err
		}
		fareFamiliesEntry := payloads.FareFamiliesEntry{}
		fareFamiliesEntry.FareFamilyKeys = fareFamilyKeys
		fareFamiliesEntry.FareFamiliesHash = getHash(fareFamilyKeys)
		variantsMap.Set(fareKey, &fareFamiliesEntry)
	}

	result = &payloads.FareFamiliesMap{
		VariantsMap:           variantsMap.data,
		FareFamiliesReference: fareFamiliesMap.data,
	}
	return result, nil
}

func (tm *tariffMatcherImpl) MatchFareFamily(
	carrier, fareCode, from, to, departure string) (*farefamiliesstructs.FareFamilyForVariant, error) {
	carrierID, ok := tm.dicts.Carriers().GetByIata(carrier)
	if !ok {
		return nil, xerrors.Errorf("Unknown carrier: %s", carrier)
	}
	fromID, ok := tm.dicts.StationCodes().GetStationIDByCode(from)
	if !ok {
		return nil, xerrors.Errorf("Unknown \"from\" point: %s", from)
	}
	toID, ok := tm.dicts.StationCodes().GetStationIDByCode(to)
	if !ok {
		return nil, xerrors.Errorf("Unknown \"to\" point: %s", to)
	}

	result, err := tm.matchFareFamilyInternal(
		carrierID,
		fareCode,
		payloads.VariantsFlight{
			AviaCompany: carrierID,
			Departure:   payloads.FlightTime{LocalTime: departure},
			From:        int64(fromID),
			To:          int64(toID),
		},
	)
	return result, err
}

func (tm *tariffMatcherImpl) matchFareFamilyAndUpdateMap(
	fareKey string,
	fare payloads.VariantsFare,
	flights map[string]payloads.VariantsFlight,
	fareFamiliesMap *fareFamiliesMapWithLock) (fareFamilyKeys [][]string, err error) {
	result := [][]string{}
	chargesPerCarrier := map[int32]chargeMap{}
	matchedFareFamilies := map[flightIndex]farefamiliesstructs.FareFamilyForVariant{}

	// First, match fare families to the fare tariffs and calculate the total refund/exchange charges
	for directionIndex, direction := range fare.FareCodes {
		if direction == nil {
			continue
		}
		for segmentIndex, segmentFare := range direction {
			if len(segmentFare) == 0 {
				continue
			}
			flight, ok := flights[fare.Route[directionIndex][segmentIndex]]
			if !ok {
				return nil, xerrors.Errorf("Unknown flight: %d / %d for fare %s", directionIndex, segmentIndex, fareKey)
			}
			operatingCompany, err := getOperatingCompany(flight)
			if err != nil {
				tm.logger.Warn("Flight has no operating carrier", log.String("error", err.Error()))
				continue
			}

			updatedFareFamily, err := tm.matchFareFamilyInternal(operatingCompany, segmentFare, flight)
			if err != nil {
				return nil, xerrors.Errorf("Error while matching fare tariff: %s for farekey %s: %w", segmentFare, fareKey, err)
			}
			if updatedFareFamily == nil {
				continue
			}
			charges, ok := chargesPerCarrier[operatingCompany]
			if !ok {
				charges = chargeMap{}
			}
			charges.updateWorstCharges(updatedFareFamily)
			chargesPerCarrier[operatingCompany] = charges

			matchedFareFamilies[flightIndex{directionIndex, segmentIndex}] = *updatedFareFamily
		}
	}
	totalCharges := calculateTotalCharges(chargesPerCarrier)

	// Second, apply the calculated charges to the fare families in each segment
	for chargeKey, chargeValue := range totalCharges {
		for fareFamilyIndex, fareFamily := range matchedFareFamilies {
			for termIndex, term := range fareFamily.Terms {
				if term.Code != chargeKey {
					continue
				}
				updatedRule, keySuffix := chargeValue.updateRule(term.Rule)
				term.Rule = updatedRule
				fareFamily.Terms[termIndex] = term
				if len(keySuffix) > 0 {
					fareFamily.Key = fmt.Sprintf("%s;%s=%s", fareFamily.Key, chargeKey, keySuffix)
				}
			}
			matchedFareFamilies[fareFamilyIndex] = fareFamily
		}
	}

	// Finally, reference the fare families from the variant's flights
	for directionIndex, direction := range fare.FareCodes {
		if direction == nil {
			result = append(result, nil)
			continue
		}
		directionResult := []string{}
		for segmentIndex, segmentFare := range direction {
			if len(segmentFare) == 0 {
				directionResult = append(directionResult, EmptyFareFamilyKey)
				continue
			}
			flightFareFamily := matchedFareFamilies[flightIndex{directionIndex, segmentIndex}]
			directionResult = append(directionResult, flightFareFamily.Key)
			if len(flightFareFamily.Key) > 0 {
				fareFamiliesMap.Set(&flightFareFamily)
			}
		}

		result = append(result, directionResult)
	}
	return result, nil
}

func (tm *tariffMatcherImpl) matchFareFamilyInternal(
	operatingCompany int32,
	segmentFare string,
	flight payloads.VariantsFlight) (*farefamiliesstructs.FareFamilyForVariant, error) {
	fareFamily, err := tm.storage.GetDump(operatingCompany, NewFareFamilyFilterByTariffCode(segmentFare))
	if err != nil {
		return nil, xerrors.Errorf("Error matching for the fare %s tariff, airline %d: %w", segmentFare, operatingCompany, err)
	}
	if len(fareFamily.FareFamilies) == 0 {
		return nil, nil
	}
	updatedFareFamily, err := tm.applyFareFamilyToVariant(flight, fareFamily.FareFamilies[0])
	if err != nil {
		return nil, xerrors.Errorf("Error matching for the fare %s tariff, airline %d: %w", segmentFare, operatingCompany, err)
	}
	return updatedFareFamily, nil
}

func (tm *tariffMatcherImpl) applyFareFamilyToVariant(
	flight payloads.VariantsFlight,
	fareFamily farefamiliesstructs.CompiledFareFamily) (*farefamiliesstructs.FareFamilyForVariant, error) {
	keyParts := []string{
		fareFamily.Key,
	}
	result := &farefamiliesstructs.FareFamilyForVariant{
		BaseClass:         fareFamily.BaseClass,
		Brand:             fareFamily.Brand,
		TariffCodePattern: fareFamily.TariffCodePattern,
		TariffGroupName:   fareFamily.TariffGroupName,
		Terms:             []farefamiliesstructs.FareFamilyTermForVariant{},
	}
	for _, term := range fareFamily.Terms {
		updatedTerm := farefamiliesstructs.FareFamilyTermForVariant{
			Code:         term.Code,
			SpecialNotes: term.SpecialNotes, // It's okay to copy a reference here, we never modify this data anyway
		}
		foundRule := false
		for ruleIndex, rule := range term.Rules {
			if len(rule.Conditions) > 0 {
				if tm.matches(rule.Conditions, flight) {
					updatedTerm.Rule = rule
					updatedTerm.ID = ruleIndex
					foundRule = true
					break
				} else {
					continue
				}
			} else {
				updatedTerm.Rule = rule
				updatedTerm.ID = ruleIndex
				foundRule = true
				break
			}
		}
		if foundRule {
			if len(updatedTerm.Rule.Availability) == 0 {
				// TODO(u-jeen): make sure tests fail when even one json file has this problem
				updatedTerm.Rule.Availability = string(farefamiliesstructs.NotAvailable)
			}
			result.Terms = append(result.Terms, updatedTerm)
			keyParts = append(keyParts, fmt.Sprintf("%s=%d", updatedTerm.Code, updatedTerm.ID))
		}
	}
	result.Key = strings.Join(keyParts, ";")
	// Error is always nil at the moment, but it's gonna change as we implement real tariff rules matching
	return result, nil
}

func (tm *tariffMatcherImpl) matches(conditions []farefamiliesstructs.Condition, flight payloads.VariantsFlight) bool {
	for _, condition := range conditions {
		if !tm.matchesCondition(condition, flight) {
			return false
		}
	}
	return true
}

func (tm *tariffMatcherImpl) matchesCondition(condition farefamiliesstructs.Condition, flight payloads.VariantsFlight) bool {
	if condition.IsDomesticRU && !tm.matchesDomesticRU(flight) {
		return false
	}
	if condition.IsInternational && !tm.matchesInternational(flight) {
		return false
	}
	if len(condition.DepartBefore) > 0 && !tm.matchesDepartBefore(condition.DepartBefore, flight) {
		return false
	}
	if len(condition.DepartAfter) > 0 && !tm.matchesDepartAfter(condition.DepartAfter, flight) {
		return false
	}
	if len(condition.CountriesFrom) > 0 && !tm.matchesCountriesFrom(condition.CountriesFrom, flight) {
		return false
	}
	if len(condition.CountriesTo) > 0 && !tm.matchesCountriesTo(condition.CountriesTo, flight) {
		return false
	}
	if len(condition.CountriesOnly) > 0 && !tm.matchesCountriesOnly(condition.CountriesOnly, flight) {
		return false
	}
	if len(condition.AirportsOnly) > 0 && !tm.matchesAirportsOnly(condition.AirportsOnly, flight) {
		return false
	}
	if len(condition.AirportsBetween.PointsA) > 0 && !tm.matchesAirportsBetween(condition.AirportsBetween, flight) {
		return false
	}
	return true
}

func (tm *tariffMatcherImpl) matchesAirportsOnly(airportsOnly []string, flight payloads.VariantsFlight) bool {
	return tm.matchesAirportsList(airportsOnly, flight.From) || tm.matchesAirportsList(airportsOnly, flight.To)
}

func (tm *tariffMatcherImpl) matchesAirportsBetween(airports farefamiliesstructs.Between, flight payloads.VariantsFlight) bool {
	if tm.matchesAirportsList(airports.PointsA, flight.From) && tm.matchesAirportsList(airports.PointsB, flight.To) {
		return true
	}
	if tm.matchesAirportsList(airports.PointsB, flight.From) && tm.matchesAirportsList(airports.PointsA, flight.To) {
		return true
	}
	return false
}

func (tm *tariffMatcherImpl) matchesAirportsList(airports []string, stationID int64) bool {
	flightAirportCode := tm.getAirportCode(stationID)
	if flightAirportCode == "" {
		return len(airports) == 0
	}

	for _, airport := range airports {
		if len(airport) == 0 {
			// Empty airport does not match anything
			continue
		}
		if airport == flightAirportCode {
			return true
		}
	}
	return false
}

func (tm *tariffMatcherImpl) matchesCountriesOnly(countries []string, flight payloads.VariantsFlight) bool {
	return tm.matchesCountriesFrom(countries, flight) || tm.matchesCountriesTo(countries, flight)
}

func (tm *tariffMatcherImpl) matchesCountriesFrom(countriesFrom []string, flight payloads.VariantsFlight) bool {
	return tm.matchesCountriesList(countriesFrom, flight.From)
}

func (tm *tariffMatcherImpl) matchesCountriesTo(countriesTo []string, flight payloads.VariantsFlight) bool {
	return tm.matchesCountriesList(countriesTo, flight.To)
}

func (tm *tariffMatcherImpl) matchesCountriesList(countries []string, stationID int64) bool {
	flightCountry := tm.getCountryCode(stationID)
	if flightCountry == "" {
		return len(countries) == 0
	}
	for _, country := range countries {
		if country == flightCountry {
			return true
		}
	}
	return false
}

func (tm *tariffMatcherImpl) matchesDomesticRU(flight payloads.VariantsFlight) bool {
	from, ok := tm.dicts.Stations().Get(int(flight.From))
	if !ok {
		tm.logger.Warn("Unknown \"from\" station", log.Int64("station", flight.From))
		return false
	}
	if from.CountryId != RussiaCountryID {
		return false
	}

	to, ok := tm.dicts.Stations().Get(int(flight.To))
	if !ok {
		tm.logger.Warn("Unknown \"to\" station", log.Int64("station", flight.To))
		return false
	}
	if to.CountryId != RussiaCountryID {
		return false
	}

	return true
}

func (tm *tariffMatcherImpl) matchesInternational(flight payloads.VariantsFlight) bool {
	from, ok := tm.dicts.Stations().Get(int(flight.From))
	if !ok {
		tm.logger.Warn("Unknown \"from\" station", log.Int64("station", flight.From))
		return false
	}

	to, ok := tm.dicts.Stations().Get(int(flight.To))
	if !ok {
		tm.logger.Warn("Unknown \"to\" station", log.Int64("station", flight.To))
		return false
	}
	if from.CountryId == RussiaCountryID && to.CountryId == RussiaCountryID {
		return false
	}

	return true
}

func (tm *tariffMatcherImpl) matchesDepartBefore(departBefore string, flight payloads.VariantsFlight) bool {
	if len(flight.Departure.LocalTime) > 0 && flight.Departure.LocalTime < departBefore {
		return true
	}
	return false
}

func (tm *tariffMatcherImpl) matchesDepartAfter(departAfter string, flight payloads.VariantsFlight) bool {
	if len(flight.Departure.LocalTime) > 0 && flight.Departure.LocalTime > departAfter {
		return true
	}
	return false
}

func (tm *tariffMatcherImpl) getCountryCode(stationID int64) string {
	station, ok := tm.dicts.Stations().Get(int(stationID))
	if !ok || station.CountryId <= 0 {
		tm.logger.Warn("Unknown station to find out country code", log.Int64("station", stationID))
		return ""
	}
	country, ok := tm.dicts.Countries().Get(int32(station.CountryId))
	if !ok {
		tm.logger.Warn("Unknown country for station", log.Int64("station", stationID))
		return ""
	}
	return country.Code
}

func (tm *tariffMatcherImpl) getAirportCode(stationID int64) string {
	station, ok := tm.dicts.Stations().Get(int(stationID))
	if !ok {
		tm.logger.Warn("Unknown station", log.Int64("station", stationID))
		return ""
	}
	codes := station.GetStationCodes()
	if codes == nil {
		return ""
	}
	code, ok := codes[int32(rasp.ECodeSystem_CODE_SYSTEM_IATA)]
	if ok {
		return code
	}
	code, ok = codes[int32(rasp.ECodeSystem_CODE_SYSTEM_SIRENA)]
	if ok {
		return code
	}
	code, ok = codes[int32(rasp.ECodeSystem_CODE_SYSTEM_ICAO)]
	if ok {
		return code
	}
	return ""
}

func getOperatingCompany(flight payloads.VariantsFlight) (int32, error) {
	if flight.Operating.Company != 0 {
		return flight.Operating.Company, nil
	}
	if flight.AviaCompany != 0 {
		return flight.AviaCompany, nil
	}
	return 0, xerrors.Errorf("Unable to determine an operating carrier for a flight: %+v", flight)
}

func getHash(fareFamilyKeys [][]string) string {
	digester := crypto.MD5.New()
	for _, direction := range fareFamilyKeys {
		if direction == nil {
			continue
		}
		for _, fareFamilyKey := range direction {
			if len(fareFamilyKey) == 0 {
				// Presumably, makes better hashes
				fareFamilyKey = "_"
			}
			fmt.Fprint(digester, fareFamilyKey)
		}
	}
	return fmt.Sprintf("%x", digester.Sum(nil))
}
