package ru.yandex.chemodan.grpc.client;

import java.util.Objects;
import java.util.concurrent.TimeUnit;

import javax.annotation.Nullable;

import com.google.common.base.Stopwatch;
import com.google.common.base.Supplier;
import io.grpc.CallOptions;
import io.grpc.Channel;
import io.grpc.ClientCall;
import io.grpc.ClientInterceptor;
import io.grpc.ForwardingClientCall;
import io.grpc.ForwardingClientCallListener;
import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
import io.grpc.Status;
import io.grpc.StatusRuntimeException;

import ru.yandex.bolts.collection.Option;
import ru.yandex.misc.log.mlf.Logger;
import ru.yandex.misc.log.mlf.LoggerFactory;
import ru.yandex.misc.random.Random2;

public class LoggingGrpcInterceptor implements ClientInterceptor {

    private static final Logger logger = LoggerFactory.getLogger(LoggingGrpcInterceptor.class);

    private static final CallOptions.Key<String> GRPC_REQUEST_ID = CallOptions.Key.createWithDefault("grpc_request_id", "-");

    private final String requestIdPrefix;

    private final Metadata.Key<String> requestTracingMetadataKey;

    private final Supplier<Boolean> enableVerboseMode;

    public LoggingGrpcInterceptor(String requestIdPrefix,
            Metadata.Key<String> requestTracingMetadataKey,
            Supplier<Boolean> enableVerboseMode)
    {
        this.requestIdPrefix = requestIdPrefix;
        this.requestTracingMetadataKey = requestTracingMetadataKey;
        this.enableVerboseMode = enableVerboseMode;
    }

    @Override
    public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(MethodDescriptor<ReqT, RespT> method,
            CallOptions callOptions, Channel next)
    {
        String grpcRequestId = requestIdPrefix + "_" + Random2.R.nextAlnum(10);
        CallOptions finalCallOptions = callOptions.withOption(GRPC_REQUEST_ID, grpcRequestId);
        return new ForwardingClientCall.SimpleForwardingClientCall<ReqT, RespT>(
                next.newCall(method, finalCallOptions)) {

            private final Stopwatch stopwatch = Stopwatch.createUnstarted();

            @Override
            public void sendMessage(ReqT message) {
                logger.info("send message via gRPC method={} authority={} grpc_request_id={} message={}",
                        method.getFullMethodName(), next.authority(), finalCallOptions.getOption(GRPC_REQUEST_ID),
                        enableVerboseMode.get() ? message : "<disabled>");
                stopwatch.start();
                super.sendMessage(message);
            }

            @Override
            public void start(Listener<RespT> responseListener, Metadata headers) {
                ClientCall.Listener<RespT> listener = new ForwardingClientCallListener.SimpleForwardingClientCallListener<RespT>(responseListener) {
                    @Override
                    public void onMessage(RespT message) {
                        if (stopwatch.isRunning()) {
                            stopwatch.stop();
                        }
                        logger.info("received message via gRPC method={} status={} request_time_ms={} authority={} grpc_request_id={} response={}",
                                method.getFullMethodName(), Status.Code.OK, stopwatch.elapsed(TimeUnit.MILLISECONDS),
                                next.authority(), finalCallOptions.getOption(GRPC_REQUEST_ID),
                                enableVerboseMode.get() ? message : "<disabled>");
                        super.onMessage(message);
                    }
                };
                headers.put(requestTracingMetadataKey, grpcRequestId);
                super.start(listener, headers);
            }

            @Override
            public void cancel(@Nullable String message, @Nullable Throwable cause) {
                if (stopwatch.isRunning()) {
                    stopwatch.stop();
                }
                logger.info("failed gRPC request method={} status={} request_time_ms={} authority={} grpc_request_id={} message={} exception={}",
                        method.getFullMethodName(),
                        Option.ofNullable(cause).filter(StatusRuntimeException.class::isInstance)
                                .map(StatusRuntimeException.class::cast).map(StatusRuntimeException::getStatus)
                                .map(Status::getCode).map(Objects::toString).getOrElse("-"),
                        stopwatch.elapsed(TimeUnit.MILLISECONDS),
                        next.authority(),
                        finalCallOptions.getOption(GRPC_REQUEST_ID),
                        Option.ofNullable(message).getOrElse("-"),
                        cause);
                super.cancel(message, cause);
            }
        };
    }
}
