package ru.yandex.travel.grpc.interceptors;

import java.net.SocketAddress;
import java.time.Duration;
import java.time.Instant;
import java.util.Collections;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;

import com.google.common.collect.ImmutableSet;
import io.grpc.ForwardingServerCall;
import io.grpc.Grpc;
import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import io.grpc.Status;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Timer;
import io.opentracing.Tracer;
import io.sentry.Sentry;
import io.sentry.event.Event;
import io.sentry.event.EventBuilder;
import io.sentry.event.interfaces.ExceptionInterface;
import lombok.Data;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.ObjectProvider;

import ru.yandex.travel.commons.metrics.MetricsUtils;
import ru.yandex.travel.grpc.GrpcServerRunner;

import static ru.yandex.travel.grpc.interceptors.LoggingInterceptorCommons.METADATA_CALL_ID;
import static ru.yandex.travel.grpc.interceptors.LoggingInterceptorCommons.METADATA_FORWARDED_FOR;
import static ru.yandex.travel.grpc.interceptors.LoggingInterceptorCommons.METADATA_FQDN;
import static ru.yandex.travel.grpc.interceptors.LoggingInterceptorCommons.METADATA_STARTED_AT;

public class LoggingServerInterceptor implements ServerInterceptor {
    private static final Logger logger = LoggerFactory.getLogger(GrpcServerRunner.class);

    @Data
    private static class MetersPerStatusKey {
        private final String method;
        private final String status;
    }

    private static class Meters {
        final Counter counter;
        final Timer receiveTimer;
        final Timer executeTimer;

        Meters(String methodName) {
            counter = Counter.builder("grpc.call.count")
                    .tag("method", methodName)
                    .register(Metrics.globalRegistry);

            receiveTimer = Timer.builder("grpc.call.receiveTime")
                    .tag("method", methodName)
                    .serviceLevelObjectives(MetricsUtils.smallDurationSla())
                    .publishPercentiles(MetricsUtils.higherPercentiles())
                    .register(Metrics.globalRegistry);

            executeTimer = Timer.builder("grpc.call.executeTime")
                    .tag("method", methodName)
                    .publishPercentileHistogram(true)
                    .serviceLevelObjectives(MetricsUtils.mediumDurationSla())
                    .publishPercentiles(MetricsUtils.higherPercentiles())
                    .register(Metrics.globalRegistry);
        }
    }

    private static class MetersPerStatus {
        final Counter counter;

        MetersPerStatus(String methodName, String status) {
            counter = Counter.builder("grpc.call.count")
                    .tag("method", methodName)
                    .tag("status", status)
                    .register(Metrics.globalRegistry);
        }
    }

    private final ConcurrentHashMap<String, Meters> metersForMethod = new ConcurrentHashMap<>();
    private final ConcurrentHashMap<MetersPerStatusKey, MetersPerStatus> metersForMethodPerStatus = new ConcurrentHashMap<>();
    private final String serverFqdn;
    private final boolean reportErrorsToSentry;

    private final Set<String> loggingBlacklistMethodNames;

    private final Tracer tracer;

    public LoggingServerInterceptor(String serverFqdn) {
        this(serverFqdn, Collections.emptySet());
    }

    public LoggingServerInterceptor(String serverFqdn, Set<String> loggingBlacklistMethodNames) {
        this(serverFqdn, loggingBlacklistMethodNames, false);
    }

    public LoggingServerInterceptor(String serverFqdn, Set<String> loggingBlacklistMethodNames,
                                    boolean reportErrorsToSentry) {
        this(serverFqdn, loggingBlacklistMethodNames, null, reportErrorsToSentry);
    }

    public LoggingServerInterceptor(String serverFqdn, Set<String> loggingBlacklistMethodNames,
                                    Tracer tracer,
                                    boolean reportErrorsToSentry) {
        this.serverFqdn = serverFqdn;
        this.loggingBlacklistMethodNames = ImmutableSet.copyOf(loggingBlacklistMethodNames);
        this.reportErrorsToSentry = reportErrorsToSentry;
        this.tracer = tracer;
    }

