package ru.yandex.grpc.utils.client.interceptors;

import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;

import io.grpc.CallOptions;
import io.grpc.ManagedChannel;
import io.grpc.MethodDescriptor;
import io.grpc.Status;
import io.grpc.Status.Code;
import io.grpc.health.v1.HealthCheckRequest;
import io.grpc.health.v1.HealthCheckResponse;
import io.grpc.health.v1.HealthGrpc;
import io.grpc.inprocess.InProcessChannelBuilder;
import io.grpc.testing.GrpcServerRulePublic;
import org.junit.After;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.Timeout;

import ru.yandex.concurrency.limits.actors.Limiter;
import ru.yandex.concurrency.limits.actors.LimiterImpl;
import ru.yandex.grpc.utils.SingleResponseStreamObserver;

import static io.grpc.stub.ClientCalls.asyncUnaryCall;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;

/**
 * @author Vladimir Gordiychuk
 */
public class LimiterClientInterceptorTest {
    @Rule
    public GrpcServerRulePublic server = new GrpcServerRulePublic();
    @Rule
    public Timeout timeout = Timeout.builder()
            .withTimeout(30, TimeUnit.SECONDS)
            .build();

    private HealthServiceStub service;
    private ManagedChannel channel;
    private Limiter limiter;

    @Before
    public void setUp() throws Exception {
        service = new HealthServiceStub();
        server.getServiceRegistry().addService(service);
        limiter = LimiterImpl.newBuilder().minLimit(1).initLimit(2).maxLimit(2).build();
        channel = InProcessChannelBuilder.forName(server.getServerName())
                .usePlaintext()
                .intercept(new LimiterClientInterceptor(new LimiterProvider() {
                    @Override
                    public <ReqT, RespT> Limiter provide(MethodDescriptor<ReqT, RespT> method, CallOptions callOptions) {
                        return limiter;
                    }
                }))
                .build();
    }

    @After
    public void tearDown() throws Exception {
        if (channel != null) {
            channel.shutdownNow();
        }
    }

    @Test
    public void limitInflight() throws InterruptedException {
        assertEquals(0, limiter.inflight());

        var futureOne = healthCheck();
        assertFalse(futureOne.isDone());
        assertEquals(1, limiter.inflight());

        var futureTwo = healthCheck();
        assertFalse(futureTwo.isDone());
        assertEquals(2, limiter.inflight());

        // limit reached, next request should be limited
        var tree = healthCheck().join();
        assertEquals(tree.toString(), Code.RESOURCE_EXHAUSTED, tree.getCode());

        // complete inflight requests
        assertEquals(2, limiter.inflight());

        service.calls.take().complete();
        CompletableFuture.anyOf(futureOne, futureTwo).join();
        assertEquals(1, limiter.inflight());

        service.calls.take().complete();
        CompletableFuture.allOf(futureOne, futureTwo).join();
        assertEquals(0, limiter.inflight());
    }

    @Test
    public void errorReleaseInflight() throws InterruptedException {
        assertEquals(0, limiter.inflight());

        var future = healthCheck();
        assertFalse(future.isDone());
        assertEquals(1, limiter.inflight());

        service.calls.take().error(Status.NOT_FOUND);

        var status = future.join();
        assertEquals(status.toString(), Code.NOT_FOUND, status.getCode());
        assertEquals(0, limiter.inflight());
    }

    @Test
    public void cancelCallReleaseInflight() {
        assertEquals(0, limiter.inflight());

        var future = new CompletableFuture<HealthCheckResponse>();
        var call = channel.newCall(HealthGrpc.getCheckMethod(), CallOptions.DEFAULT);
        asyncUnaryCall(call, HealthCheckRequest.getDefaultInstance(), new SingleResponseStreamObserver<>(future));

        assertEquals(1, limiter.inflight());

        call.cancel("cancel by client", null);

        var status = future.thenApply(ignore -> Status.OK).exceptionally(Status::fromThrowable).join();
        assertEquals(status.toString(), Code.CANCELLED, status.getCode());
        assertEquals(0, limiter.inflight());
    }

    @Test
    public void timeoutReleaseInflight() {
        assertEquals(0, limiter.inflight());

        var future = healthCheck(CallOptions.DEFAULT.withDeadlineAfter(1, TimeUnit.NANOSECONDS));
        var status = future.join();
        assertEquals(status.toString(), Code.DEADLINE_EXCEEDED, status.getCode());
        assertEquals(0, limiter.inflight());
    }

    private CompletableFuture<Status> healthCheck() {
        return healthCheck(CallOptions.DEFAULT);
    }

    private CompletableFuture<Status> healthCheck(CallOptions callOptions) {
        assertNotNull(channel);
        var future = new CompletableFuture<HealthCheckResponse>();
        var call = channel.newCall(HealthGrpc.getCheckMethod(), callOptions);
        asyncUnaryCall(call, HealthCheckRequest.getDefaultInstance(), new SingleResponseStreamObserver<>(future));
        return future.thenApply(ignore -> Status.OK).exceptionally(Status::fromThrowable);
    }
}
