package ru.yandex.solomon.alert.rule.threshold;

import java.time.Duration;
import java.time.Instant;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.function.DoublePredicate;
import java.util.function.ToDoubleFunction;
import java.util.stream.Collectors;

import javax.annotation.Nullable;
import javax.annotation.ParametersAreNonnullByDefault;

import com.google.common.math.LongMath;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import ru.yandex.misc.concurrent.CompletableFutures;
import ru.yandex.solomon.alert.EvaluationStatus;
import ru.yandex.solomon.alert.domain.Alert;
import ru.yandex.solomon.alert.domain.SubAlert;
import ru.yandex.solomon.alert.domain.threshold.PredicateRule;
import ru.yandex.solomon.alert.domain.threshold.TargetStatus;
import ru.yandex.solomon.alert.domain.threshold.ThresholdAlert;
import ru.yandex.solomon.alert.domain.threshold.ThresholdType;
import ru.yandex.solomon.alert.rule.AlertRuleDeadlines;
import ru.yandex.solomon.alert.rule.AlertTimeSeries;
import ru.yandex.solomon.alert.rule.ResultOrProceed;
import ru.yandex.solomon.alert.rule.SimulationResult;
import ru.yandex.solomon.alert.rule.SimulationStatus;
import ru.yandex.solomon.alert.rule.SingleTimeSeries;
import ru.yandex.solomon.common.RequestProducer;
import ru.yandex.solomon.expression.analytics.GraphDataLoadRequest;
import ru.yandex.solomon.expression.exceptions.SelException;
import ru.yandex.solomon.labels.query.Selectors;
import ru.yandex.solomon.math.protobuf.Aggregation;
import ru.yandex.solomon.math.protobuf.Operation;
import ru.yandex.solomon.math.protobuf.OperationDownsampling;
import ru.yandex.solomon.metrics.client.MetabaseStatus;
import ru.yandex.solomon.metrics.client.MetricsClient;
import ru.yandex.solomon.metrics.client.StockpileStatus;
import ru.yandex.solomon.metrics.client.combined.FindAndReadManyRequest;
import ru.yandex.solomon.metrics.client.combined.FindAndReadManyResponse;
import ru.yandex.solomon.metrics.client.exceptions.TooManyMetricsLoadedBySelectors;
import ru.yandex.solomon.model.point.AggrPoint;
import ru.yandex.solomon.model.point.RecyclableAggrPoint;
import ru.yandex.solomon.model.point.column.StockpileColumn;
import ru.yandex.solomon.model.point.column.StockpileColumnSet;
import ru.yandex.solomon.model.protobuf.MetricType;
import ru.yandex.solomon.model.timeseries.AggrGraphDataArrayList;
import ru.yandex.solomon.model.timeseries.AggrGraphDataIterable;
import ru.yandex.solomon.model.timeseries.AggrGraphDataListIterator;
import ru.yandex.solomon.util.time.InstantUtils;
import ru.yandex.solomon.util.time.Interval;
import ru.yandex.stockpile.api.EStockpileStatusCode;

import static java.util.Comparator.comparing;
import static java.util.concurrent.CompletableFuture.completedFuture;
import static java.util.stream.Collectors.toList;

/**
 * @author Ivan Tsybulin
 */
@ParametersAreNonnullByDefault
public class ThresholdAlertSimulator {
    private static final Logger logger = LoggerFactory.getLogger(ThresholdAlertSimulator.class);
    private static final StockpileColumnSet SIMULATION_RESULT_COLUMN_SET = StockpileColumnSet.fromColumnsVa(StockpileColumn.TS, StockpileColumn.VALUE);

    private final Alert alert;
    private final Selectors selectors;
    private final Transformer transformer;
    private final MetricsClient cachingMetricsClient;

    private static class DownsamplingWrapper {
        private final Aggregation downsampling;

        DownsamplingWrapper(Aggregation downsampling) {
            this.downsampling = downsampling;
        }

        DownsamplingAndAggregation aggregating(ThresholdType aggr) {
            return new DownsamplingAndAggregation(downsampling, aggr);
        }
    }

