package ru.yandex.solomon.alert.cost;

import java.time.Duration;
import java.time.Instant;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Collectors;

import javax.annotation.Nonnull;
import javax.annotation.ParametersAreNonnullByDefault;
import javax.annotation.WillNotClose;
import javax.annotation.concurrent.ThreadSafe;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import ru.yandex.metabase.client.MetabaseClient;
import ru.yandex.misc.concurrent.CompletableFutures;
import ru.yandex.solomon.alert.domain.Alert;
import ru.yandex.solomon.alert.domain.expression.ExpressionAlert;
import ru.yandex.solomon.alert.domain.threshold.ThresholdAlert;
import ru.yandex.solomon.alert.rule.ProgramCompiler;
import ru.yandex.solomon.codec.serializer.StockpileFormat;
import ru.yandex.solomon.expression.analytics.PreparedProgram;
import ru.yandex.solomon.expression.analytics.Program;
import ru.yandex.solomon.expression.compile.DeprOpts;
import ru.yandex.solomon.labels.protobuf.LabelSelectorConverter;
import ru.yandex.solomon.labels.query.Selectors;
import ru.yandex.solomon.math.protobuf.Aggregation;
import ru.yandex.solomon.metabase.api.protobuf.EMetabaseStatusCode;
import ru.yandex.solomon.metabase.api.protobuf.FindRequest;
import ru.yandex.solomon.metabase.api.protobuf.Metric;
import ru.yandex.solomon.metrics.client.TimeSeriesCodec;
import ru.yandex.solomon.model.point.AggrPoint;
import ru.yandex.solomon.model.protobuf.MetricId;
import ru.yandex.solomon.model.timeseries.AggrGraphDataListIterator;
import ru.yandex.solomon.util.time.Interval;
import ru.yandex.stockpile.api.EStockpileStatusCode;
import ru.yandex.stockpile.api.TCompressedReadResponse;
import ru.yandex.stockpile.api.TReadRequest;
import ru.yandex.stockpile.client.StockpileClient;

/**
 * @author Vladimir Gordiychuk
 */
@ThreadSafe
@ParametersAreNonnullByDefault
public class AlertCostEstimatorImpl implements AlertCostEstimator {
    private static final Logger logger = LoggerFactory.getLogger(AlertCostEstimatorImpl.class);

    @WillNotClose
    private final MetabaseClient metabase;
    // TODO: drop use stockpile for estimate cost, all required information should be extracted from metabase
    @WillNotClose
    private final StockpileClient stockpile;
    private final EstimationOptions options;

    public AlertCostEstimatorImpl(
            @WillNotClose MetabaseClient metabase,
            @WillNotClose StockpileClient stockpile,
            EstimationOptions options) {
        this.options = options;
        this.metabase = metabase;
        this.stockpile = stockpile;
    }

    @Nonnull
    @Override
    public CompletableFuture<Double> estimateCost(Instant now, Alert alert, Instant deadline) {
        final CompletableFuture<Double> result;
        switch (alert.getAlertType()) {
            case THRESHOLD:
                result = estimateCostOfThresholdAlert(now, (ThresholdAlert) alert, deadline);
                break;
            case EXPRESSION:
                result = estimateCostOfExpressionAlert(now, (ExpressionAlert) alert, deadline);
                break;
            default:
                throw new UnsupportedOperationException("Unsupported alert type: " + alert);
        }

        return result.whenComplete((cost, throwable) -> {
            if (throwable != null) {
                logger.error("Failed complexity estimation for alert {} with cause {}", alert, throwable.getMessage(), throwable);
                return;
            }

            logger.debug("Estimated cost {} for alert {}", cost, alert);
        });
    }

