package ru.yandex.direct.core.entity.campaign.repository;

import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.stream.Collectors;

import javax.annotation.Nullable;
import javax.annotation.ParametersAreNonnullByDefault;

import com.google.common.collect.Sets;
import org.jooq.Field;
import org.jooq.Select;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Repository;

import ru.yandex.direct.core.entity.campaign.model.BasicUplift;
import ru.yandex.direct.core.entity.campaign.model.BrandSurveyStatus;
import ru.yandex.direct.core.entity.campaign.model.BrandSurveyStatusRow;
import ru.yandex.direct.core.entity.campaign.model.BrandSurveyStopReason;
import ru.yandex.direct.core.entity.campaign.model.PythiaModerationStatus;
import ru.yandex.direct.core.entity.campaign.model.PythiaSurveyStatus;
import ru.yandex.direct.core.entity.campaign.model.SurveyStatus;
import ru.yandex.direct.grid.schema.yt.tables.BrandliftSurveys;
import ru.yandex.direct.tracing.Trace;
import ru.yandex.direct.ytcomponents.config.BrandSurveyStatusDynConfig;
import ru.yandex.direct.ytcomponents.config.BsExportYtDynConfig;
import ru.yandex.direct.ytcomponents.service.BrandSurveyStatusDynContextProvider;
import ru.yandex.direct.ytwrapper.client.YtProvider;
import ru.yandex.direct.ytwrapper.dynamic.dsl.YtDSL;
import ru.yandex.direct.ytwrapper.model.YtCluster;
import ru.yandex.inside.yt.kosher.impl.ytree.YTreeListNodeImpl;
import ru.yandex.inside.yt.kosher.ytree.YTreeMapNode;
import ru.yandex.inside.yt.kosher.ytree.YTreeNode;
import ru.yandex.yt.rpcproxy.ETransactionType;
import ru.yandex.yt.ytclient.proxy.ApiServiceTransaction;
import ru.yandex.yt.ytclient.proxy.ApiServiceTransactionOptions;
import ru.yandex.yt.ytclient.proxy.ModifyRowsRequest;
import ru.yandex.yt.ytclient.tables.ColumnValueType;
import ru.yandex.yt.ytclient.tables.TableSchema;

import static java.util.Collections.emptyList;
import static java.util.Collections.emptySet;
import static java.util.stream.Collectors.mapping;
import static java.util.stream.Collectors.toList;
import static java.util.stream.Collectors.toSet;
import static ru.yandex.direct.grid.schema.yt.tables.BrandliftSurveys.BRANDLIFT_SURVEYS;
import static ru.yandex.direct.ytwrapper.YtTableUtils.aliased;

@ParametersAreNonnullByDefault
@Repository
public class CampaignBrandSurveyYtRepository {
    private static final Logger logger = LoggerFactory.getLogger(CampaignBrandSurveyYtRepository.class);
    private static final BrandliftSurveys BRANDLIFT_SURVEYS_ALIAS = BRANDLIFT_SURVEYS.as("B");
    private static final Field<String> STOP_REASONS = aliased(BRANDLIFT_SURVEYS_ALIAS.STOP_REASONS);
    private static final Field<String> MODERATION_REASONS = aliased(BRANDLIFT_SURVEYS_ALIAS.MODERATION_REASONS);
    private static final Field<String> MODERATION_STATE = aliased(BRANDLIFT_SURVEYS_ALIAS.MODERATION_STATE);
    private static final Field<String> STATE = aliased(BRANDLIFT_SURVEYS_ALIAS.STATE);
    private static final Field<String> BRAND_SURVEY_ID = aliased(BRANDLIFT_SURVEYS_ALIAS.BRAND_SURVEY_ID);
    private static final Field<String> BASIC_UPLIFT = aliased(BRANDLIFT_SURVEYS_ALIAS.BASIC_UPLIFT);

    private final BrandSurveyStatusDynContextProvider dynContextProvider;

    private final YtProvider ytProvider;

    private final YtCluster ytCluster;

    private final String tablePath;

    private final TableSchema tableSchema;

    public CampaignBrandSurveyYtRepository(BrandSurveyStatusDynContextProvider dynContextProvider,
                                           YtProvider ytProvider,
                                           BsExportYtDynConfig ytConfig,
                                           BrandSurveyStatusDynConfig config) {
        this.dynContextProvider = dynContextProvider;
        this.ytProvider = ytProvider;
        var clusters = dynContextProvider.getContext().getClustersByPriority();
        if (clusters == null || clusters.isEmpty()) {
            this.ytCluster = YtCluster.ARNOLD;
        } else {
            this.ytCluster = dynContextProvider.getContext().getClustersByPriority().get(0);
        }
        this.tablePath = config.getBrandSurveyStatusTablePath();
        this.tableSchema = getTableSchema();
    }

