package ru.yandex.solomon.gateway.data;

import java.time.Instant;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Collectors;

import javax.annotation.Nullable;
import javax.annotation.ParametersAreNonnullByDefault;

import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import ru.yandex.discovery.cluster.ClusterMapper;
import ru.yandex.misc.concurrent.CompletableFutures;
import ru.yandex.misc.dataSize.DataSize;
import ru.yandex.misc.thread.WhatThreadDoes;
import ru.yandex.monlib.metrics.labels.Labels;
import ru.yandex.monlib.metrics.registry.MetricRegistry;
import ru.yandex.solomon.common.RequestProducer;
import ru.yandex.solomon.core.conf.watch.SolomonConfHolder;
import ru.yandex.solomon.core.exceptions.BadRequestException;
import ru.yandex.solomon.expression.NamedGraphData;
import ru.yandex.solomon.expression.analytics.GraphDataLoadRequest;
import ru.yandex.solomon.expression.expr.ProgramType;
import ru.yandex.solomon.expression.value.SelValue;
import ru.yandex.solomon.flags.FeatureFlagsHolder;
import ru.yandex.solomon.labels.LabelKeys;
import ru.yandex.solomon.labels.query.Selector;
import ru.yandex.solomon.labels.query.Selectors;
import ru.yandex.solomon.math.doubles.AggregateFunctionType;
import ru.yandex.solomon.math.operation.Metric;
import ru.yandex.solomon.math.protobuf.Aggregation;
import ru.yandex.solomon.math.protobuf.Operation;
import ru.yandex.solomon.math.protobuf.OperationDownsampling;
import ru.yandex.solomon.math.protobuf.OperationTop;
import ru.yandex.solomon.metrics.client.MetricsClient;
import ru.yandex.solomon.metrics.client.combined.CombinedCall;
import ru.yandex.solomon.metrics.client.combined.DataLimits;
import ru.yandex.solomon.metrics.client.combined.DataRequestType;
import ru.yandex.solomon.metrics.client.combined.FindAndReadManyRequest;
import ru.yandex.solomon.metrics.client.combined.OldModeResult;
import ru.yandex.solomon.metrics.client.exceptions.CannotLoadMetricsFromOtherProject;
import ru.yandex.solomon.metrics.client.exceptions.TooManyMetricsLoaded;
import ru.yandex.solomon.metrics.client.exceptions.TooManyMetricsReturnedPerGraph;
import ru.yandex.solomon.model.MetricKey;
import ru.yandex.solomon.model.protobuf.MetricType;
import ru.yandex.solomon.model.timeseries.AggrGraphDataArrayList;
import ru.yandex.solomon.model.timeseries.AggrGraphDataArrayListOrView;
import ru.yandex.solomon.model.timeseries.MetricTypeTransfers;
import ru.yandex.solomon.model.timeseries.MetricTypes;
import ru.yandex.solomon.model.timeseries.aggregation.TimeseriesSummary;
import ru.yandex.solomon.util.time.Interval;

import static ru.yandex.solomon.metrics.client.ResponseValidationUtils.ensureMetabaseStatusValid;
import static ru.yandex.solomon.metrics.client.ResponseValidationUtils.ensureStockpileStatusValid;

/**
 * @author Oleg Baryshnikov
 */
@ParametersAreNonnullByDefault
public class DataClient {
    private final Logger logger = LoggerFactory.getLogger(DataClient.class);

    private final DataClientMetrics metrics;
    private final MetricsClient metricsClient;
    private final ClusterMapper clusterMapper;
    private final ExpressionCompiler compiler;
    private final DataClientRequestCustomizer customizer;

    public DataClient(
        DataClientMetrics metrics,
        MetricsClient metricsClient,
        ClusterMapper clusterMapper,
        MetricRegistry metricRegistry,
        SolomonConfHolder solomonConfHolder,
        FeatureFlagsHolder featureFlagsHolder)
    {
        this.metrics = metrics;
        this.metricsClient = metricsClient;
        this.clusterMapper = clusterMapper;
        this.customizer = new DataClientRequestCustomizer(solomonConfHolder, featureFlagsHolder);
        this.compiler = new ExpressionCompilerImpl(featureFlagsHolder, metricRegistry);
    }

