package ru.yandex.travel.api.services.avia.fares;

import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.regex.Pattern;
import java.util.stream.Collectors;

import com.google.common.base.Preconditions;
import com.google.common.base.Strings;
import com.google.common.collect.ImmutableMap;
import lombok.extern.slf4j.Slf4j;
import org.dom4j.Document;
import org.dom4j.DocumentHelper;
import org.dom4j.Node;
import org.dom4j.XPath;

import ru.yandex.avia.booking.ff.model.FareFamily;
import ru.yandex.avia.booking.ff.model.FareFamilyRule;
import ru.yandex.avia.booking.ff.model.FareFamilyTerm;
import ru.yandex.avia.booking.ff.model.SegmentFare;
import ru.yandex.avia.booking.ff.model.TermValue;

import static java.util.stream.Collectors.toMap;

@Slf4j
public class AviaFareFamilyReference {
    private final List<FareFamily> families;
    private final Map<String, String> externalExpressions;

    // caches
    private final Map<String, Pattern> compiledPatterns;
    private final Map<String, XPath> compiledXpathExpressions;

    public AviaFareFamilyReference(List<FareFamily> families, Map<String, String> externalExpressions) {
        this.families = families;
        this.externalExpressions = externalExpressions;

        validateFormat();

        compiledPatterns = ImmutableMap.copyOf(families.stream()
                .map(FareFamily::getTariffCodePattern).collect(toMap(cp -> cp, Pattern::compile)));
        compiledXpathExpressions = ImmutableMap.copyOf(collectXpathExpressions()
                .stream().collect(toMap(e -> e, DocumentHelper::createXPath)));
    }

    private void validateFormat() {
        for (FareFamily family : families) {
            for (FareFamilyTerm term : family.getTerms()) {
                for (FareFamilyRule rule : term.getRules()) {
                    if (!Strings.isNullOrEmpty(rule.getExternalXpathRef())) {
                        Preconditions.checkArgument(externalExpressions.containsKey(rule.getExternalXpathRef()),
                                "Undefined external_xpath_ref: %s", rule.getExternalXpathRef());
                        Preconditions.checkArgument(Strings.isNullOrEmpty(rule.getXpath()),
                                "Properties xpath or external_xpath_ref can't be used together");
                    }
                }
            }
        }
    }

    private Set<String> collectXpathExpressions() {
        Set<String> expressions = new HashSet<>();
        for (FareFamily family : families) {
            for (FareFamilyTerm term : family.getTerms()) {
                for (FareFamilyRule rule : term.getRules()) {
                    if (!Strings.isNullOrEmpty(rule.getXpath())) {
                        expressions.add(rule.getXpath());
                    }
                    if (!Strings.isNullOrEmpty(rule.getExternalXpathRef())) {
                        expressions.add(externalExpressions.get(rule.getExternalXpathRef()));
                    }
                }
            }
        }
        return expressions;
    }

    public FareFamily getFareFamily(String fareCode) {
        for (FareFamily family : families) {
            Pattern p = compiledPatterns.get(family.getTariffCodePattern());
            if (p.matcher(fareCode).matches()) {
                return family;
            }
        }
        return null;
    }

    /**
     * For the given leg and fareCode finds an appropriate fare family and fills the terms table for every segment.
     *
     * @throws IllegalArgumentException if there is a segment without a term value for one of the predefined terms
     */
    public Map<String, SegmentFare> getFares(String fareCode, Document legDoc, String lang) {
        FareFamily fareFamily = getFareFamily(fareCode);
        Preconditions.checkArgument(fareFamily != null, "No fare family matches the fare family code: %s", fareCode);
        return getFares(fareFamily, legDoc, lang);
    }

    public Map<String, SegmentFare> getFares(FareFamily fareFamily, Document legDoc, String lang) {
        Preconditions.checkNotNull(fareFamily, "Fare family cannot be null");
        Set<String> segmentIds = legDoc.selectNodes("Leg/Seg").stream()
                .map(seg -> seg.valueOf("Id")).collect(Collectors.toSet());
        return matchFamilyTerms(fareFamily, legDoc, lang, segmentIds);
    }

    public Map<String, SegmentFare> getFaresPartial(FareFamily fareFamily, Document legDoc, String lang, Set<String> segmentIds) {
        Preconditions.checkNotNull(fareFamily, "Fare family cannot be null");
        return matchFamilyTerms(fareFamily, legDoc, lang, segmentIds);
    }

    public boolean isFareCodeSupported(String fareCode) {
        return getFareFamily(fareCode) != null;
    }

    private Map<String, SegmentFare> matchFamilyTerms(FareFamily fareFamily, Document legDoc, String lang, Set<String> segmentIds) {
        if (segmentIds.isEmpty()) {
            throw new IllegalArgumentException("Segments list is empty: xml=" + legDoc.asXML());
        }
        Map<String, SegmentFare> segmentFares = new LinkedHashMap<>();
        for (String segmentId : segmentIds) {
            segmentFares.put(segmentId, SegmentFare.builder()
                    .baseClass(fareFamily.getBaseClass())
                    .tariffGroupName(fareFamily.getTariffGroupName().getForLang(lang))
                    .terms(new HashMap<>())
                    .build());
        }
        for (FareFamilyTerm term : fareFamily.getTerms()) {
            Map<String, TermValue> foundTerms = matchSingleTerm(term, legDoc, lang);
            foundTerms.keySet().retainAll(segmentIds);
            if (!foundTerms.keySet().equals(segmentIds)) {
                throw new IllegalArgumentException("Can't find term value for every segment: " +
                        "termCode=" + term.getCode() + ", segments=" + segmentIds + ", found=" + foundTerms);
            }
            for (String segmentId : foundTerms.keySet()) {
                segmentFares.get(segmentId).getTerms().put(term.getCode(), foundTerms.get(segmentId));
            }
        }
        return segmentFares;
    }

    private Map<String, TermValue> matchSingleTerm(FareFamilyTerm term, Document legDoc, String lang) {
        // segment_id -> term_value
        Map<String, TermValue> foundTerms = new LinkedHashMap<>();
        for (FareFamilyRule rule : term.getRules()) {
            if (rule.getIgnore() == Boolean.TRUE) {
                continue;
            }
            // every rule has a term value (e.g. 2x23 kg baggage or 150% miles) and a condition specifying
            // when the rule can be applied to one or more segments of a leg;
            // segments matched by this condition will have the term value of the rule
            List<Node> segments;
            if (!Strings.isNullOrEmpty(rule.getXpath()) || !Strings.isNullOrEmpty(rule.getExternalXpathRef())) {
                String xpathSrc = !Strings.isNullOrEmpty(rule.getXpath()) ?
                        rule.getXpath() :
                        externalExpressions.get(rule.getExternalXpathRef());
                XPath xpath = compiledXpathExpressions.get(xpathSrc);
                segments = xpath.selectNodes(legDoc);
            } else {
                // no filters -> matches all
                segments = legDoc.selectNodes("Leg/Seg");
            }
            for (Node segment : segments) {
                String id = segment.valueOf("Id");
                if (!foundTerms.containsKey(id)) {
                    // more specific rules go earlier, we shouldn't override their values by more generic rules
                    foundTerms.put(id, rule.toTermValue(lang, term.getSpecialNotes()));
                }
            }
        }
        return foundTerms;
    }
}
