package ru.yandex.grpc.utils.client.interceptors;

import javax.annotation.Nullable;

import io.grpc.CallOptions;
import io.grpc.Channel;
import io.grpc.ClientCall;
import io.grpc.ClientCall.Listener;
import io.grpc.ClientInterceptor;
import io.grpc.ForwardingClientCall;
import io.grpc.ForwardingClientCallListener.SimpleForwardingClientCallListener;
import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
import io.grpc.MethodDescriptor.MethodType;
import io.grpc.Status;

import ru.yandex.concurrency.limits.actors.OperationPermit;
import ru.yandex.concurrency.limits.actors.OperationStatus;
import ru.yandex.grpc.utils.client.FailingClientCall;

/**
 * @author Vladimir Gordiychuk
 */
public class LimiterClientInterceptor implements ClientInterceptor {
    private final LimiterProvider limiterProvider;

    public LimiterClientInterceptor(LimiterProvider limiterProvider) {
        this.limiterProvider = limiterProvider;
    }

    @Override
    public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, Channel next) {
        var limiter = limiterProvider.provide(method, callOptions);
        var permit = limiter.acquire();
        if (permit == null) {
            limiter.addStatus(OperationStatus.REJECT);
            return new FailingClientCall<>(Status.RESOURCE_EXHAUSTED
                    .withDescription("call inflight limit"));
        }

        try {
            return new LimiterClientCall<>(next.newCall(method, callOptions), permit, method.getType());
        } catch (Throwable e) {
            permit.release(OperationStatus.IGNORE);
            throw new RuntimeException(e);
        }
    }

    private static class LimiterClientCall<ReqT, REspT> extends ForwardingClientCall.SimpleForwardingClientCall<ReqT, REspT> {
        private final OperationPermit permit;
        private final MethodType methodType;

        protected LimiterClientCall(ClientCall<ReqT, REspT> delegate, OperationPermit permit, MethodType methodType) {
            super(delegate);
            this.permit = permit;
            this.methodType = methodType;
        }

        @Override
        public void start(Listener<REspT> listener, Metadata headers) {
            try {
                super.start(new LimiterClientCallListener<>(listener, permit, methodType), headers);
            } catch (Throwable e) {
                permit.release(OperationStatus.IGNORE);
                throw new RuntimeException(e);
            }
        }

        @Override
        public void cancel(@Nullable String message, @Nullable Throwable cause) {
            permit.release(OperationStatus.IGNORE);
            super.cancel(message, cause);
        }
    }

    private static class LimiterClientCallListener<REspT> extends SimpleForwardingClientCallListener<REspT> {
        private final OperationPermit permit;
        private final MethodType methodType;

        protected LimiterClientCallListener(Listener<REspT> delegate, OperationPermit permit, MethodType methodType) {
            super(delegate);
            this.permit = permit;
            this.methodType = methodType;
        }

        @Override
        public void onHeaders(Metadata headers) {
            super.onHeaders(headers);
        }

        @Override
        public void onMessage(REspT message) {
            switch (methodType) {
                case BIDI_STREAMING, SERVER_STREAMING, UNARY -> permit.release(OperationStatus.SUCCESS);
            }

            super.onMessage(message);
        }


        @Override
        public void onClose(Status status, Metadata trailers) {
            permit.release(OperationStatus.SUCCESS);
            super.onClose(status, trailers);
        }
    }
}
