package ru.yandex.stockpile.client.impl;

import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;

import com.google.common.collect.Range;
import io.grpc.stub.StreamObserver;
import io.grpc.testing.GrpcServerRulePublic;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import ru.yandex.grpc.utils.DefaultClientOptions;
import ru.yandex.grpc.utils.InProcessChannelFactory;
import ru.yandex.stockpile.api.EStockpileStatusCode;
import ru.yandex.stockpile.api.StockpileServiceGrpc;
import ru.yandex.stockpile.api.TServerStatusRequest;
import ru.yandex.stockpile.api.TServerStatusResponse;
import ru.yandex.stockpile.api.TShardStatus;
import ru.yandex.stockpile.client.StockpileClientOptions;

import static org.hamcrest.core.IsEqual.equalTo;
import static org.hamcrest.core.IsNull.nullValue;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;

/**
 * @author Vladimir Gordiychuk
 */
public class ClusterTest {
    private static final Logger logger = LoggerFactory.getLogger(ClusterTest.class);

    private Cluster cluster;

    private Node alice;
    private Node bob;

    @Before
    public void setUp() throws Throwable {
        alice = new Node("alice");
        bob = new Node("bob");

        alice.setUp();
        bob.setUp();
        cluster = new Cluster(
                List.of(alice.server.getServerName(), bob.server.getServerName()),
                StockpileClientOptions.newBuilder(
                        DefaultClientOptions.newBuilder()
                            .setChannelFactory(new InProcessChannelFactory()))
                    .build());
    }

    @After
    public void tearDown() throws Exception {
        cluster.close();
        alice.tearDown();
        bob.tearDown();
    }

    @Test
    public void notInitializedShardsReturnsAsNull() {
        assertThat(cluster.getShard(1), nullValue());
    }

    @Test
    public void initShardState() {
        alice.addShard(3, false);
        cluster.forceClusterStatusUpdate().join();

        var shard = cluster.getShard(3);
        assertNotNull(shard);
        assertEquals(3, shard.getShardId());
        assertFalse(shard.isReady());
        assertEquals(alice.server.getServerName(), shard.getFqdn());
    }

    @Test
    public void updateShardState() {
        alice.addShard(42, false);
        cluster.forceClusterStatusUpdate().join();

        alice.addShard(42, true);
        cluster.forceClusterStatusUpdate().join();

        var shard = cluster.getShard(42);
        assertNotNull(shard);
        assertEquals(42, shard.getShardId());
        assertTrue(shard.isReady());
        assertEquals(alice.server.getServerName(), shard.getFqdn());
    }

    @Test
    public void reduceCountShardOnHost() throws InterruptedException {
        alice.addShard(1, true);
        alice.addShard(2, true);
        alice.addShard(3, true);

        cluster.forceClusterStatusUpdate().join();

        alice.removeShard(2);
        TimeUnit.MILLISECONDS.sleep(5);

        bob.addShard(2, true);

        cluster.forceClusterStatusUpdate().join();

        {
            var shard = cluster.getShard(1);
            assertNotNull(shard);
            assertEquals(1, shard.getShardId());
            assertTrue(shard.isReady());
            assertEquals(alice.server.getServerName(), shard.getFqdn());
        }

        {
            var shard = cluster.getShard(2);
            assertNotNull(shard);
            assertEquals(2, shard.getShardId());
            assertTrue(shard.isReady());
            assertEquals(bob.server.getServerName(), shard.getFqdn());
        }

        {
            var shard = cluster.getShard(3);
            assertNotNull(shard);
            assertEquals(3, shard.getShardId());
            assertTrue(shard.isReady());
            assertEquals(alice.server.getServerName(), shard.getFqdn());
        }
    }

