package ru.yandex.travel.api.services.avia.fares;

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;

import com.google.common.collect.MultimapBuilder;
import com.google.common.collect.SetMultimap;
import lombok.RequiredArgsConstructor;
import org.dom4j.Document;
import org.jetbrains.annotations.NotNull;
import org.springframework.stereotype.Service;

import ru.yandex.avia.booking.ff.model.FareFamily;
import ru.yandex.avia.booking.ff.model.SegmentFare;
import ru.yandex.avia.booking.ff.model.TermAvailability;
import ru.yandex.avia.booking.ff.model.TermValue;
import ru.yandex.avia.booking.partners.gateways.model.search.CategoryPrice;
import ru.yandex.avia.booking.partners.gateways.model.search.FareInfo;
import ru.yandex.avia.booking.partners.gateways.model.search.Flight;
import ru.yandex.avia.booking.partners.gateways.model.search.PriceInfo;
import ru.yandex.avia.booking.partners.gateways.model.search.Segment;
import ru.yandex.avia.booking.partners.gateways.model.search.Variant;

import static java.util.stream.Collectors.toList;

@Service
@RequiredArgsConstructor
public class AviaFareFamilyService {
    private static final Collection<String> TERMS_TO_VALIDATE = List.of(
            "refundable",
            "refundable_no_show",
            "changing_carriage",
            "changing_carriage_no_show"
            //"baggage"
    );

    private static final String RUB_CURRENCY = "RUB";

    private final AviaFareFamilyReference fareFamilyReference;
    private final AviaFareFamilyInputFactory fareFamilyInputFactory;

    public Map<String, SegmentFare> getFareTerms(Variant variant, String lang) {
        Map<String, String> fareCodes = getFareCodes(variant.getPriceInfo());
        SetMultimap<FareFamily, String> ff2seg = MultimapBuilder.hashKeys().hashSetValues().build();
        for (Map.Entry<String, String> segCode : fareCodes.entrySet()) {
            ff2seg.put(fareFamilyReference.getFareFamily(segCode.getValue()), segCode.getKey());
        }
        Document ticketSegmentData = fareFamilyInputFactory.convertVariant(variant, fareCodes);
        Map<String, SegmentFare> ticketSegments;
        if (ff2seg.keySet().size() == 1) {
            FareFamily fareFamily = ff2seg.keySet().iterator().next();
            ticketSegments = fareFamilyReference.getFares(fareFamily, ticketSegmentData, lang);
        } else {
            // several fare families, we have to combine them if possible
            ensureSameBrand(ff2seg.keySet());
            ticketSegments = new HashMap<>();
            for (FareFamily fareFamily : ff2seg.keySet()) {
                ticketSegments.putAll(fareFamilyReference.getFaresPartial(fareFamily, ticketSegmentData, lang, ff2seg.get(fareFamily)));
            }
        }
        return combineTerms(ticketSegments);
    }

    private Map<String, SegmentFare> combineTerms(Map<String, SegmentFare> allTerms) {
        // Hack for MOW-AER-IST: https://st.yandex-team.ru/RASPTICKETS-22893
        for (String validatedTerm : TERMS_TO_VALIDATE) {
            Set<TermValue> values = new HashSet<>();
            for (SegmentFare segmentFare : allTerms.values()) {
                values.add(segmentFare.getTerms().get(validatedTerm));
            }
            if (values.size() > 1) {
                TermValue worstValue = getWorstValue(values);
                for (SegmentFare segmentFare : allTerms.values()) {
                    Map<String, TermValue> updatedTerms = new LinkedHashMap<>(segmentFare.getTerms());
                    updatedTerms.put(validatedTerm, worstValue);
                    segmentFare.setTerms(updatedTerms);
                }
            }
        }
        return allTerms;
    }

