package ru.yandex.intranet.d.services.integration.providers.grpc;

import java.util.Objects;

import io.grpc.CallOptions;
import io.grpc.Channel;
import io.grpc.ClientCall;
import io.grpc.ClientInterceptor;
import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
import io.grpc.Status;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * Interceptor for log and monitoring GRPC provider stub's requests.
 *
 * @author Ruslan Kadriev <aqru@yandex-team.ru>
 * @since 06.11.2020
 */
public class MonitoringGrpcProviderStubInterceptor implements ClientInterceptor {
    private static final Logger LOGGER = LoggerFactory.getLogger(MonitoringGrpcProviderStubInterceptor.class);
    private final String providerId;
    private final String tenantId;

    public MonitoringGrpcProviderStubInterceptor(String providerId, String tenantId) {
        this.providerId = providerId;
        this.tenantId = tenantId;

        Objects.requireNonNull(this.providerId, "ProviderId must be provided.");
        Objects.requireNonNull(this.tenantId, "TenantId must be provided.");
    }

    @Override
    public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(MethodDescriptor<ReqT, RespT> method,
                                                               CallOptions callOptions, Channel next) {
        ClientCall<ReqT, RespT> reqTRespTClientCall;
        try {
            reqTRespTClientCall = new BackendForwardingClientCall<>(method, next.newCall(method, callOptions),
                    providerId, tenantId);
        } catch (Throwable e) {
            LOGGER.error(String.format("Error occurred while calling [{%s}] method for provider with id = [{%s}] and " +
                    "tenantId = [{%s}]", method.getFullMethodName(), providerId, tenantId), e);
            throw e;
        }

        return reqTRespTClientCall;
    }

    private static class BackendListener<R> extends ClientCall.Listener<R> {

        private final String methodName;
        private final ClientCall.Listener<R> responseListener;
        private final String providerId;
        private final String tenantId;

        protected BackendListener(String methodName, ClientCall.Listener<R> responseListener, String providerId,
                                  String tenantId) {
            super();
            this.methodName = methodName;
            this.responseListener = responseListener;
            this.providerId = providerId;
            this.tenantId = tenantId;
        }

        @Override
        public void onMessage(R message) {
            responseListener.onMessage(message);
        }

        @Override
        public void onHeaders(Metadata headers) {
            responseListener.onHeaders(headers);
        }

        @Override
        public void onClose(Status status, Metadata trailers) {
            if (!status.isOk()) {
                String errorDescription = status.getDescription() != null ? status.getDescription() : "";
                LOGGER.error(String.format("Error occurred while calling [{%s}] method for provider with id = [{%s}] " +
                                "and tenantId = [{%s}]. Status = [{%s}]. Description = [{%s}].", methodName,
                        providerId, tenantId, status.getCode(), errorDescription), status.getCause());
            }
            responseListener.onClose(status, trailers);
        }

        @Override
        public void onReady() {
            responseListener.onReady();
        }
    }

    private static class BackendForwardingClientCall<ReqT, RespT>
            extends io.grpc.ForwardingClientCall.SimpleForwardingClientCall<ReqT, RespT> {

        private final String methodName;
        private final String providerId;
        private final String tenantId;

        protected BackendForwardingClientCall(MethodDescriptor<ReqT, RespT> method, ClientCall<ReqT, RespT> delegate,
                                              String providerId, String tenantId) {
            super(delegate);
            this.methodName = method.getFullMethodName();
            this.providerId = providerId;
            this.tenantId = tenantId;
        }

        @Override
        public void start(Listener<RespT> responseListener, Metadata headers) {
            BackendListener<RespT> backendListener = new BackendListener<>(methodName, responseListener, providerId,
                    tenantId);
            super.start(backendListener, headers);
        }
    }
}
