package ru.yandex.grpc.utils.client.interceptors;

import java.util.concurrent.TimeUnit;

import com.google.common.base.Strings;
import io.grpc.CallOptions;
import io.grpc.Channel;
import io.grpc.ClientCall;
import io.grpc.ClientCall.Listener;
import io.grpc.ClientInterceptor;
import io.grpc.ForwardingClientCall.SimpleForwardingClientCall;
import io.grpc.ForwardingClientCallListener.SimpleForwardingClientCallListener;
import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
import io.grpc.Status;

import ru.yandex.grpc.utils.Headers;
import ru.yandex.grpc.utils.client.EndpointMetrics;
import ru.yandex.grpc.utils.client.EndpointMetrics.EndpointMetric;
import ru.yandex.grpc.utils.client.interceptors.MetricClientStreamTracer.Factory;
import ru.yandex.monlib.metrics.registry.MetricRegistry;
import ru.yandex.solomon.util.host.HostUtils;

/**
 * @author Vladimir Gordiychuk
 */
public class MetricClientInterceptor implements ClientInterceptor {
    private final String clientId;
    private final EndpointMetrics metrics;

    public MetricClientInterceptor(String clientId, MetricRegistry registry) {
        this.clientId = Strings.isNullOrEmpty(clientId) ? HostUtils.getShortName() : clientId;
        this.metrics = new EndpointMetrics(registry);
    }

    @Override
    public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, Channel next) {
        var endpointMetrics = metrics.byEndpoint(method.getFullMethodName());
        callOptions = callOptions.withStreamTracerFactory(new Factory(endpointMetrics));
        return new MetricClientCall<>(next.newCall(method, callOptions), clientId, endpointMetrics);
    }

    private static class MetricClientCall<ReqT, RespT> extends SimpleForwardingClientCall<ReqT, RespT> {
        private final String clientId;
        private final EndpointMetric metrics;

        protected MetricClientCall(ClientCall<ReqT, RespT> delegate, String clientId, EndpointMetric metrics) {
            super(delegate);
            this.clientId = clientId;
            this.metrics = metrics;
        }

        @Override
        public void start(Listener<RespT> listener, Metadata headers) {
            headers.put(Headers.CLIENT_ID, clientId);
            headers.put(Headers.CREATED_AT_MS, Long.toString(System.currentTimeMillis()));

            super.start(new MetricClientCallListener<>(listener, metrics), headers);
            metrics.callStarted();
        }
    }

    private static class MetricClientCallListener<RespT> extends SimpleForwardingClientCallListener<RespT> {
        private final EndpointMetric metrics;
        private long createdAtMs;
        private final long startTimeNanos;

        protected MetricClientCallListener(Listener<RespT> delegate, EndpointMetric metrics) {
            super(delegate);
            this.metrics = metrics;
            this.startTimeNanos = System.nanoTime();
        }

        @Override
        public void onHeaders(Metadata headers) {
            createdAtMs = Headers.getCreatedAt(headers);
            super.onHeaders(headers);
        }

        @Override
        public void onMessage(RespT message) {
            if (createdAtMs != 0) {
                metrics.firstMessageReceived(System.currentTimeMillis() - createdAtMs);
                createdAtMs = 0;
            }
            super.onMessage(message);
        }

        @Override
        public void onClose(Status status, Metadata trailers) {
            long elapsedTime = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNanos);
            metrics.callCompleted(status, elapsedTime);
            super.onClose(status, trailers);
        }
    }
}
