package ru.yandex.solomon.codec.histogram;

import java.util.Comparator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import com.google.common.net.HostAndPort;
import org.apache.commons.io.FileUtils;
import org.apache.logging.log4j.Level;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import ru.yandex.grpc.utils.DefaultClientOptions;
import ru.yandex.metabase.client.MetabaseClient;
import ru.yandex.metabase.client.MetabaseClientOptions;
import ru.yandex.metabase.client.MetabaseClients;
import ru.yandex.misc.concurrent.CompletableFutures;
import ru.yandex.monlib.metrics.histogram.Histograms;
import ru.yandex.monlib.metrics.labels.Label;
import ru.yandex.monlib.metrics.labels.Labels;
import ru.yandex.solomon.codec.archive.MetricArchiveImmutable;
import ru.yandex.solomon.codec.archive.MetricArchiveMutable;
import ru.yandex.solomon.codec.archive.serializer.MetricArchiveNakedSerializer;
import ru.yandex.solomon.codec.serializer.StockpileDeserializer;
import ru.yandex.solomon.codec.serializer.StockpileFormat;
import ru.yandex.solomon.labels.protobuf.LabelConverter;
import ru.yandex.solomon.labels.protobuf.LabelSelectorConverter;
import ru.yandex.solomon.labels.query.Selectors;
import ru.yandex.solomon.main.logger.LoggerConfigurationUtils;
import ru.yandex.solomon.math.GraphDataTimePointIterator;
import ru.yandex.solomon.metabase.api.protobuf.EMetabaseStatusCode;
import ru.yandex.solomon.metabase.api.protobuf.FindRequest;
import ru.yandex.solomon.model.point.AggrPoint;
import ru.yandex.solomon.model.point.AggrPointData;
import ru.yandex.solomon.model.point.column.HistogramColumn;
import ru.yandex.solomon.model.point.column.StockpileColumn;
import ru.yandex.solomon.model.point.column.TsColumn;
import ru.yandex.solomon.model.point.column.ValueColumn;
import ru.yandex.solomon.model.protobuf.TimeSeries;
import ru.yandex.solomon.model.timeseries.AggrGraphDataArrayList;
import ru.yandex.solomon.model.timeseries.AggrGraphDataListIterator;
import ru.yandex.solomon.model.timeseries.SortedOrCheck;
import ru.yandex.solomon.model.timeseries.Timeline;
import ru.yandex.solomon.model.type.Histogram;
import ru.yandex.stockpile.api.EStockpileStatusCode;
import ru.yandex.stockpile.api.TCompressedReadResponse;
import ru.yandex.stockpile.api.TReadRequest;
import ru.yandex.stockpile.client.StockpileClient;
import ru.yandex.stockpile.client.StockpileClientOptions;
import ru.yandex.stockpile.client.StockpileClients;
import ru.yandex.stockpile.client.StopStrategies;

/*
Sorted by histogram compress size:
Labels                                                                                                                                                              Points  Histogram  Doubles  Diff  Histogram Bytes  Doubles Bytes  Diff Bytes
{cluster='production', host='cluster', project='solomon', projectId='total', sensor='engine.metabaseWrites.elapsedTimeMs', service='coremon', shardId='total'}      761658     263 KB     3 MB  3 MB           269331        4003238     3733907
{cluster='production', host='cluster', project='solomon', projectId='total', sensor='engine.process.parseTimeMillis', service='coremon', shardId='total'}           765683     629 KB     4 MB  4 MB           644978        5130900     4485922
{cluster='storage_sas', host='cluster', project='solomon', projectId='total', sensor='stockpile.write.records.elapsedTimeMs', service='stockpile'}                  817108       1 MB     4 MB  2 MB          2062304        5189995     3127691
{cluster='storage_sas', host='cluster', project='solomon', projectId='total', sensor='stockpile.write.records.elapsedTimeMs', service='stockpile'}                  817108       1 MB     4 MB  2 MB          2062306        5189995     3127689
{cluster='production', host='cluster', project='solomon', projectId='total', sensor='engine.process.writeTimeMillis', service='coremon', shardId='total'}           767764       1 MB     5 MB  4 MB          1183658        5673654     4489996
{cluster='production', host='cluster', project='solomon', projectId='total', sensor='engine.process.waitTimeMillis', service='coremon', shardId='total'}            761622       1 MB     5 MB  4 MB          1265338        5751289     4485951
{cluster='production', host='cluster', project='solomon', projectId='total', sensor='engine.process.totalTimeMillis', service='coremon', shardId='total'}           765499       1 MB     5 MB  4 MB          1261030        5901589     4640559
{sensor='sum'}                                                                                                                                                    36840660       5 MB     8 MB  3 MB          5456442        8748945     3292503
*/