    private static DownsamplingWrapper downsampling(Aggregation downsampling) {
        return new DownsamplingWrapper(downsampling);
    }

    private static final Map<ThresholdType, DownsamplingAndAggregation> SUPPORTED_AGGREGATIONS = Map.of(
        ThresholdType.LAST_NON_NAN, downsampling(Aggregation.LAST).aggregating(ThresholdType.LAST_NON_NAN),
        ThresholdType.AVG, downsampling(Aggregation.AVG).aggregating(ThresholdType.AVG),
        ThresholdType.MIN, downsampling(Aggregation.MIN).aggregating(ThresholdType.MIN),
        ThresholdType.MAX, downsampling(Aggregation.MAX).aggregating(ThresholdType.MAX),
        ThresholdType.SUM, downsampling(Aggregation.SUM).aggregating(ThresholdType.SUM),
        ThresholdType.COUNT, downsampling(Aggregation.COUNT).aggregating(ThresholdType.SUM));


    public ThresholdAlertSimulator(Alert alert, Selectors selectors, Transformer transformer, MetricsClient cachingMetricsClient) {
        this.alert = alert;
        this.selectors = selectors;
        this.transformer = transformer;
        this.cachingMetricsClient = cachingMetricsClient;
    }

    public CompletableFuture<SimulationResult> simulate(Instant from, Instant to, Duration gridStep, AlertRuleDeadlines deadlines) {
        List<PredicateRule> predicateRules = resolvePredicateRules(alert);

        SimulationResult simulationResult = validatePredicateRules(predicateRules);
        if (simulationResult != null) {
            return completedFuture(simulationResult);
        }
        DownsamplingAndAggregation downsamplingAndAggregation = SUPPORTED_AGGREGATIONS.get(predicateRules.get(0).getThresholdType());

        SimulationParams params = new SimulationParams(alert, downsamplingAndAggregation, gridStep, from, to);

        return CompletableFutures.safeCall(() -> simulate(params, deadlines))
            .exceptionally(e -> {
                SimulationResult result = classifySimulationError(e);
                logger.error("{} - for alert {} caused by {}", result.getCode(), alert.getId(), e.getMessage(), e);
                return result;
            });

    }

    private static class SimulationParams {
        private final Interval loadInterval;
        private final Aggregation aggregation;
        private final long downsampleMillis;
        private final long gridMillis;
        private final long delayMillis;
        private final long fromMillis;
        private final long toMillis;
        private final long windowMillis;
        private final ToDoubleFunction<Collection<? extends AggrPoint>> windowAggr;

        private static long gcd3(long a, long b, long c) {
            return LongMath.gcd(a, LongMath.gcd(b, c));
        }

        public SimulationParams(
                Alert alert,
                DownsamplingAndAggregation downsamplingAndAggregation,
                Duration gridStep,
                Instant from,
                Instant to)
        {
            /* All markers are aligned by multiple of quant (15s at the moment)
             * Downsampling is always done in multiple of quant units (never less than 1 quant)
             *
             * E.g. gridStep = 4m, period = 3m, delay = 2m. Then downsampling interval will be 1 minute (4 quants)
             *
             *       nowDelayed - gridStep                                     nowDelayed                 now
             * ----------------|----------------------------------------------------|----------------------|--
             *                      nowDelayed - period
             * ------------------------------|--------------------------------------|-------------------------
             *                                    Downsampling grid
             * ----------------|-------------|------------|------------|------------|-------------------------
             *     Aggregating               ^------------^------------^                                   ^
             *                                            |                                                |
             *     Writing to                             \------------------------------------------------/
             */
            Duration quant = Duration.ofSeconds(15);

            long windowQuants = Math.max(alert.getPeriod().dividedBy(quant), 1L);
            long gridQuants = Math.max(gridStep.dividedBy(quant), 1L);
            long delayQuants = Duration.ofSeconds(alert.getDelaySeconds()).dividedBy(quant);
            long quantMillis = quant.toMillis();

            long commonIntervalQuants = gcd3(windowQuants, gridQuants, delayQuants);
            this.downsampleMillis = commonIntervalQuants * quantMillis;

            this.aggregation = downsamplingAndAggregation.getDownsampling();
            this.windowAggr = downsamplingAndAggregation.getWindowAggregation();

            this.windowMillis = windowQuants * quantMillis;
            this.gridMillis = gridQuants * quantMillis;
            this.delayMillis = delayQuants * quantMillis;
            this.fromMillis = InstantUtils.truncate(from.toEpochMilli(), gridMillis);
            this.toMillis = InstantUtils.ceil(to.toEpochMilli(), gridMillis);

            Instant intervalFrom = Instant.ofEpochMilli(fromMillis - windowMillis - delayMillis);
            Instant intervalTo = Instant.ofEpochMilli(toMillis - delayMillis);
            this.loadInterval = new Interval(intervalFrom, intervalTo);
        }
    }