    /*
       Выбираем наихудшие возможные условия тарифа. Т.е. если есть тариф, где соответствующий term "NOT_AVAILABLE"
       (возврат билета, например), то выбираем его; если такого нет, выбираем возврат с наихудшим сбором -
       таковым сейчас считаем сбор в валюте, если он есть; если все сборы в рублях, то выбираем любой,
       т.к. они в текущей конфигурации должны все быть одинаковыми (а если нет, то пусть спамит в сентри, как сейчас,
       поправим в конфигурации).
     */
    private TermValue getWorstValue(@NotNull Set<TermValue> values) {
        TermValue result = null;
        for (TermValue value : values) {
            if (result == null) {
                result = value;
                continue;
            }
            if (result.getAvailability() == null && value.getAvailability() != null) {
                result = value;
                continue;
            }
            if (value.getAvailability() != null
                    && value.getAvailability().ordinal() > result.getAvailability().ordinal()) {
                result = value;
                continue;
            }
            if (value.getAvailability() == TermAvailability.CHARGE
                    && result.getAvailability() == TermAvailability.CHARGE) {
                if (result.getCharge().getCurrency().equals(RUB_CURRENCY)
                        && !value.getCharge().getCurrency().equals(RUB_CURRENCY)) {
                    result = value;
                    continue;
                }
            }
        }
        return result;
    }

    private Map<String, String> getFareCodes(PriceInfo priceInfo) {
        // ensuring that all segments have the same fare codes for every passenger
        Map<String, String> segmentFareCodes = new HashMap<>();
        for (CategoryPrice categoryPrice : priceInfo.getCategoryPrices()) {
            for (FareInfo fareInfo : categoryPrice.getFareInfo()) {
                String fareBasis = fareInfo.getCanonicalFareBasis();
                if (segmentFareCodes.containsKey(fareInfo.getFlightId())) {
                    if (!segmentFareCodes.get(fareInfo.getFlightId()).equals(fareBasis)) {
                        throw new AviaFareRulesException(String.format(
                                "Multiple fare codes for a single segment detected: segmentId=%s, codes=[%s, %s]",
                                fareInfo.getFlightId(), segmentFareCodes.get(fareInfo.getFlightId()), fareBasis));
                    }
                } else {
                    segmentFareCodes.put(fareInfo.getFlightId(), fareBasis);
                }
            }
        }
        return segmentFareCodes;
    }

    private void ensureSameBrand(Set<FareFamily> families) {
        Set<String> brands = new HashSet<>();
        for (FareFamily family : families) {
            brands.add(family.getBrand());
        }
        if (brands.size() != 1 || brands.iterator().next() == null) {
            throw new AviaFareRulesException("Don't know how to combine fare families: " +
                    families.stream().map(FareFamily::getTariffGroupName).map(n -> n.getForLang("ru")).collect(toList()));
        }
    }

    private Set<String> checkUnknownFaresForRequestedVariant(Variant variant, PriceInfo priceInfo) {
        Map<String, String> fareCodes = getFareCodes(priceInfo);
        Set<String> unknownFareCodes = new HashSet<>();
        for (Segment segment : variant.getSegments()) {
            for (Flight flight : segment.getFlights()) {
                String fareCode = fareCodes.get(flight.getId());
                if (!fareFamilyReference.isFareCodeSupported(fareCode)) {
                    unknownFareCodes.add(fareCode);
                }
            }
        }
        return unknownFareCodes;
    }

    public void checkUnknownFareCodesForRequestedVariant(Variant variant) throws AviaUnknownFareFamilyException {
        Set<String> unknownFareCodes = checkUnknownFaresForRequestedVariant(variant, variant.getPriceInfo());
        if (!unknownFareCodes.isEmpty()) {
            throw new AviaUnknownFareFamilyException("Unknown fare code(s): " + unknownFareCodes);
        }
    }

    public void removeOffersWithUnknownFareCodes(Variant variant) throws AviaUnknownFareFamilyException {
        List<PriceInfo> unsupportedOffers = new ArrayList<>();
        for (PriceInfo offer : variant.getAllTariffs()) {
            Set<String> unknownFareCodes = checkUnknownFaresForRequestedVariant(variant, offer);
            if (!unknownFareCodes.isEmpty()) {
                unsupportedOffers.add(offer);
            }
        }
        variant.getAllTariffs().removeAll(unsupportedOffers);
    }
}