/**
 * @author Vladimir Gordiychuk
 *
 * TODO: this benchmark must not depends on Stockpile and Metabase clients.
 */
public class HistogramEncodeCompareBenchmark implements AutoCloseable {
    private static final Logger logger = LoggerFactory.getLogger(HistogramEncodeCompareBenchmark.class);
    private final MetabaseClient metabase;
    private final StockpileClient stockpile;
    private final CopyOnWriteArrayList<MeasureResult> results = new CopyOnWriteArrayList<>();

    private HistogramEncodeCompareBenchmark(MetabaseClient metabase, StockpileClient stockpile) {
        this.metabase = metabase;
        this.stockpile = stockpile;
    }

    public static void main(String[] args) throws InterruptedException {
        LoggerConfigurationUtils.simpleLogger(Level.INFO);

        HistogramEncodeCompareBenchmark benchmark = new HistogramEncodeCompareBenchmark(metabaseClient(), stockpileClient());
        TimeUnit.MILLISECONDS.sleep(200);
        benchmark.measure("project='solomon', cluster='storage_sas', service='stockpile', sensor='stockpile.write.records.elapsedTimeMs', projectId='total', bin='*', host='cluster'");
        benchmark.measure("project='solomon', cluster='storage_sas', service='stockpile', sensor='stockpile.write.records.elapsedTimeMs', projectId='total', bin='*', host='cluster'");
        benchmark.measure("project='solomon', cluster='production', service='coremon', sensor='engine.metabaseWrites.elapsedTimeMs', projectId='total', shardId='total', bin='*', host='cluster'");
        benchmark.measure("project='solomon', cluster='production', service='coremon', sensor='engine.metabaseReads.elapsedTimeMs', projectId='total', shardId='total', bin='*', host='cluster'");
        benchmark.measure("project='solomon', cluster='production', service='coremon', sensor='engine.process.totalTimeMillis', projectId='total', shardId='total', host='cluster', bin='*'");
        benchmark.measure("project='solomon', cluster='production', service='coremon', sensor='engine.process.waitTimeMillis', projectId='total', shardId='total', host='cluster', bin='*'");
        benchmark.measure("project='solomon', cluster='production', service='coremon', sensor='engine.process.writeTimeMillis', projectId='total', shardId='total', host='cluster', bin='*'");
        benchmark.measure("project='solomon', cluster='production', service='coremon', sensor='engine.process.parseTimeMillis', projectId='total', shardId='total', host='cluster', bin='*'");

        benchmark.printSummary();
        benchmark.close();
        System.exit(0);
    }

    private static StockpileClient stockpileClient() {
        var options = StockpileClientOptions.newBuilder(
                DefaultClientOptions.newBuilder()
                        .setRequestTimeOut(1, TimeUnit.MINUTES)
                        .setKeepAliveDelay(1, TimeUnit.MINUTES)
                        .setKeepAliveTimeout(1, TimeUnit.SECONDS))
                .setExpireClusterMetadata(30, TimeUnit.SECONDS)
                .setRetryStopStrategy(StopStrategies.stopAfterAttempt(20))
                .build();

        return StockpileClients.createDynamic(List.of("conductor_group://solomon_pre_stockpile:5700"), options);
    }