    public CompletableFuture<Double> estimateCostOfExpressionAlert(Instant now, ExpressionAlert alert, Instant deadline) {
        Program program = Program.fromSource(ProgramCompiler.ALERTING_SEL_VERSION, alert.getCombinedSource())
                .withDeprOpts(DeprOpts.ALERTING)
                .compile();

        // TODO: estimate expression complexity (gordiychuk@)
        PreparedProgram preparedProgram = program.prepare(new Interval(now.minus(alert.getPeriod()), now));
        return CompletableFutures.allOf(preparedProgram.getLoadRequests()
                .stream()
                .map(loadRequest -> estimateCostOfMetrics(now, alert.getPeriod(), deadline, loadRequest.getSelectors()))
                .collect(Collectors.toList()))
                .thenApply(doubles -> doubles.stream()
                        .mapToDouble(value -> value)
                        .sum() * 2);
    }

    @Nonnull
    private CompletableFuture<Double> estimateCostOfThresholdAlert(Instant now, ThresholdAlert alert, Instant deadline) {
        return estimateCostOfMetrics(now, alert.getPeriod(), deadline, alert.getSelectors());
    }

    @Nonnull
    private CompletableFuture<Double> estimateCostOfMetrics(Instant now, Duration period, Instant deadline, Selectors selectors) {
        return resolveMetrics(selectors, deadline)
                .thenCompose(metrics -> CompletableFutures.allOf(metrics.stream()
                        .map(metric -> estimateCountPointsOnPeriod(now, period, deadline, metric))
                        .collect(Collectors.toList())))
                .thenApply(pointCounts -> {
                    int metricCount = pointCounts.size();
                    if (metricCount == 0) {
                        metricCount = options.getApproximateCountMetricsBySelector();
                    }

                    double pointsCost = 0;
                    for (int index = 0; index < metricCount; index++) {
                        int count;
                        if (index >= pointCounts.size() || pointCounts.get(index) == 0) {
                            count = Math.toIntExact(Math.round(
                                (double) period.toMillis() / options.getApproximatePointStepMillis()));
                        } else {
                            count = pointCounts.get(index);
                        }

                        pointsCost += count * options.getPointCost();
                    }

                    return metricCount * options.getMetricLoadCost() + pointsCost;
                });
    }

    @Nonnull
    private CompletableFuture<List<Metric>> resolveMetrics(Selectors selectors, Instant deadline) {
        // TODO: retry until deadline (gordiychuk@)
        FindRequest.Builder request = FindRequest.newBuilder()
            .addAllSelectors(LabelSelectorConverter.selectorsToProto(selectors))
            .setDeadlineMillis(deadline.toEpochMilli());
        return metabase.find(request.build())
                .thenApply((response) -> {
                    if (response.getStatus() != EMetabaseStatusCode.OK) {
                        logger.error("Not able estimate count metrics for selector {} because request to metabase failed {}: {}",
                                selectors, response.getStatus(), response.getStatusMessage()
                        );

                        return Collections.emptyList();
                    }

                    return response.getMetricsList();
                });
    }

    private CompletableFuture<Integer> estimateCountPointsOnPeriod(Instant now, Duration period, Instant deadline, Metric metric) {
        return stockpile.readCompressedOne(TReadRequest.newBuilder()
                .setMetricId(metric.getMetricId())
                .setBinaryVersion(StockpileFormat.CURRENT.getFormat())
                .setFromMillis(now.minus(period).toEpochMilli())
                .setToMillis(now.toEpochMilli())
                .setGridMillis(period.toMillis())
                .setDeadline(deadline.toEpochMilli())
                .setAggregation(Aggregation.COUNT)
                .build())
                .thenApply(this::estimateCountPointsFromResponse);
    }

    private int estimateCountPointsFromResponse(TCompressedReadResponse response) {
        if (response.getStatus() != EStockpileStatusCode.OK) {
            MetricId metricId = response.getMetricId();
            logger.error("Not able estimate count points for metric {}/{} because request failed {}: {}",
                    metricId.getShardId(), metricId.getLocalId(),
                    response.getStatus(), response.getStatusMessage()
            );

            return 0;
        }

        double countSum = 0;
        AggrPoint tempPoint = new AggrPoint();
        AggrGraphDataListIterator iterator = TimeSeriesCodec.sequenceDecode(response).iterator();
        // Stockpile after truncate time can return more then one point as a down sampling result
        while (iterator.next(tempPoint)) {
            countSum += tempPoint.getValueDivided();
        }

        return (int) countSum;
    }
}