    public CompletableFuture<DataResponse> readData(DataRequest request) {
        DataClientMetrics.ProjectMetrics projectMetrics =
            metrics.getProjectMetrics(request.getProjectId());

        projectMetrics.callStarted();

        long startTimeNanos = System.nanoTime();

        CompletableFuture<DataResponse> future = CompletableFutures.safeCall(
                () -> loadDataByExpression(request));

        return future.whenComplete((response, throwable) -> {
            long endTimeNanos = System.nanoTime();
            long durationNanos = endTimeNanos - startTimeNanos;

            if (throwable != null) {
                projectMetrics.callFailed(durationNanos);
                logger.warn("failed to implement data request {}", request, throwable);
            } else {
                projectMetrics.callCompleted(durationNanos);
            }
        });
    }

    private CompletableFuture<DataResponse> loadDataByExpression(DataRequest request) {
        String destination = clusterMapper.byParamOrNull(request.getForceCluster());

        var program = compiler.compile(request);

        var selectors = new ArrayList<Selectors>(program.loadRequests().size());
        long minBeginMillis = request.getInterval().getBeginMillis();
        for (var loadRequest : program.loadRequests()) {
            if (loadRequest.getInterval().getBeginMillis() < minBeginMillis) {
                minBeginMillis = loadRequest.getInterval().getBeginMillis();
            }
            selectors.add(overrideSelectors(loadRequest.getSelectors(), request.getProjectId(), request.getClusterName()));
        }
        var downsamplingOpts = customizer.customizeDownsamplingOpts(request.getDownsampling(), request.getInterval(), selectors, minBeginMillis);

        if (request.isSummaryOnly() && program.type() == ProgramType.COMPLEX) {
            throw new BadRequestException("Cannot show summary for complex expressions");
        }

        var loadFuture = program.loadRequests()
                .stream()
                .map(loadRequest -> loadDataByGraphDataLoadRequest(
                        request.getProjectId(),
                        request.getClusterName(),
                        loadRequest,
                        program.type(),
                        downsamplingOpts,
                        destination,
                        request.isUseNewFormat(),
                        request.isOldMode(),
                        request.isSummaryOnly(),
                        request.getDeadline(),
                        request.getProducer(),
                        request.getSubjectId()))
                .collect(Collectors.collectingAndThen(Collectors.toList(), CompletableFutures::allOf));

        return loadFuture.thenApply(results -> {
            var h = WhatThreadDoes.push("Evaluate program " + request.getProjectId() + " " +
                    request.getClusterName() + " req " + request);
            try {
                return processLoadResults(program, downsamplingOpts, results);
            } finally {
                h.popSafely();
            }
        });
    }

    private DataResponse processLoadResults(ExpressionProgram program, DownsamplingOptions downsamplingOpts, List<LoadRequestAndResult> loadResults) {
        Map<GraphDataLoadRequest, List<NamedGraphData>> resultByRequest = new HashMap<>(loadResults.size());

        OldModeResult oldModeResult = loadResults.stream()
                .map(LoadRequestAndResult::getOldModeResult)
                .reduce(OldModeResult.DEFAULT, (l, r) -> new OldModeResult(
                        l.isTruncated() || r.isTruncated(),
                        l.isSummary() || r.isSummary()));

        for (LoadRequestAndResult loadResult : loadResults) {
            resultByRequest.put(loadResult.getRequest(), loadResult.getSeries());
        }

        int metricsSum = loadResults.stream()
                .mapToInt(result -> result.getSeries().size())
                .sum();

        if (metricsSum > DataLimits.MAX_METRICS_FOR_AGGR_COUNT) {
            throw new TooManyMetricsLoaded(DataLimits.MAX_METRICS_FOR_AGGR_COUNT);
        }

        PreloadedGraphDataLoader loader = new PreloadedGraphDataLoader(resultByRequest, downsamplingOpts);
        SelValue evalResult = program.evaluate(loader);

        if (!oldModeResult.isSummary()) {
            int resultMetricsCount = computeSelValueMetricsCount(evalResult);

            if (resultMetricsCount > DataLimits.MAX_METRICS_COUNT) {
                throw new TooManyMetricsReturnedPerGraph(resultMetricsCount, DataLimits.MAX_METRICS_COUNT);
            }
        }

        return new DataResponse(evalResult, oldModeResult);
    }

    private static int computeSelValueMetricsCount(SelValue result) {
        if (result.type().isGraphData()) {
            return 1;
        }

        if (result.type().isVector() && result.type().vector().elementType.isGraphData()) {
            SelValue[] series = result.castToVector().valueArray();
            return series.length;
        }

        if (result.type().isObject()) {
            Map<String, SelValue> object = result.castToObject().getObject();

            return object.values().stream()
                .mapToInt(DataClient::computeSelValueMetricsCount)
                .sum();
        }

        return 0;
    }

