package ru.yandex.infra.sidecars_updater;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

import ru.yandex.bolts.function.Function;
import ru.yandex.infra.controller.dto.StageMeta;
import ru.yandex.infra.controller.yp.YpObject;
import ru.yandex.infra.sidecars_updater.statistics.LabelGroupStatistics;
import ru.yandex.infra.sidecars_updater.statistics.LabelStatistics;
import ru.yandex.infra.sidecars_updater.statistics.StaticStatistics;
import ru.yandex.infra.sidecars_updater.statistics.Statistics;
import ru.yandex.inside.yt.kosher.ytree.YTreeNode;
import ru.yandex.yp.client.api.TStageSpec;
import ru.yandex.yp.client.api.TStageStatus;

public class LabelStatisticsUpdater {
    private static final Set<String> deployUnitLabelsWithInnerLabels = Set.of(
            "du_sidecar_target_revision",
            "du_sidecar_autoupdate_revision"
    );

    private static final Set<String> deployUnitRevisionLabels = Set.of(
            "du_patchers_target_revision",
            "du_patchers_autoupdate_revision"
    );

    private static final Set<String> stageLabels = Set.of(
            "sidecar_revision_update"
    );

    private static final Set<LabelWithValue> oldLabelGlobalStatistics = new HashSet<>();
    private static final Set<String> oldLabelGroupStatistics = new HashSet<>();

    public List<Statistics> getNewLabelStatistics(Map<String, YpObject<StageMeta, TStageSpec, TStageStatus>> stages) {
        List<Statistics> newLabelStatistics = new ArrayList<>();

        stages.forEach((stageName, stage) ->
                stage.getLabels().forEach((labelName, labelMap) -> {
                    if (deployUnitLabelsWithInnerLabels.contains(labelName) && labelMap.isMapNode()) {
//                      labelMap : Map[duName -> Map[innerLabelName -> labelLongVal]]
                        labelMap.mapNode().values().forEach(
                                innerLabelMap -> innerLabelMap.mapNode().forEach(
                                        innerLabelEntry -> updateMetrics(
                                                labelName,
                                                innerLabelEntry.getValue().longValue(),
                                                Optional.of(innerLabelEntry.getKey()),
                                                stageNameLabelsEntry -> stageNameLabelsEntry.getValue().get(labelName).isMapNode() ?
                                                        stageNameLabelsEntry.getValue().get(labelName).asMap().entrySet().stream()
                                                                .filter(entry -> entry.getValue().asMap().containsKey(innerLabelEntry.getKey()))
                                                                .collect(Collectors.toMap(
                                                                        entry -> stageNameLabelsEntry.getKey() + "." + entry.getKey(),
                                                                        entry -> entry.getValue().mapNode().getLong(innerLabelEntry.getKey())
                                                                )) : new HashMap<>(),
                                                StaticStatistics.StatisticsMode.ONLY_DU,
                                                newLabelStatistics
                                        )
                                )
                        );
                    }
                    if (deployUnitRevisionLabels.contains(labelName) && labelMap.isMapNode()) {
//                      labelMap : Map[duName -> labelLongVal]
                        labelMap.mapNode().forEach(
                                labelEntry -> updateMetrics(
                                        labelName,
                                        labelEntry.getValue().longValue(),
                                        Optional.empty(),
                                        stageNameLabelsEntry -> stageNameLabelsEntry.getValue().get(labelName).isMapNode() ?
                                                stageNameLabelsEntry.getValue().get(labelName).asMap().entrySet().stream()
                                                        .collect(Collectors.toMap(
                                                                entry -> stageNameLabelsEntry.getKey() + "." + entry.getKey(),
                                                                entry -> entry.getValue().longValue()
                                                        )) : new HashMap<>(),
                                        StaticStatistics.StatisticsMode.ONLY_DU,
                                        newLabelStatistics
                                )
                        );
                    }
                    if (stageLabels.contains(labelName) && labelMap.isMapNode()) {
//                      labelMap : Map[innerLabelName -> labelLongVal]
                        labelMap.mapNode().forEach(
                                labelEntry -> updateMetrics(
                                        labelName,
                                        labelEntry.getValue().longValue(),
                                        Optional.of(labelEntry.getKey()),
                                        stageNameLabelsEntry -> stageNameLabelsEntry.getValue().get(labelName).isMapNode() &&
                                                stageNameLabelsEntry.getValue().get(labelName).asMap().containsKey(labelEntry.getKey()) ?
                                                new HashMap<>(Map.of(
                                                        stageNameLabelsEntry.getKey(),
                                                        stageNameLabelsEntry.getValue().get(labelName).asMap().get(labelEntry.getKey()).longValue()
                                                )) : new HashMap<>(),
                                        StaticStatistics.StatisticsMode.ONLY_STAGES,
                                        newLabelStatistics
                                )
                        );
                    }
                }));

        return newLabelStatistics;
    }

    private void updateMetrics(String labelName,
                               long labelValue,
                               Optional<String> innerLabelName,
                               Function<Map.Entry<String, Map<String, YTreeNode>>, Map<String, Long>>
                                       stageNameLabelsEntryToStageDuNameLabelValueMapFunction,
                               StaticStatistics.StatisticsMode statisticsMode,
                               List<Statistics> newLabelStatistics) {

        var fullLabelName = innerLabelName.map(s -> labelName + "_" + s).orElse(labelName);
        var labelWithValue = new LabelWithValue(fullLabelName, labelValue);

        if (!oldLabelGlobalStatistics.contains(labelWithValue)) {
            oldLabelGlobalStatistics.add(labelWithValue);
            newLabelStatistics.add(new LabelStatistics(
                    labelWithValue.name + "_" + labelWithValue.value,
                    labelsWithStageName -> stageNameLabelsEntryToStageDuNameLabelValueMapFunction
                            .apply(labelsWithStageName).entrySet().stream()
                            .collect(Collectors.toMap(
                                    Map.Entry::getKey,
                                    entry -> entry.getValue() == labelValue ? 1 : 0
                            )),
                    statisticsMode
            ));
        }

        if (!oldLabelGroupStatistics.contains(fullLabelName)) {
            oldLabelGroupStatistics.add(fullLabelName);
            newLabelStatistics.add(new LabelGroupStatistics(
                    fullLabelName,
                    labelsWithStageName -> stageNameLabelsEntryToStageDuNameLabelValueMapFunction
                            .apply(labelsWithStageName).entrySet().stream()
                            .collect(Collectors.toMap(
                                    Map.Entry::getKey,
                                    Map.Entry::getValue
                            )),
                    statisticsMode
            ));
        }
    }

    private static class LabelWithValue {
        final String name;
        final long value;

        public LabelWithValue(String name, long value) {
            this.name = name;
            this.value = value;
        }

        @Override
        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || getClass() != o.getClass()) {
                return false;
            }
            LabelWithValue that = (LabelWithValue) o;
            return value == that.value && Objects.equals(name, that.name);
        }

        @Override
        public int hashCode() {
            return Objects.hash(name, value);
        }
    }
}
