package ru.yandex.direct.core.entity.campaign.service.pricerecalculation;

import java.math.BigDecimal;
import java.math.RoundingMode;
import java.time.temporal.ChronoUnit;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;

import one.util.streamex.EntryStream;
import one.util.streamex.StreamEx;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.commons.lang3.tuple.Triple;
import org.jetbrains.annotations.Nullable;

import ru.yandex.direct.core.entity.campaign.model.CampaignWithPricePackage;
import ru.yandex.direct.core.entity.campaign.model.CpmPriceCampaign;
import ru.yandex.direct.core.entity.pricepackage.model.MarkupCondition;
import ru.yandex.direct.core.entity.pricepackage.model.MarkupConditionsOperator;
import ru.yandex.direct.core.entity.pricepackage.model.PriceCalculation;
import ru.yandex.direct.core.entity.pricepackage.model.PriceMarkup;
import ru.yandex.direct.core.entity.pricepackage.model.PricePackage;
import ru.yandex.direct.core.entity.pricepackage.model.TargetingMarkup;
import ru.yandex.direct.core.entity.retargeting.model.Goal;
import ru.yandex.direct.core.entity.retargeting.model.GoalType;
import ru.yandex.direct.utils.TriFunction;

import static java.util.Collections.emptyList;
import static java.util.stream.Collectors.groupingBy;
import static java.util.stream.Collectors.maxBy;
import static java.util.stream.Collectors.toList;
import static java.util.stream.Collectors.toMap;
import static java.util.stream.Collectors.toSet;
import static ru.yandex.direct.core.entity.pricepackage.model.MarkupConditionsOperator.AND;
import static ru.yandex.direct.utils.CommonUtils.ifNotNull;
import static ru.yandex.direct.utils.CommonUtils.nvl;
import static ru.yandex.direct.utils.FunctionalUtils.mapList;

public class PriceCalculator implements TriFunction<List<Long>, List<Long>, List<Long>, PriceCalculation> {
    private final CpmPriceCampaign cpmPriceCampaign;
    private final Map<Long, PricePackage> pricePackages;
    private final Map<Long, MarkupCondition> markupConditionMap;

    public PriceCalculator(CpmPriceCampaign cpmPriceCampaign, Map<Long, PricePackage> pricePackages, Map<Long, MarkupCondition> markupConditionMap) {
        this.cpmPriceCampaign = cpmPriceCampaign;
        this.pricePackages = pricePackages;
        this.markupConditionMap = markupConditionMap;
    }

    @Override
    public PriceCalculation apply(List<Long> geo, List<Long> goalIds,
                                  List<Long> projectParamConditions) {
        return getPackagePrice(cpmPriceCampaign,
                pricePackages.get(cpmPriceCampaign.getPricePackageId()), markupConditionMap, geo,
                goalIds, projectParamConditions);
    }

    private static PriceCalculation getPackagePrice(
            CpmPriceCampaign c, PricePackage pricePackage, Map<Long, MarkupCondition> markupConditionMap,
            List<Long> geo, List<Long> segments, List<Long> projectParamConditions) {
        PriceMarkup seasonalPriceMarkup = getSeasonalPriceRatio(c, pricePackage);
        if (seasonalPriceMarkup == null) {
            return null;
        }
        Set<GoalType> goalTypes = ifNotNull(mapList(segments, Goal::computeType), Set::copyOf);
        Triple<BigDecimal, List<TargetingMarkup>, List<MarkupCondition>> targetingPrice =
                getTargetingPriceRatio(pricePackage, markupConditionMap, geo, segments, goalTypes, projectParamConditions);
        BigDecimal seasonalPriceRatio = percentIntegerToBigDecimal(seasonalPriceMarkup.getPercent());
        BigDecimal targetingPriceRatio = targetingPrice.getLeft();
        BigDecimal totalPrice = pricePackage.getPrice().multiply(seasonalPriceRatio).multiply(targetingPriceRatio);
        if (nvl(pricePackage.getIsCpd(), false)) {
            totalPrice = totalPrice.multiply(BigDecimal.valueOf(1000))
                    .divide(BigDecimal.valueOf(pricePackage.getOrderVolumeMin()), RoundingMode.HALF_UP);
        }
        BigDecimal totalPercent = totalPrice.divide(pricePackage.getPrice(), RoundingMode.HALF_UP)
                .multiply(BigDecimal.valueOf(100))
                .subtract(BigDecimal.valueOf(100));
        return new PriceCalculation()
                .withPrice(totalPrice.setScale(2, RoundingMode.HALF_UP))
                .withPercent(totalPercent.setScale(2, RoundingMode.HALF_UP))
                .withPriceMarkup(seasonalPriceMarkup)
                .withTargetingMarkups(targetingPrice.getMiddle())
                .withMarkupConditions(targetingPrice.getRight());
    }
    public static PriceMarkup getSeasonalPriceRatio(CampaignWithPricePackage c, PricePackage pricePackage) {
        List<PriceMarkup> priceMarkups = nvl(pricePackage.getPriceMarkups(), emptyList());
        var possiblePeriods = priceMarkups.stream()
                .filter(pm -> c.getStartDate().minusDays(1).isBefore(pm.getDateStart()) && c.getEndDate().plusDays(1).isAfter(pm.getDateStart())
                        || c.getEndDate().plusDays(1).isAfter(pm.getDateEnd()) && c.getStartDate().minusDays(1).isBefore(pm.getDateEnd())
                        || c.getStartDate().plusDays(1).isAfter(pm.getDateStart()) && c.getEndDate().minusDays(1).isBefore(pm.getDateEnd()))
                .sorted(Comparator.comparing(PriceMarkup::getDateStart))
                .collect(toList());
        var actualPeriods = enrichSeasonPriceMarkups(c, possiblePeriods);
        if (actualPeriods.size() > 2) {
            return null;
        }
        var minMaxPeriods = getMinMaxPeriods(actualPeriods);
        if (minMaxPeriods == null) {
            return null;
        }
        PriceMarkup minPeriod = minMaxPeriods.getLeft();
        PriceMarkup maxPeriod = minMaxPeriods.getRight();
        if (maxPeriod == null) {
            return minPeriod;
        }
        long days = (maxPeriod.getDateStart().minusDays(1).isBefore(c.getStartDate())) ?
                c.getStartDate().until(maxPeriod.getDateEnd(), ChronoUnit.DAYS) :
                maxPeriod.getDateStart().until(c.getEndDate(), ChronoUnit.DAYS);
        if (days < 7) {
            return minPeriod;
        }
        return null;
    }

