package ru.yandex.travel.actuator;

import java.io.ByteArrayOutputStream;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.TimeUnit;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
import java.util.regex.Pattern;
import java.util.stream.IntStream;

import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.DistributionSummary;
import io.micrometer.core.instrument.FunctionCounter;
import io.micrometer.core.instrument.FunctionTimer;
import io.micrometer.core.instrument.Gauge;
import io.micrometer.core.instrument.LongTaskTimer;
import io.micrometer.core.instrument.MeterRegistry;
import io.micrometer.core.instrument.Tag;
import io.micrometer.core.instrument.TimeGauge;
import io.micrometer.core.instrument.Timer;
import io.micrometer.core.instrument.distribution.CountAtBucket;
import io.micrometer.core.instrument.distribution.HistogramSnapshot;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;

import ru.yandex.monlib.metrics.MetricType;
import ru.yandex.monlib.metrics.encode.MetricEncoder;
import ru.yandex.monlib.metrics.encode.json.MetricJsonEncoder;
import ru.yandex.monlib.metrics.encode.spack.MetricSpackEncoder;
import ru.yandex.monlib.metrics.encode.spack.format.CompressionAlg;
import ru.yandex.monlib.metrics.encode.spack.format.TimePrecision;
import ru.yandex.monlib.metrics.histogram.ExplicitHistogramSnapshot;
import ru.yandex.monlib.metrics.histogram.Histograms;
import ru.yandex.monlib.metrics.labels.validate.StrictValidator;
import ru.yandex.travel.commons.metrics.TravelTag;

@Slf4j
public class MicrometerToSolomonEncoder {
    public static byte[] encodeSpack(MeterRegistry meterRegistry) {
        var stream = new ByteArrayOutputStream();
        var encoder = new MetricSpackEncoder(TimePrecision.SECONDS, CompressionAlg.LZ4, stream);
        encode(meterRegistry, encoder);
        encoder.close();
        return stream.toByteArray();
    }

    public static byte[] encodeJson(MeterRegistry meterRegistry) {
        var stream = new ByteArrayOutputStream();
        var encoder = new MetricJsonEncoder(stream);
        encode(meterRegistry, encoder);
        encoder.close();
        return stream.toByteArray();
    }

    @Data
    @EqualsAndHashCode
    static class Metric {
        private final String name;
        private final List<Tag> tags;
    }

    private static void doForAllTagCombinationsInner(List<Tag> tags, ArrayList<Tag> currentTags, int position, Consumer<List<Tag>> action) {
        if (position > tags.size()) {
            return;
        }
        if (position == tags.size()) {
            action.accept(List.copyOf(currentTags));
            return;
        }

        var tag = tags.get(position);
        if ((tag instanceof TravelTag) && ((TravelTag) tag).isTransparent()) {
            doForAllTagCombinationsInner(tags, currentTags, position + 1, action);
        }
        currentTags.add(tag);
        doForAllTagCombinationsInner(tags, currentTags, position + 1, action);
        currentTags.remove(currentTags.size() - 1);
    }

    private static void doForAllTagCombinations(List<Tag> tags, String name, Consumer<List<Tag>> action) {
        if (tags.stream().filter(tag -> (tag instanceof TravelTag) && ((TravelTag) tag).isTransparent()).count() > 10) {
            log.error("Too many transparent tags for metric {}", name);
        }

        doForAllTagCombinationsInner(tags, new ArrayList<>(), 0, action);
    }

