package ru.yandex.grpc.utils.server.interceptors;

import java.util.concurrent.TimeUnit;

import javax.annotation.ParametersAreNonnullByDefault;

import io.grpc.Metadata;
import io.grpc.ServerStreamTracer;
import io.grpc.Status;

import ru.yandex.grpc.utils.Headers;
import ru.yandex.grpc.utils.server.EndpointMetrics;
import ru.yandex.grpc.utils.server.ServerMetrics;

/**
 * @author Vladimir Gordiychuk
 */
@ParametersAreNonnullByDefault
public class MetricServerStreamTracer extends ServerStreamTracer {
    private final EndpointMetrics metrics;
    private final long startTimeNanos;
    private final long createdAtMs;

    MetricServerStreamTracer(EndpointMetrics metrics, long startTimeNanos, long createdAtMs) {
        this.metrics = metrics;
        this.startTimeNanos = startTimeNanos;
        this.createdAtMs = createdAtMs;
    }

    @Override
    public void serverCallStarted(ServerCallInfo<?, ?> callInfo) {
        metrics.callStarted();
    }

    @Override
    public void streamClosed(Status status) {
        long elapsedTime = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNanos);
        metrics.callCompleted(status, elapsedTime);
    }

    @Override
    public void outboundMessage(int seqNo) {
        metrics.addOutboundMessage();
    }

    @Override
    public void inboundMessage(int seqNo) {
        metrics.addInboundMessage();
        if (seqNo == 0 && createdAtMs != 0) {
            long deliveryMillis = System.currentTimeMillis() - createdAtMs;
            metrics.firstMessageReceived(deliveryMillis);
        }
    }

    @Override
    public void outboundWireSize(long bytes) {
        metrics.addOutboundBytes(bytes);
    }

    @Override
    public void inboundWireSize(long bytes) {
        metrics.addInboundBytes(bytes);
    }

    public static class Factory extends ServerStreamTracer.Factory {
        private final ServerMetrics metrics;

        public Factory(ServerMetrics metrics) {
            this.metrics = metrics;
        }

        @Override
        public ServerStreamTracer newServerStreamTracer(String fullMethodName, Metadata headers) {
            long startTime = System.nanoTime();
            var endpoint = metrics.getEndpoint(Headers.getClient(headers), fullMethodName);
            return new MetricServerStreamTracer(endpoint, startTime, Headers.getCreatedAt(headers));
        }
    }
}