    private static List<PriceMarkup> enrichSeasonPriceMarkups(CampaignWithPricePackage c, List<PriceMarkup> originalPeriods) {
        List<PriceMarkup> actual = new ArrayList<>();
        PriceMarkup previousPeriod = null;
        for (var period : originalPeriods) {
            if (previousPeriod == null) {
                if (period.getDateStart().isAfter(c.getStartDate())) {
                    // если период наценки начинается после даты начала кампании, добавляем наценку с нулевым процентом,
                    // чтобы весь период кампании был с наценками, для дальнейшего подсчета
                    var fakePeriod = new PriceMarkup()
                            .withDateStart(c.getStartDate())
                            .withDateEnd(period.getDateStart().minusDays(1))
                            .withPercent(0);
                    actual.add(fakePeriod);
                    actual.add(period);
                    previousPeriod = period;
                } else if (period.getDateStart().minusDays(1).isBefore(c.getStartDate())) {
                    //период начинается до даты старта кампании и это первый период
                    actual.add(period);
                    previousPeriod = period;
                }
            } else {
                if (previousPeriod.getDateEnd().plusDays(1).isEqual(period.getDateStart())) {
                    // уже не первый период, смотрим, что следующий идет подряд.
                    actual.add(period);
                    previousPeriod = period;
                } else if (!previousPeriod.getDateEnd().plusDays(1).isEqual(period.getDateStart())) {
                    // если после сезонного периода идет период без наценок, тоже делаем мнимую наценку на этот период
                    previousPeriod = new PriceMarkup()
                            .withDateStart(previousPeriod.getDateEnd().plusDays(1))
                            .withDateEnd(period.getDateStart().minusDays(1))
                            .withPercent(0);
                    actual.add(period);
                    actual.add(previousPeriod);
                }
            }
        }
        if (previousPeriod != null && previousPeriod.getDateEnd().isBefore(c.getEndDate())) {
            // наценки закончились раньше, чем кампания, добавляем в конце мнимую наценку на этот период
            actual.add(new PriceMarkup()
                    .withDateStart(previousPeriod.getDateEnd().plusDays(1))
                    .withDateEnd(c.getEndDate())
                    .withPercent(0));
        }
        if (previousPeriod == null && actual.isEmpty()) {
            // если не было найдено ни одной наценки, то делаем единственную мнимую наценку длиной с кампанию.
            actual.add(new PriceMarkup()
                    .withDateStart(c.getStartDate())
                    .withDateEnd(c.getEndDate())
                    .withPercent(0));
        }
        return actual;
    }

    public static BigDecimal percentIntegerToBigDecimal(Integer percent) {
        return BigDecimal.valueOf(100 + percent)
                .setScale(6, RoundingMode.HALF_UP)
                .divide(BigDecimal.valueOf(100), RoundingMode.HALF_UP);
    }