    private String getOrGenerateCallId(Metadata headers) {
        String callId = headers.get(METADATA_CALL_ID);
        if (callId == null) {
            callId = UUID.randomUUID().toString();
        }
        return callId;
    }

    @Override
    public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
            ServerCall<ReqT, RespT> call,
            Metadata headers,
            ServerCallHandler<ReqT, RespT> next) {
        long startedAt = System.nanoTime();
        String methodName = call.getMethodDescriptor().getFullMethodName();
        Meters meters = metersForMethod.computeIfAbsent(methodName, Meters::new);
        String callId = getOrGenerateCallId(headers);
        Instant localStartedAt = Instant.now();
        Instant remoteStartedAt = headers.get(METADATA_STARTED_AT);
        Optional<Duration> maybeTimeoutDuration = LoggingInterceptorCommons.getTimeoutDuration(headers);

        String fqdn = headers.get(METADATA_FQDN);
        String forwardedFor = headers.get(METADATA_FORWARDED_FOR);
        SocketAddress address = call.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR);
        String spanId = "none";
        String traceId = "none";
        if (tracer != null && tracer.activeSpan() != null) {
            spanId = tracer.activeSpan().context().toSpanId();
            traceId = tracer.activeSpan().context().toTraceId();
        }
        final String finalSpanId = spanId;
        final String finalTraceId = traceId;
        if (!loggingBlacklistMethodNames.contains(methodName)) {
            logger.info("-> {} (CallId: {}, StartedAt: {}, Timeout: {}, Fqdn: {}, ForwardedFor: {}, Address: {}, " +
                            "SpanId: {}, TraceId: {})",
                    methodName, callId, remoteStartedAt,
                    maybeTimeoutDuration.map(Duration::toMillis).orElse(null),
                    fqdn, forwardedFor, address, finalSpanId, finalTraceId);
        }
        meters.counter.increment();

        ForwardingServerCall.SimpleForwardingServerCall<ReqT, RespT> interceptedCall =
                new ForwardingServerCall.SimpleForwardingServerCall<ReqT, RespT>(call) {
                    @Override
                    public void sendHeaders(Metadata headers) {
                        headers.put(METADATA_FQDN, serverFqdn);
                        super.sendHeaders(headers);
                    }

                    @Override
                    public void close(Status status, Metadata trailers) {
                        long completedAt = System.nanoTime();

                        meters.executeTimer.record(completedAt - startedAt, TimeUnit.NANOSECONDS);
                        metersForMethodPerStatus.computeIfAbsent(
                            new MetersPerStatusKey(methodName, status.getCode().name()),
                            (key) -> new MetersPerStatus(key.method, key.status)
                        ).counter.increment();

                        if (!loggingBlacklistMethodNames.contains(methodName)) {
                            logger.info("<- {} (CallId: {}, Time: {}, StatusCode: {}, StatusDescription: {}, " +
                                            "Trailers: {}, SpanId: {}, TraceId: {})",
                                    methodName, callId, Duration.ofNanos(completedAt - startedAt).toMillis(),
                                    status.getCode(), status.getDescription(), trailers.toString(), finalSpanId,
                                    finalTraceId);
                            if (status.getCause() != null) {
                                Throwable ex = status.getCause();
                                logger.warn(String.format("!! %s (CallId: %s)", methodName, callId), ex);
                                if (reportErrorsToSentry) {
                                    EventBuilder eventBuilder = new EventBuilder().withMessage(ex.getMessage())
                                            .withLevel(Event.Level.ERROR)
                                            .withServerName(serverFqdn)
                                            .withTag("MethodName", methodName)
                                            .withExtra("CallId", callId)
                                            .withSentryInterface(new ExceptionInterface(status.getCause()));
                                    Sentry.getStoredClient().sendEvent(eventBuilder);
                                }
                            }
                        }
                        super.close(status, trailers);
                    }
                };

        if (remoteStartedAt != null) {
            if (localStartedAt.isAfter(remoteStartedAt)) {
                meters.receiveTimer.record(Duration.between(remoteStartedAt, localStartedAt));
            } else {
                meters.receiveTimer.record(0, TimeUnit.NANOSECONDS);
            }
        }

        return next.startCall(interceptedCall, headers);
    }
}