    private static void encode(MeterRegistry meterRegistry, MetricEncoder metricEncoder) {
        var tsMillis = 0; // timestamp is not specified and will be generated by solomon

        var longs = new HashMap<Metric, Long>();
        var doubles = new HashMap<Metric, Optional<Double>>();
        var histograms = new HashMap<Metric, Optional<ExplicitHistogramSnapshot>>();

        meterRegistry.forEachMeter(m -> {
            List<Tag> tags = m.getId().getTags();

            BiConsumer<String, Long> onCounter = (String name, Long value) -> {
                doForAllTagCombinations(tags, name, currTags -> {
                    longs.merge(new Metric(name + ".rate", currTags), value, Long::sum);
                });
            };
            BiConsumer<String, Double> onGauge = (String name, Double value) -> {
                doForAllTagCombinations(tags, name, currTags -> {
                    doubles.merge(new Metric(name, currTags), Optional.of(value), (x, y) -> {
                        if (x.isPresent() && x.get().equals(y.get())) {
                            return x;
                        }
                        return Optional.empty();
                    });
                });
            };
            BiConsumer<String, ExplicitHistogramSnapshot> onHistogram = (String name, ExplicitHistogramSnapshot value) -> {
                doForAllTagCombinations(tags, name, currTags -> {
                    histograms.merge(new Metric(name, currTags), Optional.of(value), (x, y) -> {
                        if (x.isPresent() && x.get().boundsEquals(y.get())) { // can't handle different bounds
                            var lhs = x.get();
                            var rhs = y.get();

                            var bounds = IntStream.range(0, lhs.count()).mapToDouble(lhs::upperBound).toArray();
                            var buckets = IntStream.range(0, lhs.count()).mapToLong(i -> lhs.value(i) + rhs.value(i)).toArray();
                            return Optional.of(new ExplicitHistogramSnapshot(bounds, buckets));
                        }
                        return Optional.empty();
                    });
                });
            };

            String name = m.getId().getName();

            if (name.endsWith(".histogram") || name.endsWith(".percentile")) {
                // See io.micrometer.core.instrument.distribution.HistogramGauges.registerWithCommonFormat
                return;
            }
            if (m instanceof TimeGauge) {
                TimeGauge gauge = (TimeGauge) m;
                onGauge.accept(name + "Ms", gauge.value(TimeUnit.MILLISECONDS));
            } else if (m instanceof Gauge) {
                Gauge gauge = (Gauge) m;
                onGauge.accept(name, gauge.value());
            } else if (m instanceof Counter) {
                long count = (long) ((Counter) m).count();
                onCounter.accept(name, count);
            } else if (m instanceof FunctionCounter) {
                long count = (long) ((FunctionCounter) m).count();
                onCounter.accept(name, count);
            } else if (m instanceof Timer) {
                Timer timer = (Timer) m;
                HistogramSnapshot snapshot = timer.takeSnapshot();
                onCounter.accept(name + "Ms.count", timer.count());
                onCounter.accept(name + "Ms.total", (long) timer.totalTime(TimeUnit.MILLISECONDS));
                onHistogram.accept(name + "Ms.histogram", convertHistogramSnapshot(snapshot, 1000000.0));
            } else if (m instanceof FunctionTimer) {
                FunctionTimer functionTimer = (FunctionTimer) m;
                onCounter.accept(name + "Ms.count", (long) functionTimer.count());
                onCounter.accept(name + "Ms.total", (long) functionTimer.totalTime(TimeUnit.MILLISECONDS));
            } else if (m instanceof DistributionSummary) {
                DistributionSummary distributionSummary = (DistributionSummary) m;
                HistogramSnapshot snapshot = distributionSummary.takeSnapshot();
                onCounter.accept(name + ".count", (long) distributionSummary.count());
                onCounter.accept(name + ".total", (long) distributionSummary.totalAmount());
                onHistogram.accept(name + ".histogram", convertHistogramSnapshot(snapshot, 1.0));
            } else if (m instanceof LongTaskTimer) {
                LongTaskTimer longTaskTimer = (LongTaskTimer) m;
                onCounter.accept(name + ".activeTasks", (long) longTaskTimer.activeTasks());
                onCounter.accept(name + ".durationMs", (long) longTaskTimer.duration(TimeUnit.MILLISECONDS));
            }
        });

        metricEncoder.onStreamBegin(-1);
        metricEncoder.onCommonTime(tsMillis);

        for (var entry: longs.entrySet()) {
            encodeLong(metricEncoder, MetricType.RATE, entry.getKey().getName(), entry.getKey().getTags(), tsMillis, entry.getValue());
        }
        for (var entry: doubles.entrySet()) {
            if (entry.getValue().isPresent()) {
                encodeDouble(metricEncoder, MetricType.DGAUGE, entry.getKey().getName(), entry.getKey().getTags(), tsMillis, entry.getValue().get());
            } else {
                log.error("Skipping metric {} with tags {} because of unmergable values", entry.getKey().getName(), entry.getKey().getTags());
            }
        }
        for (var entry: histograms.entrySet()) {
            if (entry.getValue().isPresent()) {
                encodeHistogram(metricEncoder, MetricType.HIST, entry.getKey().getName(), entry.getKey().getTags(), tsMillis, entry.getValue().get());
            } else {
                log.error("Skipping metric {} with tags {} because of unmergable values", entry.getKey().getName(), entry.getKey().getTags());
            }
        }

        metricEncoder.onStreamEnd();
    }