    private static Pair<PriceMarkup, PriceMarkup> getMinMaxPeriods(Collection<PriceMarkup> allPeriods) {
        if (allPeriods.isEmpty()) {
            return null;
        }
        if (allPeriods.size() == 1) {
            return Pair.of(allPeriods.stream().findFirst().get(), null);
        }
        return Pair.of(allPeriods.stream().min(Comparator.comparing(PriceMarkup::getPercent)).get(),
                allPeriods.stream().max(Comparator.comparing(PriceMarkup::getPercent)).get());
    }

    private static Triple<BigDecimal, List<TargetingMarkup>, List<MarkupCondition>> getTargetingPriceRatio(
            PricePackage pricePackage, Map<Long, MarkupCondition> markupConditionMap, List<Long> geo,
            List<Long> segments, Set<GoalType> goalTypes, List<Long> projectParamConditions) {

        List<TargetingMarkup> targetingMarkups = nvl(pricePackage.getTargetingMarkups(), emptyList());
        Map<Long, BigDecimal> conditionIdToPercent = targetingMarkups.stream()
                .collect(toMap(TargetingMarkup::getConditionId, TargetingMarkup::getPercent));

        Map<String, Optional<MarkupCondition>> topPriorityConditionMatchByGroup = markupConditionMap.values().stream()
                .filter(condition -> conditionIdToPercent.containsKey(condition.getId()))
                .filter(condition -> conditionMatchesTargeting(condition, geo, segments, goalTypes, projectParamConditions))
                .collect(groupingBy(condition -> nvl(condition.getGroup(), condition.getId().toString()), HashMap::new,
                        maxBy(Comparator.comparing(MarkupCondition::getPriority))));

        List<Long> relevantConditionIds = topPriorityConditionMatchByGroup.values().stream()
                .filter(Optional::isPresent)
                .map(Optional::get)
                .map(MarkupCondition::getId)
                .collect(toList());

        List<TargetingMarkup> relevantMarkups = targetingMarkups.stream()
                .filter(tm -> relevantConditionIds.contains(tm.getConditionId()))
                .collect(toList());

        List<MarkupCondition> relevantConditions = EntryStream.of(markupConditionMap)
                .filterKeys(relevantConditionIds::contains)
                .values()
                .toList();

        Optional<BigDecimal> targetingPriceRatioOptional = relevantConditionIds.stream()
                .map(conditionIdToPercent::get)
                .map(percent -> BigDecimal.valueOf(100).add(percent)
                        .setScale(6, RoundingMode.HALF_UP)
                        .divide(BigDecimal.valueOf(100), RoundingMode.HALF_UP))
                .reduce(BigDecimal::multiply);

        return targetingPriceRatioOptional
                .map(ratio -> Triple.of(ratio, relevantMarkups, relevantConditions))
                .orElse(Triple.of(BigDecimal.ONE, null, null));
    }

    private static boolean conditionMatchesTargeting(MarkupCondition condition, List<Long> geo, List<Long> segments,
                                                     Set<GoalType> goalTypes, List<Long> projectParamConditions) {
        Boolean cryptaSegmentsMatches =
                conditionMatchesPieceOfTargeting(segments, condition.getTargeting().getCryptaSegments(),
                        condition.getOperator());
        Boolean geoMatches = conditionMatchesPieceOfTargeting(geo, condition.getTargeting().getGeo(),
                condition.getOperator());
        Boolean goalTypesMatches = conditionMatchesPieceOfTargeting(goalTypes, condition.getTargeting().getGoalTypes(),
                condition.getOperator());
        Boolean projectParamConditionsMatches = conditionMatchesPieceOfTargeting(projectParamConditions,
                condition.getTargeting().getProjectParamConditions(), condition.getOperator());
        Set<Boolean> matches = StreamEx.of(cryptaSegmentsMatches, geoMatches, goalTypesMatches, projectParamConditionsMatches)
                .filter(Objects::nonNull)
                .collect(toSet());
        boolean entireConditionMatches =
                condition.getOperator() == AND ? !matches.isEmpty() && !matches.contains(Boolean.FALSE)
                        : matches.contains(Boolean.TRUE);
        return entireConditionMatches;
    }

    @Nullable
    private static <T> Boolean conditionMatchesPieceOfTargeting(Collection<T> targeting,
                                                                Collection<T> conditionTargeting, MarkupConditionsOperator conditionsOperator) {
        return ifNotNull(conditionTargeting,
                ct -> conditionsOperator == AND
                        ? ifNotNull(targeting, t -> t.containsAll(conditionTargeting))
                        : ifNotNull(targeting, t -> !Collections.disjoint(ct, t)));
    }
    public BigDecimal getDefaultPricePackagePrice() {
        return pricePackages.get(cpmPriceCampaign.getPricePackageId()).getPrice().setScale(2, RoundingMode.HALF_UP);
    }
}
