package ru.yandex.market.graphouse.search;

import java.io.ByteArrayOutputStream;
import java.io.PrintWriter;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.stream.Stream;

import org.apache.commons.io.Charsets;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.stereotype.Controller;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.server.ServerWebExchange;

import ru.yandex.market.graphouse.retention.RetentionManager;
import ru.yandex.market.graphouse.search.tree.MetricBaseZip;
import ru.yandex.market.graphouse.search.tree.MetricNameZip;
import ru.yandex.market.graphouse.stockpile.GraphouseStockpileClient;
import ru.yandex.market.graphouse.stockpile.MetricResponse;
import ru.yandex.monlib.metrics.histogram.Histograms;
import ru.yandex.monlib.metrics.labels.Labels;
import ru.yandex.monlib.metrics.primitives.Histogram;
import ru.yandex.monlib.metrics.primitives.Rate;
import ru.yandex.monlib.metrics.registry.MetricRegistry;
import ru.yandex.solomon.config.protobuf.graphite.storage.TMetricDataControllerConfig;
import ru.yandex.solomon.config.thread.ThreadPoolProvider;
import ru.yandex.solomon.model.timeseries.GraphData;
import ru.yandex.solomon.model.timeseries.GraphDataArrayList;
import ru.yandex.solomon.model.timeseries.SortedOrCheck;
import ru.yandex.solomon.model.timeseries.Timeline;
import ru.yandex.solomon.util.collection.Nullables;
import ru.yandex.solomon.util.collection.array.DoubleArrayView;
import ru.yandex.solomon.util.collection.array.LongArrayView;

import static java.util.concurrent.CompletableFuture.completedFuture;

/**
 * Response format:
 * for each resolved metric:
 * <pre>
 * [metric name] [number of points]
 * [timestamp in millis] [value] * [number of points] times.
 * </pre>
 * e.g.:
 * <pre>
 * one_sec.local.random.diceroll.1 1 1
 * 1463665893 1.0
 * one_sec.local.random.diceroll.5 2 1
 * 1463663878 5.0
 * 1463766459 5.0
 * </pre>
 *
 * @author Maksim Leonov (nohttp@)
 * @date 20/05/16
 */
@Controller
public class MetricDataController {
    private static final int PROTOCOL_VERSION = 2;
    private static final int MAX_METRICS_IN_REQUEST = 100_000;

    private static final Logger log = LoggerFactory.getLogger(MetricDataController.class);

    private final MetricSearch metricSearch;
    private final GraphouseStockpileClient client;
    private final RetentionManager retentionManager;
    private final ExecutorService executorService;
    private final MetricDataControllerMetrics metrics;

    @Autowired
    public MetricDataController(
        MetricSearch metricSearch,
        GraphouseStockpileClient client,
        RetentionManager retentionManager,
        TMetricDataControllerConfig config,
        ThreadPoolProvider threadPoolProvider)
    {
        this.metricSearch = metricSearch;
        this.client = client;
        this.retentionManager = retentionManager;
        this.executorService = threadPoolProvider.getExecutorService(
            config.getThreadPoolName(), "MetricDataControllerConfig.ThreadPoolName");

        MetricRegistry metricRegistry = MetricRegistry.root();
        this.metrics = new MetricDataControllerMetrics(metricRegistry);
    }

    // in WebFlux @RequestParam cannot be used to map form data, so here
    // we use intermediate handler to retrieve data from submitted form
    @PostMapping(value = "/query", consumes = MediaType.APPLICATION_FORM_URLENCODED_VALUE)
    public CompletableFuture<ResponseEntity<byte[]>> searchPost(ServerWebExchange exchange) {
        return exchange.getFormData()
                .toFuture()
                .thenCompose(formData -> {
                    String[] metrics = formData.get("metrics").toArray(String[]::new);
                    long from = Long.parseLong(Objects.requireNonNull(formData.getFirst("from")));
                    long until = Long.parseLong(Objects.requireNonNull(formData.getFirst("until")));
                    String user = Nullables.orDefault(formData.getFirst("user"), "unknown");
                    String req = Nullables.orDefault(formData.getFirst("req"), "unknown");

                    return searchGet(metrics, from, until, user, req);
                });
    }