    private CompletableFuture<Pair<List<NamedGraphData>, OldModeResult>> loadDataBySelectors(
            Selectors selectors,
            Interval interval,
            DownsamplingOptions downsamplingOpts,
            boolean isOldMode,
            boolean summaryOnly,
            @Nullable OperationTop rankFilter,
            @Nullable String destination,
            boolean useNewFormat,
            ProgramType programType,
            Instant deadline,
            RequestProducer producer,
            String subjectId)
    {
        final int metabaseReqLimit;
        if (programType == ProgramType.TOP) {
            metabaseReqLimit = DataLimits.MAX_METRICS_FOR_TOP_COUNT;
        } else {
            metabaseReqLimit = DataLimits.MAX_METRICS_FOR_AGGR_COUNT;
        }

        DataRequestType requestType = of(programType);

        List<Operation> operations = new ArrayList<>(2);
        if (rankFilter != null) {
            operations.add(Operation.newBuilder().setTop(rankFilter).build());
        }

        if (!summaryOnly) {
            @Nullable OperationDownsampling downsampling = downsamplingOpts.toOperation(interval.duration().toMillis());
            if (downsampling != null) {
                operations.add(Operation.newBuilder().setDownsampling(downsampling).build());
            }
        }

        if (summaryOnly) {
            CombinedCall.addOnlySummaryOperation(operations::add);
        }

        FindAndReadManyRequest request = FindAndReadManyRequest.newBuilder()
                .setSelectors(selectors)
                .setUseNewFormat(useNewFormat)
                .setMetabaseLimit(metabaseReqLimit)
                .setFromMillis(interval.getBeginMillis())
                .setToMillis(interval.getEndMillis())
                .setOperations(operations)
                .setOldMode(isOldMode)
                .setRequestType(requestType)
                .setDeadline(deadline)
                .setDestination(destination)
                .setProducer(producer)
                .setSubjectId(subjectId)
                .build();

        return metricsClient.findAndReadMany(request)
            .thenApply(response -> {
                ensureMetabaseStatusValid(response.getMetaStatus());
                ensureStockpileStatusValid(response.getStorageStatus());

                List<Metric<MetricKey>> result = response.getMetrics();

                var timeseriesList = new ArrayList<NamedGraphData>(result.size());
                int bytes = 0;
                for (var metric : result) {
                    var timeseries = namedGraphDataForVisibleMetric(metric);
                    timeseriesList.add(timeseries);
                    bytes += timeseries.memorySizeIncludingSelfInt();
                }

                if (bytes > DataLimits.HUGE_RESPONSE_THRESHOLD) {
                    logger.info("Huge load {} by selector {}", DataSize.shortString(bytes), selectors);
                }
                return Pair.of(timeseriesList, response.getModeResult());
            });
    }

    public static DataRequestType of(ProgramType type) {
        switch (type) {
            case SELECTORS:
            case SIMPLE:
                return DataRequestType.SIMPLE;
            case COMPLEX:
                return DataRequestType.COMPLEX;
            case TOP:
                return DataRequestType.TOP;
        }
        throw new IllegalArgumentException("Unknown program type: " + type);
    }

    private CompletableFuture<LoadRequestAndResult> loadDataByGraphDataLoadRequest(
            String projectId,
            String clusterName,
            GraphDataLoadRequest loadRequest,
            ProgramType programType,
            DownsamplingOptions downsamplingOptions,
            @Nullable String destination,
            boolean useNewFormat,
            boolean oldMode,
            boolean summaryOnly,
            Instant deadline,
            RequestProducer producer,
            String subjectId)
    {
        return loadDataBySelectors(
                overrideSelectors(loadRequest.getSelectors(), projectId, clusterName),
                loadRequest.getInterval(),
                overrideDownsamplingOpts(downsamplingOptions, loadRequest),
                oldMode,
                summaryOnly,
                loadRequest.getRankFilter(),
                destination,
                useNewFormat,
                programType,
                deadline,
                producer,
                subjectId)
            .thenApply(data -> {
                List<NamedGraphData> series = data.getLeft();
                List<NamedGraphData> convertedSeries = convertSeriesIfNeeded(series);
                OldModeResult oldModeResult = data.getRight();
                metrics.register(oldModeResult);
                return new LoadRequestAndResult(loadRequest, convertedSeries, oldModeResult);
            });
    }