    private CompletableFuture<SimulationResult> simulate(
            SimulationParams params,
            AlertRuleDeadlines deadlines)
    {
        final PreparedTransformer prepared = transformer.prepare(params.loadInterval);

        GraphDataLoadRequest loadRequest = prepared.getLoadRequest().toBuilder()
                .setSelectors(selectors)
                .build();

        return loadManyTimeSeries(loadRequest, params, deadlines)
                .thenApply(resultOrProceed -> resultOrProceed.terminate(loadedSeries ->
                        simulateLoaded(prepared, params, loadedSeries)
                ));
    }

    private SimulationResult simulateLoaded(
            PreparedTransformer prepared,
            SimulationParams params,
            List<SingleTimeSeries> vectorOfTimeSeries)
    {
        List<AlertTimeSeries> transformedAndSmoothed = prepared.transform(vectorOfTimeSeries).stream()
                .map(transformedTimeseries -> windowSmooth(
                        transformedTimeseries,
                        params.windowAggr,
                        params.windowMillis, params.delayMillis,
                        params.fromMillis, params.toMillis, params.gridMillis))
                .collect(Collectors.toList());
        var statuses = computeStatuses(transformedAndSmoothed);
        return new SimulationResult(SimulationStatus.OK, "", transformedAndSmoothed, statuses);
    }

    private CompletableFuture<ResultOrProceed<SimulationResult, List<SingleTimeSeries>>> loadManyTimeSeries(
            GraphDataLoadRequest loadRequest,
            SimulationParams params,
            AlertRuleDeadlines deadlines)
    {
        FindAndReadManyRequest request = FindAndReadManyRequest.newBuilder()
                .setSelectors(selectors)
                .setMetabaseLimit(alert.getMetricsLimit())
                .setFromMillis(loadRequest.getInterval().getBeginMillis())
                .setToMillis(loadRequest.getInterval().getEndMillis())
                .addOperation(Operation.newBuilder()
                        .setDownsampling(OperationDownsampling.newBuilder()
                                .setFillOption(OperationDownsampling.FillOption.NONE)
                                .setAggregation(params.aggregation)
                                .setGridMillis(params.downsampleMillis)
                        )
                        .build())
                .setProducer(RequestProducer.SYSTEM)
                .setSoftDeadline(deadlines.softResolveDeadline())
                .setSoftReadDeadline(deadlines.softReadDeadline())
                .setDeadline(deadlines.hardDeadline())
                .build();

        return cachingMetricsClient.findAndReadMany(request)
                .handle((response, ex) -> {
                    SimulationResult checkResponseResult = checkReadManyResponse(response, ex);

                    if (checkResponseResult != null) {
                        return ResultOrProceed.ready(checkResponseResult);
                    }

                    return response.getMetrics().stream()
                            .map(metric -> new SingleTimeSeries(metric.getKey(), metric.getType(), metric.getTimeseries()))
                            .collect(Collectors.collectingAndThen(toList(), ResultOrProceed::proceed));
                });
    }

    private static class ThresholdComparator {
        private final DoublePredicate alarmPredicate;
        private final DoublePredicate warnPredicate;

        ThresholdComparator(List<PredicateRule> rules) {
            alarmPredicate = WindowCheckFunctionFactory.makeDoublePredicate(rules.get(0));
            if (rules.size() > 1) {
                warnPredicate = WindowCheckFunctionFactory.makeDoublePredicate(rules.get(1));
            } else {
                warnPredicate = ignore -> false;
            }
        }