    @GetMapping(value = "/query")
    public CompletableFuture<ResponseEntity<byte[]>> searchGet(
        @RequestParam(name = "metrics") String[] rawMetrics,
        @RequestParam(name = "from") long startTime,
        @RequestParam(name = "until") long endTime,
        @RequestParam(name = "user", defaultValue = "unknown") String userName,
        @RequestParam(name = "req", defaultValue = "unknown") String reqKey)
    {
        long queryProcessBeginMillis = System.currentTimeMillis();
        rawMetrics = Arrays.stream(rawMetrics).sorted().distinct().toArray(String[]::new);
        MetricRegistry
            .root()
            .rate("metric.reads", Labels.of("user", userName, "method", "query"))
            .add(rawMetrics.length);

        metrics.recordMetricsInRequestHistogram(rawMetrics.length);
        if (rawMetrics.length > MAX_METRICS_IN_REQUEST) {
            var response = ResponseEntity.status(HttpStatus.INTERNAL_SERVER_ERROR)
                    .contentType(MediaType.TEXT_PLAIN)
                    .body("Too many metrics requested".getBytes(Charsets.UTF_8));
            return completedFuture(response);
        }

        log.info("Req key " + reqKey + " matches");
        return readMetricsAsync(rawMetrics, startTime, endTime, queryProcessBeginMillis, userName)
                .thenApply(body -> ResponseEntity.ok()
                        .contentType(MediaType.TEXT_PLAIN)
                        .body(body));
    }

    private CompletableFuture<byte[]> readMetricsAsync(
        String[] uniqueSortedMetrics,
        long startTime,
        long endTime,
        long queryProcessBeginMillis,
        String userName)
    {
        log.info("Began reading data");
        long queryStartTime = System.currentTimeMillis();
        MetricNameZip[] resolvedMetrics = Arrays.stream(uniqueSortedMetrics)
                .map(metricSearch::findExistingMetric)
                .filter(Objects::nonNull)
                .toArray(MetricNameZip[]::new);
        long resolveTime = System.currentTimeMillis() - queryStartTime;
        metrics.recordResolveTimeHistogram(resolveTime);

        metrics.incrementConnectionsAccepted();
        metrics.addMetricRequests(resolvedMetrics.length);

        long startTimeMillis = startTime * 1000;
        long endTimeMillis = endTime * 1000;

        Stream<Short> retantionIds =
            Arrays.stream(resolvedMetrics).map(MetricBaseZip::getRetentionId);
        long stepMillis = getStepMillis(retantionIds, startTimeMillis);

        // We add extra step before and after to make a window-based resampler (see lower) work
        long readStartTimeMillis = startTimeMillis - startTimeMillis % stepMillis - stepMillis;
        long readEndTimeMillis = endTimeMillis - endTimeMillis % stepMillis + stepMillis;

        return client.readMetrics(resolvedMetrics, readStartTimeMillis, readEndTimeMillis)
            .thenApplyAsync(metricResponses -> {
                log.info("Began serving {} metrics", metricResponses.size());

                ByteArrayOutputStream output = new ByteArrayOutputStream(128);
                try (PrintWriter out = new PrintWriter(output)) {
                    long readTime = System.currentTimeMillis() - queryStartTime;
                    metrics.recordReadTimingsHistogram(readTime);
                    writeMetrics(metricResponses, readStartTimeMillis, readEndTimeMillis, stepMillis, out, userName);

                    long writeTime = System.currentTimeMillis() - queryStartTime;
                    out.println("# stats:" +
                            " resolve: +" + resolveTime +
                            " read: +" + readTime +
                            " write: +" + writeTime);
                }

                long dt = System.currentTimeMillis() - queryProcessBeginMillis;
                metrics.recordReplyTimingsHistogram(dt);

                log.info("Request complete in " + dt);
                return output.toByteArray();
            }, executorService);
    }

    long getStepMillis(Stream<Short> retantionIds, long startTimeMillis) {
        return retantionIds
            .mapToLong(id -> retentionManager.getRetentionById(id).getStepMillis(startTimeMillis))
            .max().orElse(60_000);
    }

