package ru.yandex.solomon.name.resolver.client.grpc;

import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;

import io.grpc.Status;
import io.grpc.stub.StreamObserver;
import io.grpc.testing.GrpcServerRulePublic;
import org.junit.After;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.Timeout;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import ru.yandex.grpc.utils.DefaultClientOptions;
import ru.yandex.grpc.utils.InProcessChannelFactory;
import ru.yandex.solomon.name.resolver.client.FindRequest;
import ru.yandex.solomon.name.resolver.client.FindResponse;
import ru.yandex.solomon.name.resolver.client.NameResolverClient;
import ru.yandex.solomon.name.resolver.client.Resource;
import ru.yandex.solomon.name.resolver.protobuf.ResourceServiceGrpc;
import ru.yandex.solomon.name.resolver.protobuf.ServerStatusRequest;
import ru.yandex.solomon.ut.ManualClock;
import ru.yandex.solomon.ut.ManualScheduledExecutorService;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static ru.yandex.solomon.name.resolver.client.ResourcesTestSupport.staticResource;

/**
 * @author Vladimir Gordiychuk
 */
public class GrpcNameResolverClientTest {
    private static final Logger logger = LoggerFactory.getLogger(GrpcNameResolverClientTest.class);

    @Rule
    public Timeout timeout = Timeout.builder()
            .withTimeout(30, TimeUnit.SECONDS)
            .build();

    private ManualClock clock;
    private NameResolverClient client;
    private ManualScheduledExecutorService timer;
    private Cluster cluster = new Cluster();

    private Node alice;
    private Node bob;
    private Node eva;

    @Before
    public void setUp() throws Throwable {
        alice = new Node("alice");
        bob = new Node("bob");
        eva = new Node("eva");

        clock = new ManualClock();
        timer = new ManualScheduledExecutorService(1, clock);
        var opts = DefaultClientOptions.newBuilder()
                .setChannelFactory(new InProcessChannelFactory())
                .setRpcExecutor(ForkJoinPool.commonPool())
                .setTimer(timer)
                .build();

        alice.setUp();
        bob.setUp();
        eva.setUp();

        cluster.setLeader(alice.server.getServerName());
        client = new GrpcNameResolverClient(List.of(
                alice.server.getServerName(),
                bob.server.getServerName(),
                eva.server.getServerName()), opts);
    }

    @After
    public void tearDown() {
        timer.shutdownNow();
        alice.tearDown();
        bob.tearDown();
        eva.tearDown();
        if (client != null) {
            client.close();
        }
    }

    @Test
    public void emptyResultForNotExistShard() throws InterruptedException {
        cluster.assignShard("another_shard", alice);
        awaitClusterStateSync();

        FindResponse response = client.find(FindRequest.newBuilder()
                .cloudId("not_exist_cloud")
                .build())
                .join();

        assertEquals(List.of(), response.resources);
        assertFalse(response.truncated);
    }

    @Test
    public void findResources() throws InterruptedException {
        String shardId = "test";
        cluster.assignShard(shardId, bob);
        awaitClusterStateSync();

        List<Resource> resources = List.of(
                resource(shardId).setResourceId("a").setName("name-a"),
                resource(shardId).setResourceId("c").setName("name-b"),
                resource(shardId).setResourceId("b").setName("name-c"));
        cluster.addResources(shardId, resources);

        var response = client.find(FindRequest.newBuilder()
                .cloudId(shardId)
                .build())
                .join();

        assertFalse(response.truncated);
        assertEquals("", response.nextPageToken);
        assertEquals(Set.copyOf(resources), Set.copyOf(response.resources));
    }

    @Test
    public void getShards() throws InterruptedException {
        String shardId = "test";
        cluster.assignShard(shardId, bob);
        awaitClusterStateSync();

        List<Resource> resources = List.of(
                resource(shardId).setResourceId("a").setName("name-a"),
                resource(shardId).setResourceId("c").setName("name-b"),
                resource(shardId).setResourceId("b").setName("name-c"));
        cluster.addResources(shardId, resources);

        var response = client.getShardIds().join();

        assertEquals(Set.of(shardId), response.ids());
    }