    @Test
    public void statusErrorInvalidateShardsOnHost() throws InterruptedException {
        alice.addShard(42, true);
        cluster.forceClusterStatusUpdate().join();
        alice.predefineStatus = EStockpileStatusCode.NODE_UNAVAILABLE;

        TimeUnit.MILLISECONDS.sleep(5);
        bob.addShard(42, true);
        cluster.forceClusterStatusUpdate().join();

        {
            var shard = cluster.getShard(42);
            assertNotNull(shard);
            assertEquals(42, shard.getShardId());
            assertTrue(shard.isReady());
            assertEquals(bob.server.getServerName(), shard.getFqdn());
        }

        bob.predefineStatus = EStockpileStatusCode.NODE_UNAVAILABLE;
        alice.predefineStatus = EStockpileStatusCode.OK;
        TimeUnit.MILLISECONDS.sleep(5);
        cluster.forceClusterStatusUpdate().join();

        {
            var shard = cluster.getShard(42);
            assertNotNull(shard);
            assertEquals(42, shard.getShardId());
            assertTrue(shard.isReady());
            assertEquals(alice.server.getServerName(), shard.getFqdn());
        }
    }

    @Test
    public void compressFormatCompatible() throws InterruptedException {
        alice.format = Range.closed(3, 6);
        bob.format = Range.closed(5, 8);
        TimeUnit.MILLISECONDS.sleep(20);
        cluster.forceClusterStatusUpdate().join();
        Range<Integer> format = cluster.getCompatibleCompressFormat();

        assertThat(format.upperEndpoint(), equalTo(6));
        assertThat(format.lowerEndpoint(), equalTo(5));
    }

    @Test
    public void compressFormatIncompatible() throws InterruptedException {
        alice.format = Range.closed(3, 6);
        bob.format = Range.closed(7, 9);
        TimeUnit.MILLISECONDS.sleep(20);
        cluster.forceClusterStatusUpdate().join();
        Range<Integer> format = cluster.getCompatibleCompressFormat();
        assertTrue(format.isEmpty());
    }

    private static class Node extends StockpileServiceGrpc.StockpileServiceImplBase {
        private GrpcServerRulePublic server;
        private String name;
        private ConcurrentMap<Integer, Shard> shards = new ConcurrentHashMap<>();
        private volatile EStockpileStatusCode predefineStatus = EStockpileStatusCode.OK;
        private volatile Range<Integer> format = Range.closed(0, 0);

        public Node(String name) {
            this.name = name;
        }

        void setUp() throws Throwable {
            server = new GrpcServerRulePublic();
            server.before();
            logger.info("{}: address {}", name, server.getServerName());
            server.getServiceRegistry().addService(this);
        }

        void tearDown() {
            predefineStatus = EStockpileStatusCode.NODE_UNAVAILABLE;
            server.after();
        }

        public void addShard(int id, boolean ready) {
            logger.info(name + " shard{id="+id+", ready="+ready+"}");
            shards.put(id, new Shard(server.getServerName(), id, ready, ready, ready, System.currentTimeMillis()));
        }

        public void removeShard(int id) {
            logger.info(name + " remove shard shard{id="+id+"}");
            shards.remove(id);
        }

        @Override
        public void serverStatus(TServerStatusRequest request, StreamObserver<TServerStatusResponse> responseObserver) {
            var response = TServerStatusResponse.newBuilder()
                .setOlderSupportBinaryVersion(format.lowerEndpoint())
                .setLatestSupportBinaryVersion(format.upperEndpoint())
                .setTotalShardCount(4096)
                .setStatus(predefineStatus);

            if (predefineStatus == EStockpileStatusCode.OK) {
                response.addAllShardStatus(shards.values().stream()
                    .map(shard -> TShardStatus.newBuilder()
                        .setReady(shard.isReady())
                        .setReadyWrite(shard.isReadyWrite())
                        .setReadyRead(shard.isReadyRead())
                        .setShardId(shard.getShardId())
                        .build())
                    .collect(Collectors.toList()));
            }

            responseObserver.onNext(response.build());
            responseObserver.onCompleted();
        }
    }
}