    private void writeMetrics(
        List<MetricResponse> metricResponses,
        long readStartTimeMillis,
        long readEndTimeMillis,
        long stepMillis,
        PrintWriter out,
        String userName)
    {
        out.println(PROTOCOL_VERSION);
        out.println(stepMillis / 1000);

        boolean hasData = false;
        long minObservedStart = Long.MAX_VALUE;
        long maxObservedEnd = Long.MIN_VALUE;

        for (MetricResponse metric : metricResponses) {
            GraphData gd = metric.payload;
            if (!gd.isEmpty()) {
                hasData = true;
                minObservedStart = Math.min(minObservedStart, gd.getTimestamps().first());
                maxObservedEnd = Math.max(maxObservedEnd, gd.getTimestamps().last());
            }
        }
        if (!hasData) {
            return;
        }

        readStartTimeMillis = Math.max(readStartTimeMillis, minObservedStart);
        readEndTimeMillis = Math.min(readEndTimeMillis, maxObservedEnd);

        readStartTimeMillis -= readStartTimeMillis % stepMillis;

        Rate pointReads = MetricRegistry.root()
                .rate("point.reads", Labels.of("user", userName, "method", "query"));

        for (MetricResponse metric : metricResponses) {
            pointReads.add(metric.payload.length());

            GraphData normalized = normalize(metric.payload, readStartTimeMillis, readEndTimeMillis, stepMillis);
            MetricDescription description = metric.description;

            metrics.addPointsRead(metric.payload.length());
            metrics.addPointsServed(normalized.length());

            String metricName = description.getName();
            int numPoints = normalized.length();

            out.println(metricName + " " + numPoints + " " + 0);
            normalized.visit((tsMillis, value) -> {
                out.print(tsMillis / 1000);
                out.print(" ");
                out.print(value);
                out.println();
            });
        }
    }

    private GraphData normalize(GraphData gd, long fromMillis, long toMillis, long step) {
        if (gd.isEmpty()) {
            return gd;
        }
        Timeline original = gd.getTimeline();
        long resampledFromMillis = (fromMillis == 0) ? original.first() : fromMillis;
        long resampledToMillis = (toMillis == 0) ? original.last() : toMillis;

        Timeline timeline = makeRegularTimeline(resampledFromMillis, resampledToMillis, step);
        return windowResampleOnNet(gd, timeline, step);
    }

    private Timeline makeRegularTimeline(long startMillis, long endMillis, long stepMillis) {
        if (stepMillis <= 0) {
            throw new RuntimeException();
        }
        long duration = endMillis - startMillis;
        int numPoints = (int) (duration / stepMillis) + 1;
        if (duration % stepMillis != 0) {
            numPoints++;
        }

        long[] array = new long[numPoints];

        array[0] = startMillis;
        for (int i = 1; i < numPoints; i++) {
            array[i] = array[i - 1] + stepMillis;
        }

        return new Timeline(array, SortedOrCheck.SORTED_UNIQUE);
    }