    public Map<Long, BrandSurveyStatus> getBrandSurveyStatusForCampaigns(Map<Long, String> campaignIdToBrandSurveyIdMap) {

        var brandSurveyStatuses = getBrandSurveyStatuses(campaignIdToBrandSurveyIdMap.values(),
                BrandSurveyStatusRow::getBrandSurveyId);

        var brandSurveyIdToCampaignIds = campaignIdToBrandSurveyIdMap.entrySet().stream()
                .filter(entry -> entry.getValue() != null)
                .collect(Collectors.groupingBy(x -> x.getValue(), mapping(x -> x.getKey(), toSet())));

        var result = new HashMap<Long, BrandSurveyStatus>();
        for (var brandSurveyToCampaigns : brandSurveyIdToCampaignIds.entrySet()) {
            var status = brandSurveyStatuses.get(brandSurveyToCampaigns.getKey());
            for (var campaignId : brandSurveyToCampaigns.getValue()) {
                if (status != null) {
                    result.put(campaignId, status);
                }
            }
        }
        return result;
    }

    public Map<String, BrandSurveyStatus> getStatusForBrandSurveys(Collection<String> brandSurveyIds) {
        return getBrandSurveyStatuses(brandSurveyIds, BrandSurveyStatusRow::getBrandSurveyId);
    }

    private <T> Map<T, BrandSurveyStatus> getBrandSurveyStatuses(Collection<String> brandSurveyIds,
                                                                 Function<BrandSurveyStatusRow, T> keyExtractor) {
        var map = new HashMap<T, BrandSurveyStatus>();
        getBrandSurveyStatuseRows(brandSurveyIds)
                .forEach(row -> fillBrandSurveyStatusMap(row, map, keyExtractor));

        return map;
    }

    public List<BrandSurveyStatusRow> getBrandSurveyStatuseRows(Collection<String> brandSurveyIds) {
        Select query = YtDSL.ytContext()
                .select(BRAND_SURVEY_ID, STATE, MODERATION_STATE, MODERATION_REASONS, STOP_REASONS, BASIC_UPLIFT)
                .from(BRANDLIFT_SURVEYS_ALIAS)
                .where(BRANDLIFT_SURVEYS_ALIAS.BRAND_SURVEY_ID.in(brandSurveyIds));
        List<BrandSurveyStatusRow> result = emptyList();
        try (var ignore = Trace.current().profile("brand_survey:yql", "brand_survey_status")) {
            try {
                result = dynContextProvider.getContext()
                        .executeSelect(query)
                        .getYTreeRows()
                        .stream()
                        .map(this::readBrandSurveyStatusRow)
                        .collect(toList());
            } catch (RuntimeException ex) {
                logger.error("Error while getting brand survey statuses for ids " + brandSurveyIds, ex);
            }
        }
        return result;
    }

    public void updateBrandSurveyStates(Collection<BrandSurveyStatusRow> brandSurveyStatusRows) {
        try {
            ytProvider.getDynamicOperator(ytCluster).runInTransaction(tr -> modifyBrandSurveys(tr, brandSurveyStatusRows),
                    new ApiServiceTransactionOptions(ETransactionType.TT_TABLET).setSticky(true));
        } catch (RuntimeException ex) {
            logger.error("Error while updating brand survey statuses for ids " + brandSurveyStatusRows, ex);
            throw ex;
        }
    }

    public void modifyBrandSurveys(ApiServiceTransaction transaction, Collection<BrandSurveyStatusRow> brandSurveyStatuses) {
        ModifyRowsRequest request = new ModifyRowsRequest(tablePath, tableSchema);
        request.setRequireSyncReplica(false);
        for (var brandSurveyStatusRow : brandSurveyStatuses) {
            request.addInsert(
                    List.of(
                            brandSurveyStatusRow.getBrandSurveyId(),
                            brandSurveyStatusRow.getState(),
                            brandSurveyStatusRow.getModerationState(),
                            brandSurveyStatusRow.getModerationReasons(),
                            brandSurveyStatusRow.getBasicUplift(),
                            brandSurveyStatusRow.getStopReasons()
                    )
            );
        }

        transaction.modifyRows(request).join(); // IGNORE-BAD-JOIN DIRECT-149116
    }

    private BrandSurveyStatusRow readBrandSurveyStatusRow(YTreeMapNode row) {
        return new BrandSurveyStatusRow()
                .withBrandSurveyId(row.getString(BRAND_SURVEY_ID.getName()))
                .withState(row.getString(STATE.getName()))
                .withModerationState(row.getStringO(MODERATION_STATE.getName()).orElse((String) null))
                .withModerationReasons(row.get(MODERATION_REASONS.getName()).orElse(null))
                .withStopReasons(row.get(STOP_REASONS.getName()).orElse(null))
                .withBasicUplift(row.get(BASIC_UPLIFT.getName()).orElse(null));
    }

    private <T> void fillBrandSurveyStatusMap(BrandSurveyStatusRow row, Map<T, BrandSurveyStatus> map,
                                              Function<BrandSurveyStatusRow, T> keyExtractor) {
        try {
            map.put(keyExtractor.apply(row), toBrandSurveyStatus(row));
        } catch (RuntimeException ex) {
            logger.error("Couldn't read pythia brand survey status for brand survey with id: " + row.getBrandSurveyId(), ex);
        }
    }

