package ru.yandex.intranet.imscore.infrastructure.presentation.grpc.interceptors;

import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.TimeUnit;

import com.google.common.base.Splitter;
import io.grpc.ForwardingServerCall;
import io.grpc.Grpc;
import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.Status;
import io.grpc.inprocess.InProcessSocketAddress;
import net.devh.boot.grpc.common.util.GrpcUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.slf4j.MDC;

import ru.yandex.intranet.imscore.metrics.GrpcServerMetrics;
import ru.yandex.intranet.imscore.util.MdcTaskDecorator;
import ru.yandex.intranet.imscore.util.OneShotStopWatch;

/**
 * Server call wrapper to write GRPC access log.
 *
 * @author Ruslan Kadriev <aqru@yandex-team.ru>
 */
public class AccessLogServerCall<Q, A> extends ForwardingServerCall.SimpleForwardingServerCall<Q, A> {

    private static final Logger ACCESS_LOG = LoggerFactory.getLogger("ACCESS_LOG");

    private final OneShotStopWatch stopWatch;
    private final String logId;
    private final Metadata headers;
    private final GrpcServerMetrics grpcServerMetrics;

    public AccessLogServerCall(ServerCall<Q, A> delegate, String logId, Metadata headers,
                               GrpcServerMetrics grpcServerMetrics) {
        super(delegate);
        this.logId = logId;
        this.headers = headers;
        this.stopWatch = new OneShotStopWatch();
        this.grpcServerMetrics = grpcServerMetrics;
    }

    @Override
    public void close(final Status status, final Metadata responseHeaders) {
        super.close(status, responseHeaders);
        log(status);
    }

    private void log(final Status status) {
        long elapsedMillis = stopWatch.elapsed(TimeUnit.MILLISECONDS);
        String methodName = Optional.ofNullable(GrpcUtils.extractMethodName(delegate().getMethodDescriptor()))
                .orElse("-");
        String serviceName = Optional.ofNullable(GrpcUtils.extractServiceName(delegate().getMethodDescriptor()))
                .orElse("-");
        //noinspection UnstableApiUsage
        String shortServiceName = Splitter.on(".").splitToStream(serviceName).reduce((l, r) -> r).orElse(serviceName);
        String path = shortServiceName + "/" + methodName;
        String methodType = delegate().getMethodDescriptor().getType().name();
        String statusCode = status.getCode().name();

        Optional<String> remoteIp = getRemoteIp().map(this::prepareIp);
        Optional<String> requestId = Optional.ofNullable(headers
                .get(Metadata.Key.of("X-Request-ID", Metadata.ASCII_STRING_MARSHALLER)));
        Map<String, String> mdcMap = new HashMap<>();
        mdcMap.put("access_protocol", "GRPC");
        remoteIp.ifPresent(v -> mdcMap.put("access_remote_ip", v));
        mdcMap.put("access_grpc_service", serviceName);
        mdcMap.put("access_grpc_method", methodName);
        mdcMap.put("access_grpc_method_type", methodType);
        mdcMap.put("access_grpc_status", statusCode);
        mdcMap.put("access_response_time_ms", String.valueOf(elapsedMillis));
        requestId.ifPresent(v -> mdcMap.put("access_http_request_id", v));
        mdcMap.put(MdcTaskDecorator.LOG_ID_MDC_KEY, logId);
        Map<String, String> previousMdc = MDC.getCopyOfContextMap();
        try {
            MDC.setContextMap(mdcMap);
            ACCESS_LOG.info("{} GRPC - {} {} - {}", remoteIp.orElse("-"), path, statusCode, elapsedMillis);
        } finally {
            MDC.clear();
            MDC.setContextMap(previousMdc);
        }
        grpcServerMetrics.onRequestCompletion(status.getCode(), elapsedMillis);
    }

    private Optional<String> getRemoteIp() {
        SocketAddress socketAddress = delegate().getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR);
        if (socketAddress == null) {
            return Optional.empty();
        }
        if (socketAddress instanceof InetSocketAddress inetSocketAddress) {
            return Optional.of(inetSocketAddress.getHostString());
        }
        if (socketAddress instanceof InProcessSocketAddress inProcessSocketAddress) {
            return Optional.of(inProcessSocketAddress.getName());
        }
        return Optional.of(socketAddress.toString());
    }

    private String prepareIp(String value) {
        // Cut %0 from addresses like 2a02:6b8:0:e00:0:0:0:1a%0
        if (value == null || value.isEmpty()) {
            return value;
        }
        if (value.contains("%")) {
            int lastSeparatorIndex = value.lastIndexOf("%");
            return value.substring(0, lastSeparatorIndex);
        }
        return value;
    }
}