    @Test
    public void dispatchFindByNode() throws InterruptedException {
        cluster.assignShard("alice", alice);
        awaitClusterStateSync();
        cluster.assignShard("bob", bob);
        cluster.assignShard("eva", eva);
        awaitClusterStateSync();

        for (String cloudId : List.of("alice", "bob", "eva")) {
            cluster.addResources(cloudId, List.of(resource(cloudId).setResourceId(cloudId)));
            var response = client.find(FindRequest.newBuilder()
                    .cloudId(cloudId)
                    .build())
                    .join();

            assertFalse(cloudId, response.truncated);
            assertEquals(cloudId, List.of(resource(cloudId).setResourceId(cloudId)), response.resources);
        }
    }

    @Test
    public void unavailableWhenClusterStateNotSync() {
        var status = client.find(FindRequest.newBuilder()
                .cloudId("not_exist_cloud")
                .build())
                .thenApply(ignore -> Status.OK)
                .exceptionally(Status::fromThrowable)
                .join();

        assertEquals(status.toString(), status.getCode(), Status.Code.UNAVAILABLE);
    }

    private void awaitClusterStateSync() throws InterruptedException {
        do {
            clock.passedTime(15, TimeUnit.SECONDS);
        } while (!cluster.serverStatusSync.await(5, TimeUnit.MILLISECONDS));
    }

    private static Resource resource(String shardId) {
        return staticResource().setCloudId(shardId);
    }

    private class Node extends ResourceServiceGrpc.ResourceServiceImplBase {
        private GrpcServerRulePublic server;
        private final String name;

        public Node(String name) {
            this.name = name;
        }

        void setUp() throws Throwable {
            server = new GrpcServerRulePublic();
            server.before();
            logger.info("{}: address {}", name, server.getServerName());
            server.getServiceRegistry().addService(this);
        }

        void tearDown() {
            if (server != null) {
                server.after();
            }
        }

        @Override
        public void serverStatus(ServerStatusRequest request, StreamObserver<ru.yandex.solomon.name.resolver.protobuf.ServerStatusResponse> responseObserver) {
            responseObserver.onNext(Proto.toProto(cluster.serverStatus()));
            responseObserver.onCompleted();
        }

        @Override
        public void find(ru.yandex.solomon.name.resolver.protobuf.FindRequest request, StreamObserver<ru.yandex.solomon.name.resolver.protobuf.FindResponse> responseObserver) {
            var req = Proto.fromProto(request);
            var shard = cluster.shardById.get(req.cloudId);
            if (shard == null) {
                responseObserver.onError(Status.NOT_FOUND.withDescription("shard not exist").asRuntimeException());
                return;
            }

            if (!Objects.equals(shard.node, server.getServerName())) {
                responseObserver.onError(Status.NOT_FOUND.withDescription("shard absent on host").asRuntimeException());
                return;
            }

            responseObserver.onNext(Proto.toProto(shard.find(req)));
            responseObserver.onCompleted();
        }
    }

    private static class Cluster {
        private Map<String, Shard> shardById = new ConcurrentHashMap<>();
        private AtomicInteger stateHash = new AtomicInteger();
        private volatile String leader = "";
        private volatile CountDownLatch serverStatusSync = new CountDownLatch(3);

        public synchronized void assignShard(String shardId, Node node) {
            var shard = shardById.computeIfAbsent(shardId, Shard::new);
            shard.node = node.server.getServerName();
            stateHash.incrementAndGet();
            serverStatusSync = new CountDownLatch(2);
        }

        public void setLeader(String leader) {
            this.leader = leader;
        }

        public void addResources(String shardId, List<Resource> resources) {
            var shard = shardById.computeIfAbsent(shardId, Shard::new);
            for (var resource : resources) {
                shard.resourceById.put(resource.resourceId, resource);
            }
        }

        public synchronized ServerStatusResponse serverStatus() {
            try {
                var result = new ServerStatusResponse(stateHash.get(), leader);
                for (var shard : shardById.values()) {
                    result.addShard(shard.node, shard.shardId);
                }
                return result;
            } finally {
                serverStatusSync.countDown();
            }
        }
    }

    private static class Shard {
        private final String shardId;
        private volatile String node;
        private final Map<String, Resource> resourceById = new ConcurrentHashMap<>();

        public Shard(String shardId) {
            this.shardId = shardId;
        }

        public FindResponse find(FindRequest request) {
            return new FindResponse(List.copyOf(resourceById.values()), false);
        }
    }
}