        private EvaluationStatus.Code check(double value) {
            if (Double.isNaN(value)) {
                // TODO(uranix): honor no points policy
                return EvaluationStatus.Code.NO_DATA;
            }
            if (alarmPredicate.test(value)) {
                return EvaluationStatus.Code.ALARM;
            }
            if (warnPredicate.test(value)) {
                return EvaluationStatus.Code.WARN;
            }
            return EvaluationStatus.Code.OK;
        }

        List<EvaluationStatus.Code> check(List<AggrGraphDataIterable> iterables) {
            if (iterables.isEmpty()) {
                return List.of();
            }

            int recordCount = iterables.stream().map(AggrGraphDataIterable::getRecordCount).max(Integer::compareTo).orElse(0);
            ArrayList<EvaluationStatus.Code> statuses = new ArrayList<>(recordCount);
            List<RecyclableAggrPoint> points = iterables.stream().map(ignore -> RecyclableAggrPoint.newInstance()).collect(toList());
            List<AggrGraphDataListIterator> iterators = iterables.stream().map(AggrGraphDataIterable::iterator).collect(toList());

            while (true) {
                boolean ok = true;
                for (int i = 0; i < iterables.size(); i++) {
                    ok &= iterators.get(i).next(points.get(i));
                }
                if (!ok) {
                    break;
                }
                var reducedStatus = points.stream()
                    .map(point -> check(point.getValueDivided()))
                    .max(comparing(EvaluationStatus.Code::ordinal))
                    .orElse(EvaluationStatus.Code.OK);

                statuses.add(reducedStatus);
            }

            for (var point : points) {
                point.recycle();
            }

            return statuses;
        }
    }

    private List<EvaluationStatus.Code> computeStatuses(List<AlertTimeSeries> transformedAndSmoothed) {
        ThresholdComparator comparator = new ThresholdComparator(resolvePredicateRules(alert));

        return comparator.check(transformedAndSmoothed.stream().map(AlertTimeSeries::getSource).collect(toList()));
    }

    private static class PeekableIterator {
        private final AggrGraphDataListIterator iter;
        @Nullable
        private RecyclableAggrPoint point;

        PeekableIterator(AggrGraphDataListIterator iterator) {
            this.iter = iterator;
            this.point = RecyclableAggrPoint.newInstance();
        }

        @Nullable
        AggrPoint peek() {
            if (point == null) {
                return point;
            }
            if (point.tsMillis == 0) {
                advance();
                return point;
            }
            return point;
        }

        void advance() {
            if (point == null) {
                throw new IndexOutOfBoundsException("Iterator end is reached");
            }
            boolean ok = iter.next(point);
            if (!ok) {
                point.recycle();
                point = null;
            }
        }
    }

    private AlertTimeSeries windowSmooth(
            AlertTimeSeries transformedTimeseries,
            ToDoubleFunction<Collection<? extends AggrPoint>> aggregate,
            long windowMillis, long delayMillis,
            long fromMillis, long toMillis, long gridMillis)
    {
        AggrGraphDataIterable source = transformedTimeseries.getSource();
        PeekableIterator iter = new PeekableIterator(source.iterator());

        int capacity = Math.toIntExact((toMillis - fromMillis + gridMillis - 1) / gridMillis);
        AggrGraphDataArrayList result = new AggrGraphDataArrayList(SIMULATION_RESULT_COLUMN_SET.columnSetMask(), capacity);

        ArrayDeque<RecyclableAggrPoint> window = new ArrayDeque<>();

        for (long ts = fromMillis; ts < toMillis; ts += gridMillis) {
            long loadFrom = ts - windowMillis - delayMillis;
            long loadTo = ts - delayMillis;

            while (!window.isEmpty() && window.getFirst().getTsMillis() < loadFrom) {
                window.removeFirst().recycle();
            }

            AggrPoint point;

            while ((point = iter.peek()) != null && point.getTsMillis() < loadTo) {
                if (point.getTsMillis() >= loadFrom) {
                    RecyclableAggrPoint add = RecyclableAggrPoint.newInstance();
                    add.setTsMillis(point.tsMillis);
                    add.setValue(point.getValueDivided());
                    window.addLast(add);
                }
                iter.advance();
            }

            result.addRecordShort(ts, aggregate.applyAsDouble(window));
        }

        for (RecyclableAggrPoint point : window) {
            point.recycle();
        }

        return new AlertTimeSeries(
                transformedTimeseries.getAlias(),
                transformedTimeseries.getLabels(),
                MetricType.DGAUGE,
                result);
    }