    private static BrandSurveyStatus toBrandSurveyStatus(BrandSurveyStatusRow row) {
        var brandSurveyStatus = new BrandSurveyStatus()
                .withBrandSurveyStopReasonsDaily(emptySet());

        var pythiaSurveyStatus = PythiaSurveyStatus.valueOf(row.getState());
        switch (pythiaSurveyStatus) {
            case DRAFT:
                brandSurveyStatus.setSurveyStatusDaily(SurveyStatus.DRAFT);
                break;
            case ACTIVE:
                brandSurveyStatus.setSurveyStatusDaily(SurveyStatus.ACTIVE);
                break;
            case COMPLETED:
                brandSurveyStatus.setSurveyStatusDaily(SurveyStatus.COMPLETED);
                break;
            case MODERATION:
                brandSurveyStatus.setSurveyStatusDaily(SurveyStatus.MODERATION);
                break;
            case CANCELED:
                brandSurveyStatus.setSurveyStatusDaily(SurveyStatus.UNFEASIBLE);
                break;
            default:
                logger.error("Invalid pythia survey status {}", pythiaSurveyStatus);
        }

        if (row.getModerationState() != null && !row.getModerationState().isEmpty()) {
            var pythiaModerationStatus = PythiaModerationStatus.valueOf(row.getModerationState());
            if (pythiaModerationStatus == PythiaModerationStatus.REJECTED) {
                brandSurveyStatus.setSurveyStatusDaily(SurveyStatus.MODERATION_REJECTED);
            }
        }

        brandSurveyStatus.setReasonIds(convertListNode(row.getModerationReasons(), YTreeNode::longValue));

        List<BrandSurveyStopReason> stopReasons = convertListNode(row.getStopReasons(),
                cell -> BrandSurveyStopReason.valueOf(cell.stringValue().toUpperCase()));

        brandSurveyStatus.setBrandSurveyStopReasonsDaily(Sets.immutableEnumSet(stopReasons));

        if (!stopReasons.isEmpty()) {
            brandSurveyStatus.setSurveyStatusDaily(SurveyStatus.UNFEASIBLE);
        }

        var basicUplift = new BasicUplift();
        if (row.getBasicUplift().isMapNode()) {
            var basicUpliftMap = row.getBasicUplift().asMap();
            var brandSurveyId = row.getBrandSurveyId();
            setMetric(basicUpliftMap, "adRecall", basicUplift::setAdRecall, brandSurveyId);
            setMetric(basicUpliftMap, "brandAwareness", basicUplift::setBrandAwareness, brandSurveyId);
            setMetric(basicUpliftMap, "brandFavorability", basicUplift::setBrandFavorability, brandSurveyId);
            setMetric(basicUpliftMap, "productConsideration", basicUplift::setProductConsideration, brandSurveyId);
            setMetric(basicUpliftMap, "adMessageRecall", basicUplift::setAdMessageRecall, brandSurveyId);
            setMetric(basicUpliftMap, "purchaseIntent", basicUplift::setPurchaseIntent, brandSurveyId);
        }
        brandSurveyStatus.setBasicUplift(basicUplift);
        return brandSurveyStatus;
    }

    private static void setMetric(Map<String, YTreeNode> basicUpliftMap, String metricName, Consumer<Double> setter,
                                  String brandSurveyId) {

        var metricNode = basicUpliftMap.get(metricName);
        if (metricNode == null) {
            return;
        }
        var metricValueNode = metricNode.asMap().get("raise");
        if (metricValueNode == null) {
            return;
        }

        try {
            setter.accept(metricValueNode.doubleValue());
        } catch (Exception ignored) {
            logger.error("Invalid value of metric: {} for brandLift : {}", metricName, brandSurveyId);
        }

    }

    private static <T> List<T> convertListNode(@Nullable YTreeNode node, Function<YTreeNode, T> mapper) {
        if (!(node instanceof YTreeListNodeImpl)) {
            return emptyList();
        }

        try {
            return node.asList().stream().map(mapper).collect(toList());
        } catch (RuntimeException ex) {
            logger.error("Error reading node for brand survey", ex);
        }

        return emptyList();
    }

    private TableSchema getTableSchema() {
        return new TableSchema.Builder()
                .addKey(BRAND_SURVEY_ID.getName(), ColumnValueType.STRING)
                .addValue(STATE.getName(), ColumnValueType.STRING)
                .addValue(MODERATION_STATE.getName(), ColumnValueType.STRING)
                .addValue(MODERATION_REASONS.getName(), ColumnValueType.ANY)
                .addValue(BASIC_UPLIFT.getName(), ColumnValueType.ANY)
                .addValue(STOP_REASONS.getName(), ColumnValueType.ANY)
                .build();
    }
}
