package ru.yandex.direct.core.entity.goal.repository;


import java.math.BigDecimal;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import javax.annotation.ParametersAreNonnullByDefault;

import one.util.streamex.EntryStream;
import one.util.streamex.StreamEx;
import org.jooq.Field;
import org.jooq.types.ULong;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Repository;

import ru.yandex.direct.grid.schema.yt.Tables;
import ru.yandex.direct.grid.schema.yt.tables.SuggestConversionPrice;
import ru.yandex.direct.ytcomponents.service.ConversionPriceForecastDynContextProvider;
import ru.yandex.direct.ytwrapper.dynamic.dsl.YtDSL;
import ru.yandex.inside.yt.kosher.ytree.YTreeMapNode;

import static java.util.function.Function.identity;
import static ru.yandex.direct.ytwrapper.YtTableUtils.aliased;

@Repository
@ParametersAreNonnullByDefault
public class ConversionPriceForecastRepository {
    public static final SuggestConversionPrice CONVERSION_PRICE = Tables.SUGGEST_CONVERSION_PRICE.as("SCP");
    private static final Field<BigDecimal> MEAN_PRICE = aliased(CONVERSION_PRICE.MEAN_PRICE);
    private static final Field<BigDecimal> CLICKS_COUNT = aliased(CONVERSION_PRICE.CLICKS_COUNT);
    private static final Field<String> GOAL_TYPE = aliased(CONVERSION_PRICE.GOAL_TYPE);
    private static final Field<ULong> CATEGORY_ID = aliased(CONVERSION_PRICE.CATEGORY_ID);
    private static final Field<ULong> SIGNIFICANT_CATEGORY_ID = aliased(CONVERSION_PRICE.SIGNIFICANT_CATEGORY_ID);

    private static final long DEFAULT_ATTRIBUTION_TYPE = 4; //last yandex direct click

    private final ConversionPriceForecastDynContextProvider dynContextProvider;

    @Autowired
    public ConversionPriceForecastRepository(ConversionPriceForecastDynContextProvider dynContextProvider) {
        this.dynContextProvider = dynContextProvider;
    }

    /**
     * Для каждого региона из regions по категории баннеров и типу цели получает среднюю цену за конверсию и
     * количество кликов, сделанных пользователями по соответствующим баннерам. В качестве типа атрибуции используется
     * "last yandex direct click".
     *
     * @param categoryId категория баннеров
     * @param goalTypes  типы целей
     * @param regions    список id целевых регионов
     * @return для каждого типа цели из goalTypes - список кликов и рекомендуемых цен за конверсию
     * для регионов из regions и значимых для categoryId категорий
     */
    public Map<String, List<PriceWithClicks>> getPriceWithClicksByGoalTypes(long categoryId,
                                                                            List<String> goalTypes,
                                                                            List<Long> regions) {
        Map<String, List<Long>> significantCategoriesByGoalTypes =
                getSignificantCategoryIdsForCategoryIdByGoalTypes(categoryId, goalTypes);

        return getPriceWithClicksByGoalTypes(goalTypes, regions, significantCategoriesByGoalTypes);

    }

    /**
     * @param categoryId категория баннеров
     * @param goalTypes  список целей
     * @return для типов целей из goalTypes - список значимых для categoryId категорий (>= 1000 кликов по баннерам
     * этих категорий)
     */
    private Map<String, List<Long>> getSignificantCategoryIdsForCategoryIdByGoalTypes(long categoryId,
                                                                                      List<String> goalTypes) {
        var selectSignificantCategoriesForCategoryByGoalTypeQuery = YtDSL.ytContext()
                .select(GOAL_TYPE, SIGNIFICANT_CATEGORY_ID)
                .from(CONVERSION_PRICE)
                .where(CONVERSION_PRICE.ATTRIBUTION_TYPE.eq(DEFAULT_ATTRIBUTION_TYPE),
                        CONVERSION_PRICE.CATEGORY_ID.eq(ULong.valueOf(categoryId)),
                        GOAL_TYPE.in(goalTypes));

        List<YTreeMapNode> queryResult = dynContextProvider.getContext()
                .executeSelect(selectSignificantCategoriesForCategoryByGoalTypeQuery)
                .getYTreeRows();

        return convertToSignificantCategoriesByGoalTypes(queryResult, categoryId);
    }