    private static MetabaseClient metabaseClient() {
        List<HostAndPort> addresses = IntStream.range(0, 15)
                .mapToObj(n -> String.format("solomon-pre-fetcher-man-%03d.search.yandex.net", n))
                .map(host -> HostAndPort.fromParts(host, /*SolomonPorts.COREMON_GRPC*/ 5710))
                .collect(Collectors.toList());
//
        var options = MetabaseClientOptions.newBuilder(
                DefaultClientOptions.newBuilder()
                        .setRequestTimeOut(1, TimeUnit.MINUTES)
                        .setKeepAliveDelay(1, TimeUnit.MINUTES)
                        .setKeepAliveTimeout(1, TimeUnit.SECONDS))
                .setExpireClusterMetadata(30, TimeUnit.MINUTES)
                .setMetaDataRequestTimeOut(5, TimeUnit.MINUTES)
                .build();

        return MetabaseClients.create(addresses, options);
    }

    private void measure(String selectors) {
        measure(Selectors.parse(selectors));
    }

    private void measure(Selectors selectors) {
        download(selectors)
                .thenApply(metrics -> {
                    if (metrics.size() == 0) {
                        logger.warn("Metrics not found by selectors {}", selectors);
                        return null;
                    }

                    int totalCountPoints = metrics.stream()
                            .mapToInt(metric -> metric.getPoints().length())
                            .sum();
                    Labels labels = metrics.get(0).getLabels();

                    int bytesSizeHistogram = compressIntoHistogram(metrics);
                    int bytesSizeDoubles = compressIntoSeparateDoubles(metrics);

                    return new MeasureResult(labels, bytesSizeHistogram, bytesSizeDoubles, totalCountPoints);
                })
                .handle((r, t) -> {
                    if (t != null) {
                        logger.error("Error occurs during process selector {}", selectors, t);
                    }
                    if (r != null) {
                        results.add(r);
                    }

                    return null;
                })
                .join();
    }

    private void printSummary() {
        List<MeasureResult> sortedMeasure = results.stream()
                .sorted(Comparator.comparingLong(MeasureResult::getDoubleBytes).thenComparing(MeasureResult::getHistogramBytes))
                .collect(Collectors.toList());

        int totalPoints = 0;
        int totalHistogramBytes = 0;
        int totalDoublesBytes = 0;
        for (MeasureResult measure : sortedMeasure) {
            totalPoints += measure.countPoints;
            totalHistogramBytes += measure.histogramBytes;
            totalDoublesBytes += measure.doubleBytes;
        }
        sortedMeasure.add(new MeasureResult(Labels.of("sensor", "sum"), totalPoints, totalHistogramBytes, totalDoublesBytes));

        LinkedHashMap<String, Function<MeasureResult, String>> columns = new LinkedHashMap<>();
        columns.put("Labels", measure -> measure.getLabels().toString());
        columns.put("Points", measure -> String.valueOf(measure.getCountPoints()));
        columns.put("Histogram", measure -> FileUtils.byteCountToDisplaySize(measure.getHistogramBytes()));
        columns.put("Doubles", measure -> FileUtils.byteCountToDisplaySize(measure.getDoubleBytes()));
        columns.put("Diff", measure -> FileUtils.byteCountToDisplaySize(measure.getDoubleBytes() - measure.getHistogramBytes()));
        columns.put("Histogram Bytes", measure -> String.valueOf(measure.getHistogramBytes()));
        columns.put("Doubles Bytes", measure -> String.valueOf(measure.getDoubleBytes()));
        columns.put("Diff Bytes", measure -> String.valueOf(measure.getDoubleBytes() - measure.getHistogramBytes()));

        printTable(sortedMeasure, columns);
    }