    private static void ensureSelectorInsideProject(String projectId, Selectors selectors) {
        long incorrectProjectSelectorsCount = selectors.stream()
                .filter(s -> s.getKey().equals(LabelKeys.PROJECT))
                .filter(s -> !s.isExact() || !s.getValue().equals(projectId))
                .count();

        if (incorrectProjectSelectorsCount > 0) {
            throw new CannotLoadMetricsFromOtherProject();
        }
    }

    private static DownsamplingOptions overrideDownsamplingOpts(DownsamplingOptions opts, GraphDataLoadRequest request) {
        var result = opts.toBuilder();
        var aggr = request.getAggregateFunction();
        if (aggr != null) {
            result.setDownsamplingAggr(convertAggrToDownsamplingAggr(aggr));
        }

        if (request.getGridMillis() != 0) {
            result.setDownsamplingType(DownsamplingType.BY_INTERVAL);
            result.setGridMillis(request.getGridMillis());
        }

        return result.build();
    }

    private static Selectors overrideSelectors(Selectors selectors, String projectId, String clusterName) {
        @Nullable Selector projectSelector = selectors.findByKey(LabelKeys.PROJECT);
        var builder = selectors.toBuilder();
        if (projectSelector == null) {
            builder.addOverride(LabelKeys.PROJECT, projectId);
        }
        if (StringUtils.isNotEmpty(clusterName)) {
            builder.addOverride(LabelKeys.CLUSTER, clusterName);
        }
        var result = builder.build();
        ensureSelectorInsideProject(projectId, result);
        return result;
    }

    private static List<NamedGraphData> convertSeriesIfNeeded(List<NamedGraphData> series) {
        return series.stream()
            .map(DataClient::convertToDGaugeIfNeeded)
            .collect(Collectors.toList());
    }

    private static NamedGraphData convertToDGaugeIfNeeded(NamedGraphData namedGraphData) {
        AggrGraphDataArrayListOrView graphData = namedGraphData.getAggrGraphDataArrayList();
        MetricType type = namedGraphData.getDataType();
        if (!graphData.isEmpty() && type != MetricType.DGAUGE && MetricTypes.isPrimitive(type)) {
            namedGraphData.toBuilder().setGraphData(MetricType.DGAUGE,
                AggrGraphDataArrayList.of(MetricTypeTransfers.of(type, MetricType.DGAUGE, graphData.iterator())));
        }

        return namedGraphData;
    }

    private static Aggregation convertAggrToDownsamplingAggr(AggregateFunctionType aggr) {
        return switch (aggr) {
            case DEFAULT -> Aggregation.DEFAULT_AGGREGATION;
            case MAX -> Aggregation.MAX;
            case MIN -> Aggregation.MIN;
            case AVG -> Aggregation.AVG;
            case SUM -> Aggregation.SUM;
            case LAST -> Aggregation.LAST;
            case COUNT -> Aggregation.COUNT;
            default -> throw new RuntimeException("unknown downsampling aggregation type:" + aggr);
        };
    }

    private static NamedGraphData namedGraphDataForVisibleMetric(Metric<MetricKey> metric) {
        MetricKey metricKey = metric.getKey();
        var type = metricKey != null ? metricKey.getType() : ru.yandex.monlib.metrics.MetricType.UNKNOWN;
        String metricName = metricKey != null ? metricKey.getName() : "";
        Labels labels = metricKey != null ? metricKey.getLabels() : Labels.empty();

        final AggrGraphDataArrayList dataTimeseries;

        if (metric.getTimeseries() == null) {
            // TODO: should be null, not empty list. But probably no one expects seires to be null, so ...
            dataTimeseries = AggrGraphDataArrayList.empty();
        } else {
            dataTimeseries = AggrGraphDataArrayList.of(metric.getTimeseries());
        }

        TimeseriesSummary dataSummary = metric.getSummary();

        return NamedGraphData.newBuilder()
            .setType(type)
            .setMetricName(metricName)
            .setLabels(labels)
            .setGraphData(metric.getType(), dataTimeseries)
            .setSummary(dataSummary)
            .build();
    }

    private static class LoadRequestAndResult {
        private final GraphDataLoadRequest request;
        private final List<NamedGraphData> series;
        private final OldModeResult oldModeResult;

        private LoadRequestAndResult(
            GraphDataLoadRequest request,
            List<NamedGraphData> series,
            OldModeResult oldModeResult)
        {
            this.request = request;
            this.series = series;
            this.oldModeResult = oldModeResult;
        }

        GraphDataLoadRequest getRequest() {
            return request;
        }

        List<NamedGraphData> getSeries() {
            return series;
        }

        public OldModeResult getOldModeResult() {
            return oldModeResult;
        }
    }
}