    private Map<String, List<Long>> convertToSignificantCategoriesByGoalTypes(List<YTreeMapNode> queryResult,
                                                                              long categoryId) {
        Map<String, List<Long>> significantCategories = StreamEx.of(queryResult)
                .mapToEntry(r -> r.getString(GOAL_TYPE.getName()),
                        r -> r.getLong(SIGNIFICANT_CATEGORY_ID.getName()))
                .grouping();

        significantCategories.forEach((goalType, signCategoryIds) -> {
            if (signCategoryIds.contains(categoryId)) {
                signCategoryIds.removeIf(c -> c.equals(categoryId));
            }
        });

        return significantCategories;
    }

    private Map<String, List<PriceWithClicks>> getPriceWithClicksByGoalTypes(List<String> goalTypes,
                                                                             List<Long> regions,
                                                                             Map<String, List<Long>> significantCategoriesByGoalType) {
        var regionsPriceAndClicksByCategoryIdByGoalTypesQuery = YtDSL.ytContext()
                .select(GOAL_TYPE, CATEGORY_ID, MEAN_PRICE, CLICKS_COUNT)
                .from(CONVERSION_PRICE)
                .where(CONVERSION_PRICE.ATTRIBUTION_TYPE.eq(DEFAULT_ATTRIBUTION_TYPE),
                        GOAL_TYPE.in(goalTypes),
                        CONVERSION_PRICE.REGION_ID.in(regions));

        List<YTreeMapNode> queryResult = dynContextProvider.getContext()
                .executeSelect(regionsPriceAndClicksByCategoryIdByGoalTypesQuery)
                .getYTreeRows();

        return convertToPriceWithClicksByGoalTypes(queryResult, significantCategoriesByGoalType);
    }

    private Map<String, List<PriceWithClicks>> convertToPriceWithClicksByGoalTypes(List<YTreeMapNode> queryResult,
                                                                                   Map<String, List<Long>> significantCategoriesByGoalType) {
        Map<String, List<YTreeMapNode>> rowsByGoalType = groupByGoalType(queryResult);

        Map<String, List<YTreeMapNode>> rowsByGoalTypeAndCategoryId = EntryStream.of(rowsByGoalType)
                .mapValues(ConversionPriceForecastRepository::groupByCategoryId)
                .mapToValue((goalType, rowsGroupedByCategoryId) -> filterCategories(
                        significantCategoriesByGoalType,
                        goalType,
                        rowsGroupedByCategoryId))
                .mapValues(node -> node.values()
                        .stream()
                        .flatMap(List::stream)
                        .collect(Collectors.toList()))
                .toMap();


        return EntryStream.of(rowsByGoalTypeAndCategoryId)
                .mapValues(ConversionPriceForecastRepository::getPriceWithClicks)
                .toMap();
    }

    private static Map<String, List<YTreeMapNode>> groupByGoalType(List<YTreeMapNode> rows) {
        return StreamEx.of(rows)
                .mapToEntry(r -> r.getString(GOAL_TYPE.getName()), identity())
                .grouping();
    }

    private static Map<Long, List<YTreeMapNode>> groupByCategoryId(List<YTreeMapNode> rows) {
        return StreamEx.of(rows)
                .mapToEntry(r -> r.getLong(CATEGORY_ID.getName()), identity())
                .grouping();
    }

    private Map<Long, List<YTreeMapNode>> filterCategories(Map<String, List<Long>> significantCategoriesByGoalType,
                                                           String goalType,
                                                           Map<Long, List<YTreeMapNode>> rowsGroupedByCategoryId) {
        List<Long> significantCategoriesForGoalType = significantCategoriesByGoalType.get(goalType);
        return rowsGroupedByCategoryId.entrySet()
                .stream()
                .filter(e -> significantCategoriesForGoalType.contains(e.getKey()))
                .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
    }

    private static List<PriceWithClicks> getPriceWithClicks(List<YTreeMapNode> rows) {
        return rows.stream()
                .map(r -> new PriceWithClicks(
                        r.getDouble(MEAN_PRICE.getName()),
                        r.getDouble(CLICKS_COUNT.getName())
                ))
                .collect(Collectors.toList());
    }

}