    private void printTable(List<MeasureResult> source, Map<String, Function<MeasureResult, String>> columns) {
        Map<String, Integer> columnSize = columns.entrySet()
                .stream()
                .collect(Collectors.toMap(Map.Entry::getKey, e -> {
                    int length = source.stream()
                            .map(e.getValue())
                            .mapToInt(String::length)
                            .max()
                            .orElse(0);
                    return Math.max(e.getKey().length(), length) + 2;
                }));

        String first = columns.keySet().iterator().next();
        columnSize.compute(first, (s, size) -> size * -1);
        System.out.println("Sorted by histogram compress size:");
        for (String columnName : columns.keySet()) {
            int size = columnSize.get(columnName);
            System.out.printf("%" + size + "s", columnName);
        }
        System.out.println();

        for (MeasureResult measure : source) {
            for (Map.Entry<String, Function<MeasureResult, String>> entry : columns.entrySet()) {
                int size = columnSize.get(entry.getKey());
                String value = entry.getValue().apply(measure);
                System.out.printf("%" + size + "s", value);
            }
            System.out.println();
        }
        System.out.println();
    }

    private int compressIntoHistogram(List<Metric> metrics) {
        double[] bounds = metrics.stream()
                .mapToDouble(Metric::getUpperBound)
                .toArray();

        Timeline timeline = Timeline.union(metrics.stream()
                .map(metric -> new Timeline(metric.getPoints().getTimestamps(), SortedOrCheck.SORTED_UNIQUE))
                .collect(Collectors.toList()));

        List<GraphDataTimePointIterator> iterators = metrics.stream()
                .map(metric -> new GraphDataTimePointIterator(metric.getPoints().toGraphDataShort()))
                .collect(Collectors.toList());

        AggrPointData point = new AggrPointData();
        int mask = TsColumn.mask | HistogramColumn.mask;
        MetricArchiveMutable content = new MetricArchiveMutable();
        content.ensureCapacity(mask, timeline.length());
        for (int tIndex = 0; tIndex < timeline.length(); tIndex++) {
            long ts = timeline.getPointMillisAt(tIndex);
            long[] buckets = new long[bounds.length];
            for (int bIndex = 0; bIndex < iterators.size(); bIndex++) {
                GraphDataTimePointIterator iterator = iterators.get(bIndex);
                if (iterator.hasNext(ts)) {
                    buckets[bIndex] = Math.round(iterator.next());
                }
            }
            point.tsMillis = ts;
            point.histogram = Histogram.newInstance().copyFrom(bounds, buckets);
            content.addRecordData(mask, point);
        }

//        StockpileLogEntryContent log = new StockpileLogEntryContent();
//        log.addArchive(StockpileLocalId.random(), content);
//        return StockpileLogEntryContentSerializer.S.serializeToByteString(log).size();

        return 0;
    }

    private int compressIntoSeparateDoubles(List<Metric> metrics) {
//        StockpileLogEntryContent log = new StockpileLogEntryContent();
//        for (Metric metric : metrics) {
//            log.addMetric(metric.localId, metric.getPoints());
//        }
//        return StockpileLogEntryContentSerializer.S.serializeToByteString(log).size();

        return 0;
    }

    private CompletableFuture<List<Metric>> download(Selectors selectors) {
        return metabase.find(FindRequest.newBuilder()
                .addAllSelectors(LabelSelectorConverter.selectorsToProto(selectors))
                .build())
                .thenCompose(response -> {
                    if (response.getStatus() != EMetabaseStatusCode.OK) {
                        throw new IllegalStateException(response.getStatus() + ": " + response.getStatusMessage());
                    }

                    return response.getMetricsList()
                            .stream()
                            .map(metric -> stockpile.readCompressedOne(TReadRequest.newBuilder()
                                    .setMetricId(metric.getMetricId())
                                    .setBinaryVersion(StockpileFormat.CURRENT.getFormat())
                                    .build())
                                    .thenApply(this::decode)
                                    .thenApply(points -> new Metric(metric, points)))
                            .collect(Collectors.collectingAndThen(Collectors.toList(), CompletableFutures::allOf));
                })
                .thenApply(list -> list.stream()
                        .sorted(Comparator.comparingDouble(Metric::getUpperBound))
                        .collect(Collectors.toList()));
    }