    private static ExplicitHistogramSnapshot convertHistogramSnapshot(HistogramSnapshot histogramSnapshot,
                                                                      double factor) {
        int nBuckets = histogramSnapshot.histogramCounts().length;
        boolean addFakeInf = nBuckets == 0 ||
                !isPositiveInf(histogramSnapshot.histogramCounts()[nBuckets - 1]);
        double[] bounds = new double[addFakeInf ? nBuckets + 1 : nBuckets];
        long[] counts = new long[addFakeInf ? nBuckets + 1 : nBuckets];
        long previousCount = 0;
        for (int i = 0; i < nBuckets; i++) {
            CountAtBucket counter = histogramSnapshot.histogramCounts()[i];
            bounds[i] = isPositiveInf(counter) ? Histograms.INF_BOUND : counter.bucket() / factor;
            double rawCount = counter.count();
            long count = Math.round(rawCount); // seems like it's long deep inside micrometer, so trying to recover it...
            counts[i] = count - previousCount;
            previousCount = count;
        }
        if (addFakeInf) {
            bounds[nBuckets] = Histograms.INF_BOUND;
            counts[nBuckets] = 0;
        }
        return new ExplicitHistogramSnapshot(bounds, counts);
    }

    // copy-pasted from io.micrometer.core.instrument.distribution.CountAtBucket#isPositiveInf
    private static boolean isPositiveInf(CountAtBucket counter) {
        // check for Long.MAX_VALUE to maintain backwards compatibility
        return counter.bucket() == Double.POSITIVE_INFINITY ||
                counter.bucket() == Double.MAX_VALUE ||
                (long) counter.bucket() == Long.MAX_VALUE;
    }

    private static void encodeLong(MetricEncoder metricEncoder, MetricType metricType, String name, List<Tag> tags,
                                   long tsMillis, long value) {
        encodeValue(metricEncoder, metricType, name, tags, () -> metricEncoder.onLong(tsMillis, value));
    }

    private static void encodeDouble(MetricEncoder metricEncoder, MetricType metricType, String name, List<Tag> tags,
                                     long tsMillis, double value) {
        encodeValue(metricEncoder, metricType, name, tags, () -> metricEncoder.onDouble(tsMillis, value));
    }

    private static void encodeHistogram(MetricEncoder metricEncoder, MetricType metricType, String name,
                                        List<Tag> tags, long tsMillis,
                                        ru.yandex.monlib.metrics.histogram.HistogramSnapshot value) {
        encodeValue(metricEncoder, metricType, name, tags, () -> metricEncoder.onHistogram(tsMillis, value));
    }

    private static <T> void encodeValue(MetricEncoder metricEncoder, MetricType metricType, String name,
                                        List<Tag> tags, Runnable onValue) {
        metricEncoder.onMetricBegin(metricType);

        metricEncoder.onLabelsBegin(tags.size() + 1);
        tags.forEach(tag -> metricEncoder.onLabel(
                escapeKey(name, tags, tag.getKey()),
                escapeValue(name, tags, tag.getKey(), tag.getValue())
        ));
        metricEncoder.onLabel("sensor", name);
        metricEncoder.onLabelsEnd();

        onValue.run();

        metricEncoder.onMetricEnd();
    }

    private static String escapeKey(String name, List<Tag> tags, String key) {
        var error = StrictValidator.validateKey(key);
        if (error == null) {
            return key;
        }
        log.warn(String.format("Label key for metric: %s is invalid: '%s' (all tags: %s)", name, error, tags));
        var oldKey = key;
        key = crop(key, 30);
        key = keyFirstCharPattern.matcher(key).replaceAll("Z");
        key = keyAllCharsPattern.matcher(key).replaceAll("Z");

        if (!oldKey.equals(key)) {
            log.warn(String.format("Label key for metric: %s changed from '%s' to '%s' (all tags: %s)", name,
                    oldKey, key, tags));
        }
        return key;
    }

    private static String escapeValue(String name, List<Tag> tags, String key, String value) {
        var error = StrictValidator.validateValue(value);
        if (error == null) {
            return value;
        }
        boolean logWarns = !name.startsWith("jvm");
        if (logWarns) {
            log.warn(String.format("Label value for metric: %s is invalid: '%s' (all tags: %s)", name, error, tags));
        }
        var oldValue = value;
        value = crop(value, 200);
        value = valueFirstCharPattern.matcher(value).replaceAll("_");
        value = valueAllCharsPattern.matcher(value).replaceAll("_");

        if (logWarns && !oldValue.equals(value)) {
            log.warn(String.format("Label '%s' for metric: %s changed from '%s' to '%s' (all tags: %s)", key,
                    name, oldValue, value, tags));
        }
        return value;
    }

    private static String crop(String value, int len) {
        if (value.length() > len) {
            return value.substring(0, len);
        }
        return value;
    }

    private static Pattern keyFirstCharPattern = Pattern.compile("^[^a-zA-Z]");
    private static Pattern keyAllCharsPattern = Pattern.compile("[^a-zA-Z0-9]");
    private static Pattern valueFirstCharPattern = Pattern.compile("^[^a-zA-Z0-9./@_]");
    private static Pattern valueAllCharsPattern = Pattern.compile("[^a-zA-Z0-9./@_,:;()\\[\\]<>\\- ]");
}
