package ru.yandex.grpc.utils.server;

import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;

import io.grpc.Status;

import ru.yandex.monlib.metrics.MetricConsumer;
import ru.yandex.monlib.metrics.MetricSupplier;
import ru.yandex.monlib.metrics.histogram.Histograms;
import ru.yandex.monlib.metrics.labels.Labels;
import ru.yandex.monlib.metrics.primitives.Histogram;
import ru.yandex.monlib.metrics.primitives.LazyGaugeInt64;
import ru.yandex.monlib.metrics.primitives.Rate;
import ru.yandex.monlib.metrics.registry.MetricRegistry;

/**
 * @author Vladimir Gordiychuk
 */
public class EndpointMetrics implements MetricSupplier {
    private final MetricRegistry registry;
    private final Rate inboundMessage;
    private final Rate outboundMessage;
    private final Rate inboundBytes;
    private final Rate outboundBytes;
    private final Rate callStarted;
    private final Rate callCompleted;
    private final LazyGaugeInt64 callInFlight;
    private final Histogram responseTime;
    private final Histogram inboundDeliveryTime;

    private final ConcurrentMap<Status.Code, Rate> statusCounts = new ConcurrentHashMap<>();

    public EndpointMetrics(String name) {
        registry = new MetricRegistry(Labels.of("endpoint", name));
        inboundMessage = registry.rate("grpc.server.call.inBoundMessages");
        outboundMessage = registry.rate("grpc.server.call.outBoundMessages");
        inboundBytes = registry.rate("grpc.server.call.inBoundBytes");
        outboundBytes = registry.rate("grpc.server.call.outBoundBytes");
        callStarted = registry.rate("grpc.server.call.started");
        callCompleted = registry.rate("grpc.server.call.completed");
        callInFlight = registry.lazyGaugeInt64("grpc.server.call.inFlight",
                () -> callStarted.get() - callCompleted.get());
        responseTime = registry.histogramRate("grpc.server.call.elapsedTimeMs",
                Histograms.exponential(20, 2, 1));
        inboundDeliveryTime = registry.histogramRate("grpc.server.call.delivery.elapsedTimeMs",
            Histograms.exponential(16, 2, 1));
    }

    public void callStarted() {
        callStarted.inc();
    }

    private Rate getStatusCount(Status.Code status) {
        return statusCounts.computeIfAbsent(status, code -> registry.rate("grpc.server.call.status", Labels.of("code", code.name())));
    }

    public void callCompleted(Status status, long elapsedTime) {
        callCompleted.inc();
        getStatusCount(status.getCode()).inc();
        responseTime.record(elapsedTime);
    }

    public void addOutboundMessage() {
        this.outboundMessage.inc();
    }

    public void addInboundMessage() {
        this.inboundMessage.inc();
    }

    public void addOutboundBytes(long bytes) {
        this.outboundBytes.add(bytes);
    }

    public void addInboundBytes(long bytes) {
        this.inboundBytes.add(bytes);
    }

    public void firstMessageReceived(long elapsedTimeMs) {
        inboundDeliveryTime.record(elapsedTimeMs);
    }

    public void combine(EndpointMetrics endpoint) {
        this.inboundMessage.combine(endpoint.inboundMessage);
        this.outboundMessage.combine(endpoint.outboundMessage);
        this.inboundBytes.combine(endpoint.inboundBytes);
        this.outboundBytes.combine(endpoint.outboundBytes);
        this.callStarted.combine(endpoint.callStarted);
        this.callCompleted.combine(endpoint.callCompleted);
        this.responseTime.combine(endpoint.responseTime);
        this.inboundDeliveryTime.combine(endpoint.inboundDeliveryTime);
        for (var entry : endpoint.statusCounts.entrySet()) {
            getStatusCount(entry.getKey()).combine(entry.getValue());
        }
    }

    @Override
    public int estimateCount() {
        return registry.estimateCount();
    }

    @Override
    public void append(long tsMillis, Labels commonLabels, MetricConsumer consumer) {
        registry.append(tsMillis, commonLabels, consumer);
    }
}