    private AggrGraphDataArrayList decode(TCompressedReadResponse response) {
        if (response.getStatus() != EStockpileStatusCode.OK) {
            throw new IllegalStateException(response.getStatus() + ": " + response.getStatusMessage());
        }

        StockpileFormat format = StockpileFormat.byNumber(response.getBinaryVersion());
        AggrGraphDataArrayList list = response.getChunksList()
                .stream()
                .map(chunk -> decodeChunk(chunk, format))
                .reduce(AggrGraphDataArrayList.empty(), (left, right) -> {
                    left.addAll(right);
                    return left;
                });

        list.sortAndMerge();
        return list;
    }

    private AggrGraphDataArrayList decodeChunk(TimeSeries.Chunk chunk, StockpileFormat format) {
        if (chunk.getPointCount() == 0) {
            return AggrGraphDataArrayList.empty();
        }

        StockpileDeserializer deserializer = new StockpileDeserializer(chunk.getContent());
        MetricArchiveImmutable archive = MetricArchiveNakedSerializer.serializerForFormatSealed(format)
                .deserializeToEof(deserializer);

        if (archive.isEmpty()) {
            return AggrGraphDataArrayList.empty();
        }

        return roundAndDropNans(archive.iterator());
    }

    private AggrGraphDataArrayList roundAndDropNans(AggrGraphDataListIterator it) {
        int mask = StockpileColumn.TS.mask() | StockpileColumn.VALUE.mask();
        AggrGraphDataArrayList result = new AggrGraphDataArrayList(mask, it.estimatePointsCount());
        AggrPoint point = new AggrPoint();
        while (it.next(point)) {
            double value = point.getValueDivided();
            if (Double.isNaN(value)) {
                continue;
            }

            point.setValue(value, ValueColumn.DEFAULT_DENOM);
            result.addRecordData(mask, point);
        }

        return result;
    }

    @Override
    public void close() {
        metabase.close();
        stockpile.close();
    }

    private static class Metric {
        private static Pattern BIN_PATTERN = Pattern.compile("(\\d+).*");

        private final int shardId;
        private final long localId;
        private final Labels labels;
        private final double upperBound;
        private final AggrGraphDataArrayList points;

        Metric(ru.yandex.solomon.metabase.api.protobuf.Metric metric, AggrGraphDataArrayList points) {
            this.shardId = metric.getMetricId().getShardId();
            this.localId = metric.getMetricId().getLocalId();
            this.labels = LabelConverter.protoToLabels(metric.getLabelsList());
            this.upperBound = parseBin();
            this.points = points;
        }

        private double parseBin() {
            Label label = this.labels.findByKey("bin");
            if (label == null) {
                throw new IllegalStateException("Not exists bin label into: " + labels);
            }

            String labelValue = label.getValue();

            if ("inf".equals(labelValue)) {
                return Histograms.INF_BOUND;
            }

            Matcher matcher = BIN_PATTERN.matcher(labelValue);
            if (!matcher.find()) {
                throw new IllegalStateException("invalid label value: " + labelValue);
            }

            String bin = matcher.group(1);
            return Long.parseLong(bin);
        }

        Labels getLabels() {
            return labels;
        }

        double getUpperBound() {
            return upperBound;
        }

        AggrGraphDataArrayList getPoints() {
            return points;
        }
    }

    private static class MeasureResult {
        private final Labels labels;
        private final int histogramBytes;
        private final int doubleBytes;
        private final int countPoints;

        MeasureResult(Labels labels, int histogramBytes, int doubleBytes, int countPoints) {
            this.labels = labels.removeByKey("bin");
            this.histogramBytes = histogramBytes;
            this.doubleBytes = doubleBytes;
            this.countPoints = countPoints;
        }

        Labels getLabels() {
            return labels;
        }

        int getHistogramBytes() {
            return histogramBytes;
        }

        int getDoubleBytes() {
            return doubleBytes;
        }

        int getCountPoints() {
            return countPoints;
        }
    }
}
