package ru.yandex.direct.core.entity.forecast;

import java.math.BigDecimal;
import java.util.List;
import java.util.Map;
import java.util.function.Function;

import javax.annotation.Nullable;

import one.util.streamex.EntryStream;
import one.util.streamex.StreamEx;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;

import ru.yandex.direct.core.entity.forecast.model.CpaEstimate;
import ru.yandex.direct.core.entity.forecast.model.CpaEstimateAttributionModel;
import ru.yandex.direct.core.entity.forecast.model.CpaEstimatesContainer;
import ru.yandex.direct.core.entity.retargeting.model.MetrikaCounterGoalType;

import static java.util.Collections.emptyMap;

@Service
public class CpaEstimatesService {

    private final CpaEstimatesRepository cpaEstimatesRepository;

    @Autowired
    public CpaEstimatesService(CpaEstimatesRepository cpaEstimatesRepository) {
        this.cpaEstimatesRepository = cpaEstimatesRepository;
    }

    public Map<MetrikaCounterGoalType, CpaEstimatesContainer> getAvgCpaByGoalType(@Nullable String businessCategory,
                                                                                  @Nullable List<Long> regionIds,
                                                                                  @Nullable CpaEstimateAttributionModel attributionModel) {
        if (businessCategory == null) {
            return emptyMap();
        }

        List<CpaEstimate> cpaEstimates = cpaEstimatesRepository.getCpaEstimates(businessCategory, regionIds,
                attributionModel);

        Map<MetrikaCounterGoalType, List<CpaEstimatesContainer>> cpasByGoalType = StreamEx.of(cpaEstimates)
                .mapToEntry(CpaEstimate::getGoalType, CpaEstimatesService::getCpaFields)
                .mapKeys(cpaEstimateGoalType -> MetrikaCounterGoalType.valueOf(cpaEstimateGoalType.name()))
                .grouping();

        return EntryStream.of(cpasByGoalType)
                .mapValues(CpaEstimatesService::aggregateCpas)
                .toMap();
    }

    private static CpaEstimatesContainer getCpaFields(CpaEstimate cpaEstimate) {
        return new CpaEstimatesContainer()
                .withMinCpa(cpaEstimate.getMinCpa())
                .withMaxCpa(cpaEstimate.getMaxCpa())
                .withMedianCpa(cpaEstimate.getMedianCpa());
    }

    private static CpaEstimatesContainer aggregateCpas(List<CpaEstimatesContainer> cpaEstimatesContainers) {
        double minCpa = aggregateCpa(cpaEstimatesContainers, CpaEstimatesContainer::getMinCpa);
        double maxCpa = aggregateCpa(cpaEstimatesContainers, CpaEstimatesContainer::getMaxCpa);
        double medianCpa = aggregateCpa(cpaEstimatesContainers, CpaEstimatesContainer::getMedianCpa);

        return new CpaEstimatesContainer()
                .withMedianCpa(medianCpa)
                .withMinCpa(minCpa)
                .withMaxCpa(maxCpa);
    }

    private static double aggregateCpa(List<CpaEstimatesContainer> cpaEstimatesContainers,
                                       Function<CpaEstimatesContainer, BigDecimal> fieldGetter) {
        return StreamEx.of(cpaEstimatesContainers)
                .mapToDouble(cpaEstimatesContainer -> fieldGetter.apply(cpaEstimatesContainer).doubleValue())
                .max()
                .getAsDouble();
    }
}
