package ru.yandex.intranet.imscore.infrastructure.presentation.grpc.interceptors;

import java.util.UUID;

import io.grpc.Context;
import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import net.devh.boot.grpc.common.util.InterceptorOrder;
import net.devh.boot.grpc.server.interceptor.GrpcGlobalServerInterceptor;
import org.apache.logging.log4j.core.config.Order;
import org.slf4j.MDC;

import ru.yandex.intranet.imscore.metrics.GrpcServerMetrics;
import ru.yandex.intranet.imscore.util.MdcTaskDecorator;

@GrpcGlobalServerInterceptor
@Order(InterceptorOrder.ORDER_FIRST + 1)
public class AccessLogInterceptor implements ServerInterceptor {

    public static final Context.Key<String> LOG_ID_KEY = Context.key("logId");

    private final GrpcServerMetrics grpcServerMetrics;

    public AccessLogInterceptor(GrpcServerMetrics grpcServerMetrics) {
        this.grpcServerMetrics = grpcServerMetrics;
    }

    @Override
    public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(ServerCall<ReqT, RespT> call, Metadata headers,
                                                                 ServerCallHandler<ReqT, RespT> next) {
        String logId = UUID.randomUUID().toString();
        Context context = Context.current().withValue(LOG_ID_KEY, logId);
        AccessLogServerCall<ReqT, RespT> wrappedCall = new AccessLogServerCall<>(call, logId, headers,
                grpcServerMetrics);
        return interceptCall(context, wrappedCall, headers, next, logId);
    }

    private static <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(Context context,
                                                                         ServerCall<ReqT, RespT> call,
                                                                         Metadata headers,
                                                                         ServerCallHandler<ReqT, RespT> next,
                                                                         String logId) {
        Context previous = context.attach();
        MDC.put(MdcTaskDecorator.LOG_ID_MDC_KEY, logId);
        try {
            return new LogIdCallListener<>(next.startCall(call, headers), context, logId);
        } finally {
            context.detach(previous);
            MDC.remove(MdcTaskDecorator.LOG_ID_MDC_KEY);
        }
    }

}
