package ru.yandex.travel.grpc.interceptors;

import java.time.Duration;
import java.time.ZoneOffset;
import java.time.ZonedDateTime;
import java.util.Collections;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;

import com.google.common.collect.ImmutableSet;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.MessageOrBuilder;
import com.google.protobuf.util.JsonFormat;
import io.grpc.CallOptions;
import io.grpc.Channel;
import io.grpc.ClientCall;
import io.grpc.ClientInterceptor;
import io.grpc.ForwardingClientCall;
import io.grpc.ForwardingClientCallListener;
import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
import io.grpc.Status;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Timer;
import io.opentracing.Tracer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.slf4j.MDC;

import ru.yandex.misc.ExceptionUtils;
import ru.yandex.travel.commons.logging.NestedMdc;
import ru.yandex.travel.commons.metrics.MetricsUtils;

public class LoggingClientInterceptor implements ClientInterceptor {

    private static final Logger logger = LoggerFactory.getLogger("ru.yandex.travel.grpc.ClientCalls");
    private static final JsonFormat.Printer protoPrinter = JsonFormat.printer().preservingProtoFieldNames();

    private final String target;
    private final String fqdn;
    private final ConcurrentHashMap<String, Meters> metersForMethod = new ConcurrentHashMap<>();
    private final Set<String> blacklistMethodNames;
    private final Tracer tracer;
    private final boolean messageLogging;

    public LoggingClientInterceptor(String fqdn, String targetHost, int targetPort) {
        this(fqdn, targetHost + ":" + targetPort, Collections.emptySet());
    }

    public LoggingClientInterceptor(String fqdn, String target, Set<String> blacklistMethodNames) {
        this(fqdn, target, blacklistMethodNames, null);
    }

    public LoggingClientInterceptor(String fqdn, String target, Set<String> blacklistMethodNames, Tracer tracer) {
        this(fqdn, target, blacklistMethodNames, tracer, false);
    }

    public LoggingClientInterceptor(String fqdn, String target, Set<String> blacklistMethodNames, Tracer tracer,
                                    boolean messageLogging) {
        this.fqdn = fqdn;
        this.target = target;
        this.blacklistMethodNames = ImmutableSet.copyOf(blacklistMethodNames);
        this.tracer = tracer;
        this.messageLogging = messageLogging;
    }

    @Override
    public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(MethodDescriptor<ReqT, RespT> method,
                                                               CallOptions callOptions, Channel next) {
        String callId = UUID.randomUUID().toString();
        ZonedDateTime startedAt = ZonedDateTime.now(ZoneOffset.UTC);
        boolean isLogged = !blacklistMethodNames.contains(method.getFullMethodName());
        long startedAtNanos = System.nanoTime();
        Meters meters = metersForMethod.computeIfAbsent(method.getFullMethodName(), (key) -> new Meters(target, key));
        meters.counter.increment();
        Map<String, String> mdc = MDC.getCopyOfContextMap();

        return new ForwardingClientCall.SimpleForwardingClientCall<>(next.newCall(method, callOptions)) {
            private Metadata headers;
            private String spanId = "none";
            private String traceId = "none";

            @Override
            public void start(Listener<RespT> responseListener, Metadata headers) {
                this.headers = headers;

                headers.put(LoggingInterceptorCommons.METADATA_CALL_ID, callId);
                headers.put(LoggingInterceptorCommons.METADATA_STARTED_AT, startedAt.toInstant());

                if (tracer != null && tracer.activeSpan() != null) {
                    this.spanId = tracer.activeSpan().context().toSpanId();
                    this.traceId = tracer.activeSpan().context().toTraceId();
                }

                super.start(new ForwardingClientCallListener.SimpleForwardingClientCallListener<>(responseListener) {
                    private RespT response;

                    @Override
                    public void onMessage(RespT message) {
                        this.response = message;
                        super.onMessage(message);
                    }

                    @Override
                    public void onClose(Status status, Metadata trailers) {
                        long completedAtNanos = System.nanoTime();
                        if (isLogged) {
                            String formattedMessage = messageLogging ?
                                    formatMessageLogEntry((MessageOrBuilder) this.response) : "";
                            try (var ignored = NestedMdc.nestedMdc(mdc)) {
                                logger.info("-> {} (CallId: {}, Time: {}, StatusCode: {}, StatusDescription: {}, " +
                                                "Target: {}, Trailers: {}, SpanId: {}, TraceId: {}){}",
                                        method.getFullMethodName(), callId, Duration.between(startedAt,
                                                ZonedDateTime.now(ZoneOffset.UTC)).toMillis(),
                                        status.getCode(), status.getDescription(), target, trailers.toString(),
                                        spanId, traceId, formattedMessage);
                            }
                        }
                        meters.executeTimer.record(completedAtNanos - startedAtNanos, TimeUnit.NANOSECONDS);
                        super.onClose(status, trailers);
                    }
                }, headers);
            }

            @Override
            public void sendMessage(ReqT message) {
                if (isLogged) {
                    Optional<Long> timeoutMillis =
                            LoggingInterceptorCommons.getTimeoutDuration(this.headers).map(Duration::toMillis);
                    String forwarderForHeader = this.headers.get(LoggingInterceptorCommons.METADATA_FORWARDED_FOR);
                    String formattedMessage = messageLogging ? formatMessageLogEntry((MessageOrBuilder) message) : "";
                    logger.info("<- {} (CallId: {}, StartedAt: {}, Timeout: {}, Fqdn: {}, ForwardedFor: {}, " +
                                    "Target: {}, SpanId: {}, TraceId: {}){}",
                            method.getFullMethodName(), callId, startedAt,
                            timeoutMillis.orElse(null), fqdn, forwarderForHeader, target,
                            spanId, traceId, formattedMessage);
                }
                super.sendMessage(message);
            }
        };
    }

    private String formatMessageLogEntry(MessageOrBuilder message) {
        try {
            return "\n" + (Objects.isNull(message) ? "null" : protoPrinter.print(message));
        } catch (InvalidProtocolBufferException e) {
            return "\n" + ExceptionUtils.getStackTrace(e);
        }
    }

    private static class Meters {
        final Counter counter;
        final Timer executeTimer;

        Meters(String target, String methodName) {
            counter = Counter.builder("grpc.clientCall.count")
                    .tag("target", target)
                    .tag("method", methodName)
                    .register(Metrics.globalRegistry);

            executeTimer = Timer.builder("grpc.clientCall.executeTime")
                    .tag("target", target)
                    .tag("method", methodName)
                    .publishPercentileHistogram(true)
                    .serviceLevelObjectives(MetricsUtils.mediumDurationSla())
                    .publishPercentiles(MetricsUtils.higherPercentiles())
                    .register(Metrics.globalRegistry);
        }
    }
}