    private GraphData windowResampleOnNet(GraphData gd, Timeline timeline, long stepMillis) {
        long radiusMillis = stepMillis / 2; // '2' affects windowBeginInclIdx logic
        long optimisticRadiusMillis = (long) (radiusMillis * 1.25);

        int windowBeginInclIdx = 0;
        int windowEndExclIdx = 0;

        LongArrayView inputTs = gd.getTimestamps();
        DoubleArrayView inputPoints = gd.getValues();

        GraphDataArrayList output = new GraphDataArrayList(timeline.length());
        for (int i = 0; i < timeline.length(); i++) {
            long pointTsMillis = timeline.getPointMillisAt(i);

            double windowSum = 0;
            windowBeginInclIdx = windowEndExclIdx; // because intervals do not intersect (iff radius == step / 2)

            while (windowEndExclIdx < inputTs.length() && inputTs.at(windowEndExclIdx) <= pointTsMillis + radiusMillis) {
                windowSum += inputPoints.at(windowEndExclIdx);
                windowEndExclIdx++;
            }
            while (windowBeginInclIdx < windowEndExclIdx && inputTs.at(windowBeginInclIdx) < pointTsMillis - radiusMillis) {
                windowSum -= inputPoints.at(windowBeginInclIdx);
                windowBeginInclIdx++;
            }
            if (windowBeginInclIdx != windowEndExclIdx) {
                output.add(pointTsMillis, windowSum / (windowEndExclIdx - windowBeginInclIdx));
            } else {
                // quite possible situation:
                // * a metric with an five-sec interval (with nodes, say, at hh:00:0[50])
                // * a point at 11:05:17 and a point at 11:05:23 (graphite timestamps are round up to seconds)
                // => we do not want to draw a gap at 11:05:20.
                // so we have to expand window by 1 point and check that these neighbour point are not far enough (1.25 radius is a good guess)

                double value = 0;
                long distToValuePoint = 0;
                boolean foundValue = false;
                if (windowBeginInclIdx != 0) {
                    long leftPointTsMillis = inputTs.at(windowBeginInclIdx - 1);
                    long distToLeftPoint = pointTsMillis - leftPointTsMillis;
                    if (distToLeftPoint <= optimisticRadiusMillis) {
                        foundValue = true;

                        value = inputPoints.at(windowBeginInclIdx - 1);
                        distToValuePoint = distToLeftPoint;
                    }
                }

                if (windowEndExclIdx != inputTs.length()) {
                    long rightPointTsMillis = inputTs.at(windowEndExclIdx);
                    long distToRightPoint = rightPointTsMillis - pointTsMillis;
                    if (distToRightPoint <= optimisticRadiusMillis) {
                        if ((!foundValue) || (distToRightPoint < distToValuePoint)) {
                            foundValue = true;

                            value = inputPoints.at(windowEndExclIdx);
                            distToValuePoint = distToRightPoint;
                        }
                    }
                }

                if (foundValue) {
                    output.add(pointTsMillis, value);
                }
            }
        }

        return output.buildGraphData();
    }

    private class MetricDataControllerMetrics {

        private final Rate connectionsAccepted;
        private final Rate readMetrics;
        private final Rate readPoints;
        private final Rate readPointsServed;
        private final Histogram metricsInRequestHistogram;
        private final Histogram readTimingsHistogram;
        private final Histogram replyTimingsHistogram;
        private final Histogram resolveTimeHistogram;

        MetricDataControllerMetrics(MetricRegistry metricRegistry) {
            this.connectionsAccepted = metricRegistry.rate("MetricDataController.connectionsAccepted");
            this.readMetrics = metricRegistry.rate("MetricDataController.readMetrics");
            this.readPoints = metricRegistry.rate("MetricDataController.readPoints");
            this.readPointsServed = metricRegistry.rate("MetricDataController.readPointsServed");
            this.metricsInRequestHistogram = metricRegistry.histogramRate(
                "MetricDataController.metricsInRequestHistogram", Histograms.exponential(13, 2, 16));
            this.readTimingsHistogram = metricRegistry.histogramRate(
                "MetricDataController.readTimingsHistogram", Histograms.exponential(13, 2, 16));
            this.replyTimingsHistogram = metricRegistry.histogramRate(
                "MetricDataController.replyTimingsHistogram", Histograms.exponential(13, 2, 16));
            this.resolveTimeHistogram = metricRegistry
                .histogramRate("MetricDataController.resolveTime", Histograms.exponential(13, 2, 16));
        }

        void incrementConnectionsAccepted() {
            connectionsAccepted.inc();
        }

        void addMetricRequests(int metrics) {
            readMetrics.add(metrics);
        }

        void addPointsRead(int points) {
            readPoints.add(points);
        }

        void addPointsServed(int servedPoints) {
            readPointsServed.add(servedPoints);
        }

        void recordMetricsInRequestHistogram(int lengthMetrics) {
            metricsInRequestHistogram.record(lengthMetrics);
        }

        void recordReadTimingsHistogram(long time) {
            readTimingsHistogram.record(time);
        }

        void recordReplyTimingsHistogram(long time) {
            replyTimingsHistogram.record(time);
        }

        void recordResolveTimeHistogram(long time) {
            resolveTimeHistogram.record(time);
        }
    }
}
