package ru.yandex.solomon.metrics.parser.prometheus;

import java.io.IOException;

import com.google.protobuf.InvalidProtocolBufferException;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufInputStream;
import io.netty.buffer.PooledByteBufAllocator;
import io.netty.handler.codec.compression.Snappy;

import ru.yandex.monitoring.prometheus.Label;
import ru.yandex.monitoring.prometheus.MetricMetadata;
import ru.yandex.monitoring.prometheus.Sample;
import ru.yandex.monitoring.prometheus.TimeSeries;
import ru.yandex.monitoring.prometheus.WriteRequest;
import ru.yandex.monlib.metrics.MetricType;
import ru.yandex.monlib.metrics.encode.ParseException;
import ru.yandex.monlib.metrics.labels.LabelAllocator;
import ru.yandex.monlib.metrics.labels.Labels;
import ru.yandex.monlib.metrics.labels.LabelsBuilder;
import ru.yandex.solomon.metrics.parser.MetricConsumer;
import ru.yandex.solomon.metrics.parser.TreeParser;

/**
 * @author Sergey Polovko
 */
public class RemoteWriteParser implements TreeParser {

    private static final String NAME_LABEL = "__name__";

    private final LabelAllocator labelAllocator;
    private final String metricNameLabel;

    public RemoteWriteParser(LabelAllocator labelAllocator, String metricNameLabel) {
        this.labelAllocator = labelAllocator;
        this.metricNameLabel = metricNameLabel;
    }

    @Override
    public void parseAndProcess(
        Labels commonLabels,
        ByteBuf compressed,
        MetricConsumer metricConsumer,
        ErrorListener errorListener,
        FormatListener formatListener,
        boolean onlyNewFormatWrites)
    {
        // uncompress incoming data with snappy
        ByteBuf uncompressed = PooledByteBufAllocator.DEFAULT.heapBuffer();
        try {
            var snappy = new Snappy();
            snappy.decode(compressed, uncompressed);
        } catch (Throwable t) {
            uncompressed.release();
            throw new ParseException("cannot uncompress prometheus data", t);
        }

        // parse write request and release temp buffer as soon as possible
        WriteRequest writeRequest;
        try (var input = new ByteBufInputStream(uncompressed, true)) {
            writeRequest = WriteRequest.parseFrom(input);
        } catch (InvalidProtocolBufferException e) {
            throw new ParseException("corrupted prometheus data", e);
        } catch (IOException e) {
            throw new ParseException("cannot read prometheus data", e);
        }

        // convert prometheus data to solomon metrics
        try {
            LabelsBuilder labelsBuilder = new LabelsBuilder(Labels.MAX_LABELS_COUNT, labelAllocator);
            metricConsumer.ensureCapacity(writeRequest.getTimeseriesCount());

            for (int i = 0; i < writeRequest.getTimeseriesCount(); i++) {
                TimeSeries timeSeries = writeRequest.getTimeseries(i);

                // (1) convert type
                MetricType type = MetricType.DGAUGE;
                if (i < writeRequest.getMetadataCount()) {
                    MetricMetadata metadata = writeRequest.getMetadata(i);
                    type = Types.fromProto(metadata.getType());
                }

                // (2) convert labels
                var labels = convertLabels(labelsBuilder, timeSeries);

                // (3) convert samples
                metricConsumer.onMetricBegin(type, labels, false);
                if (timeSeries.getSamplesCount() == 1) {
                    consumeOneSample(metricConsumer, timeSeries);
                } else {
                    consumeManySamples(metricConsumer, timeSeries);
                }
            }
        } catch (Throwable t) {
            throw new ParseException("unable to parse prometheus write response", t);
        }
    }

    private Labels convertLabels(LabelsBuilder labelsBuilder, TimeSeries timeSeries) {
        labelsBuilder.clear();

        final int labelsCount = timeSeries.getLabelsCount();
        int idx = 0;

        // only if shard has metric name label, and it is not '__name__' then
        // try to replace '__name__' with configured one
        if (!metricNameLabel.isEmpty() && !NAME_LABEL.equals(metricNameLabel)) {
            while (idx < labelsCount) {
                Label label = timeSeries.getLabels(idx++);
                if (NAME_LABEL.equals(label.getName())) {
                    labelsBuilder.add(metricNameLabel, label.getValue());
                    // expect only one '__name__' label per time series
                    break;
                } else {
                    labelsBuilder.add(label.getName(), label.getValue());
                }
            }
        }

        // simplified cycle without redundant string comparison on each iteration
        while (idx < labelsCount) {
            Label label = timeSeries.getLabels(idx++);
            labelsBuilder.add(label.getName(), label.getValue());
        }

        return labelsBuilder.build();
    }

    private void consumeOneSample(MetricConsumer metricConsumer, TimeSeries timeSeries) {
        Sample sample = timeSeries.getSamples(0);
        metricConsumer.onPoint(sample.getTimestamp(), sample.getValue());
    }

    private void consumeManySamples(MetricConsumer metricConsumer, TimeSeries timeSeries) {
        var ts = ru.yandex.monlib.metrics.series.TimeSeries.newDouble(timeSeries.getSamplesCount());
        for (Sample sample : timeSeries.getSamplesList()) {
            ts = ts.addDouble(sample.getTimestamp(), sample.getValue());
        }
        metricConsumer.onTimeSeries(ts);
    }
}