    private SimulationResult tooManyMetrics() {
        return SimulationStatus.TOO_MANY_METRICS
                .withMessage("Max metrics to load limit exceeded: " + alert.getMetricsLimit());
    }

    @Nullable
    private SimulationResult checkReadManyResponse(FindAndReadManyResponse response, @Nullable Throwable throwable) {
        if (throwable != null) {
            Throwable ex = CompletableFutures.unwrapCompletionException(throwable);
            if (ex instanceof TooManyMetricsLoadedBySelectors) {
                return tooManyMetrics();
            }
            if (ex instanceof RuntimeException) {
                throw (RuntimeException) ex;
            }
            throw new RuntimeException(ex);
        }
        MetabaseStatus metaStatus = response.getMetaStatus();
        if (!metaStatus.isOkOrNoData()) {
            return SimulationStatus.DATA_LOAD_ERROR.withMessage(metaStatus.toString());
        }

        StockpileStatus storageStatus = response.getStorageStatus();
        if (storageStatus.getCode() != EStockpileStatusCode.OK) {
            return SimulationStatus.DATA_LOAD_ERROR.withMessage(storageStatus.toString());
        }

        // TODO(uranix): honor No metrics policy
        if (response.getMetrics().isEmpty()) {
            return SimulationStatus.NO_METRICS.withMessage("No metrics by selectors " + selectors);
        }

        if (response.getMetrics().size() > alert.getMetricsLimit()) {
            return tooManyMetrics();
        }

        return null;
    }

    // TODO(uranix): move to generic cloud alerting code
    @Nullable
    public static SimulationResult validatePredicateRules(List<PredicateRule> rules) {
        final SimulationStatus unsupported = SimulationStatus.UNSUPPORTED;

        if (rules.size() < 1 || rules.size() > 2) {
            return unsupported.withMessage("Incorrect number of predicate rules for Threshold alert");
        }

        PredicateRule mainRule = rules.get(0);

        if (mainRule.getTargetStatus() != TargetStatus.ALARM) {
            return unsupported.withMessage("First predicate rule must have ALARM target status");
        }

        if (rules.size() > 1) {
            PredicateRule extraRule = rules.get(1);
            if (extraRule.getTargetStatus() != TargetStatus.WARN) {
                return unsupported.withMessage("Second predicate rule must have WARN target status");
            }
            if (mainRule.getComparison() != extraRule.getComparison()) {
                return unsupported.withMessage("Rules differ in comparison functions");
            }
            if (mainRule.getThresholdType() != extraRule.getThresholdType()) {
                return unsupported.withMessage("Rules differ in aggregation function");
            }
        }

        if (!SUPPORTED_AGGREGATIONS.containsKey(mainRule.getThresholdType())) {
            return unsupported.withMessage("Aggregation function " + mainRule.getThresholdType() + " is not supported");
        }

        return null;
    }

    private static List<PredicateRule> resolvePredicateRules(Alert alert) {
        if (alert instanceof ThresholdAlert) {
            return ((ThresholdAlert) alert).getPredicateRules();
        }
        if (alert instanceof SubAlert) {
            return resolvePredicateRules(((SubAlert) alert).getParent());
        }
        throw new IllegalArgumentException("Non-threshold alert in ThresholdAlertRule");
    }

    private SimulationResult classifySimulationError(Throwable e) {
        Throwable cause = CompletableFutures.unwrapCompletionException(e);
        if (cause instanceof SelException exception) {
            return SimulationStatus.INVALID_REQUEST.withMessage(exception.getErrorMessage());
        }
        logger.error("Unhandled simulation error", cause);
        return SimulationStatus.GENERIC_ERROR.withMessage(cause.getMessage());
    }

}
