package ru.yandex.solomon.coremon.balancer.cluster;

import java.util.Map;
import java.util.OptionalInt;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import javax.annotation.Nonnull;

import com.google.common.net.HostAndPort;
import io.grpc.Server;
import io.grpc.netty.NettyServerBuilder;
import it.unimi.dsi.fastutil.ints.IntOpenHashSet;
import it.unimi.dsi.fastutil.ints.IntSet;
import org.junit.Test;

import ru.yandex.grpc.utils.DefaultClientOptions;
import ru.yandex.monitoring.coremon.EShardState;
import ru.yandex.solomon.coremon.balancer.state.ShardIds;
import ru.yandex.solomon.coremon.balancer.state.ShardLoad;
import ru.yandex.solomon.coremon.balancer.state.ShardsLoadMap;
import ru.yandex.solomon.coremon.meta.service.MetabaseTotalShardCounter;
import ru.yandex.solomon.ut.ManualClock;
import ru.yandex.solomon.ut.ManualScheduledExecutorService;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;

/**
 * @author Sergey Polovko
 */
public class RemoteCoremonHostTest {

    @Test
    public void receiveShardsLoadChanges() throws Exception {
        ShardLoad shard1 = new ShardLoad(1, EShardState.NEW, 10, 20, 30, 40, 0);
        ShardLoad shard2 = new ShardLoad(2, EShardState.LOADING, 11, 21, 31, 41, 0);
        ShardLoad shard3 = new ShardLoad(3, EShardState.INDEXING, 12, 22, 32, 42, 0);

        var executor = Executors.newScheduledThreadPool(4);
        var localHost = new LocalCoremonHost(new AwaitableShardLocatorImpl() {
            private final AtomicInteger id = new AtomicInteger(0);

            @Override
            public ShardsLoadMap getShardsLoad() {
                // will emulate shards changing during execution
                switch (id.incrementAndGet()) {
                    case 1: return ShardsLoadMap.copyOf(Map.of(1, shard1));
                    case 2: return ShardsLoadMap.copyOf(Map.of(2, shard2));
                    case 3: return ShardsLoadMap.copyOf(Map.of(2, shard2, 3, shard3));
                    default: return ShardsLoadMap.copyOf(Map.of(1, shard1, 2, shard2, 3, shard3));
                }
            }
        }, executor);

        try (var server = new TestServer(localHost, executor, executor)) {
            {
                server.awaitNextPing();
                assertTrue(Math.abs(System.currentTimeMillis() - server.getSeenAliveTimeMillis()) < 1000);

                CoremonHost.State state = server.getState();
                assertEquals(ShardIds.EMPTY, state.getAssignments());
                assertEquals(ShardsLoadMap.copyOf(Map.of(1, shard1)), state.getShards());
                System.out.println("1st ping OK");
            }
            {
                server.awaitNextPing();
                assertTrue(Math.abs(System.currentTimeMillis() - server.getSeenAliveTimeMillis()) < 1000);

                CoremonHost.State state = server.getState();
                assertEquals(ShardIds.EMPTY, state.getAssignments());
                assertEquals(ShardsLoadMap.copyOf(Map.of(2, shard2)), state.getShards());
                System.out.println("2nd ping OK");
            }
            {
                server.awaitNextPing();
                assertTrue(Math.abs(System.currentTimeMillis() - server.getSeenAliveTimeMillis()) < 1000);

                CoremonHost.State state = server.getState();
                assertEquals(ShardIds.EMPTY, state.getAssignments());
                assertEquals(ShardsLoadMap.copyOf(Map.of(2, shard2, 3, shard3)), state.getShards());
                System.out.println("3rd ping OK");
            }
            {
                server.awaitNextPing();
                assertTrue(Math.abs(System.currentTimeMillis() - server.getSeenAliveTimeMillis()) < 1000);

                CoremonHost.State state = server.getState();
                assertEquals(ShardIds.EMPTY, state.getAssignments());
                assertEquals(ShardsLoadMap.copyOf(Map.of(1, shard1, 2, shard2, 3, shard3)), state.getShards());
                System.out.println("4th ping OK");
            }

            System.out.println("emulate fail and wait 10s");
            server.emulateFail();
            Thread.sleep(10_000);

            // after fail we should not see host anymore
            assertTrue(Math.abs(System.currentTimeMillis() - server.getSeenAliveTimeMillis()) >= 10_000);

            // but we still can get previous state of the host
            CoremonHost.State state = server.getState();
            assertEquals(ShardIds.EMPTY, state.getAssignments());
            assertEquals(ShardsLoadMap.copyOf(Map.of(1, shard1, 2, shard2, 3, shard3)), state.getShards());
        } finally {
            executor.shutdown();
            assertTrue(executor.awaitTermination(10, TimeUnit.SECONDS));
        }
    }

    @Test
    public void updateAssignments() throws Exception {
        var locator = new AwaitableShardLocatorImpl();
        var executor = Executors.newScheduledThreadPool(4);
        var localHost = new LocalCoremonHost(locator, executor);

        try (var server = new TestServer(localHost, executor, executor)) {
            {
                // {} --> {1, 2, 3}
                server.setAssignments(intSet(1, 2, 3));
                ShardIds shardIds = locator.getLastUpdate();
                assertNotNull(shardIds);
                assertEquals(intSet(1, 2, 3), shardIds.getShards());
            }
            {
                // {1, 2, 3} --> {3, 4, 5}
                server.setAssignments(intSet(3, 4, 5));
                ShardIds shardIds = locator.getLastUpdate();
                assertNotNull(shardIds);
                assertEquals(intSet(3, 4, 5), shardIds.getShards());
            }
            {
                // {3, 4, 5} + {1} - {5} --> {1, 3, 4}
                server.changeAssignments(intSet(1), intSet(5));
                ShardIds shardIds = locator.getLastUpdate();
                assertNotNull(shardIds);
                assertEquals(intSet(1, 3, 4), shardIds.getShards());
            }
            {
                // {1, 3, 4} + {2} - {4} --> {1, 2, 3}
                server.changeAssignments(intSet(2), intSet(4));
                ShardIds shardIds = locator.getLastUpdate();
                assertNotNull(shardIds);
                assertEquals(intSet(1, 2, 3), shardIds.getShards());
            }
        } finally {
            executor.shutdown();
            assertTrue(executor.awaitTermination(10, TimeUnit.SECONDS));
        }
    }

    @Test
    public void dontForceAssignPreviousState() throws Exception {
        var locator = new AwaitableShardLocatorImpl() {
            private final Map<Integer, ShardLoad> loadMap = new ConcurrentHashMap<>();

            @Override
            public void setLocalShardIds(@Nonnull ShardIds ids) {
                loadMap.clear();
                loadMap.putAll(loadEntries(ids.getShards().toIntArray()));
                super.setLocalShardIds(ids);
            }

            @Override
            public ShardsLoadMap getShardsLoad() {
                return ShardsLoadMap.copyOf(loadMap);
            }
        };
        var clock = new ManualClock();
        var executor = new ManualScheduledExecutorService(1, clock);
        var localHost = new LocalCoremonHost(locator, executor);

        TestServer alice = null;
        TestServer bob = null;
        try {
            alice = new TestServer(localHost, executor, executor);
            {
                // {} --> {1, 2, 3}
                alice.setAssignments(intSet(1, 2, 3));
                ShardIds shardIds = locator.getLastUpdate();
                assertNotNull(shardIds);
                assertEquals(intSet(1, 2, 3), shardIds.getShards());
            }
            alice.awaitNextPing();
            alice.remoteHost.stopPinging();

            bob = new TestServer(localHost, executor, executor);
            {
                // {1, 2, 3} --> {3, 4, 5}
                bob.setAssignments(intSet(3, 4, 5));
                ShardIds shardIds = locator.getLastUpdate();
                assertNotNull(shardIds);
                assertEquals(intSet(3, 4, 5), shardIds.getShards());
            }
            bob.awaitNextPing();
            System.out.println("bob become member");
            bob.remoteHost.stopPinging();

            // alice become leader
            System.out.println("alice become leader");
            alice.remoteHost.startPinging(TestServer.leaderSeqNo++, () -> 5);
            alice.awaitNextPing();
            alice.awaitNextPing();
            {
                ShardIds shardIds = locator.getLastUpdate();
                assertNull(shardIds);
            }
            {
                // {1, 2, 3} --> {3, 4, 5}
                alice.setAssignments(intSet(3, 4, 5));
                ShardIds shardIds = locator.getLastUpdate();
                assertNotNull(shardIds);
                assertEquals(intSet(3, 4, 5), shardIds.getShards());
            }
        } finally {
            if (alice != null) {
                alice.close();
            }
            if (bob != null) {
                bob.close();
            }
            executor.shutdown();
            assertTrue(executor.awaitTermination(10, TimeUnit.SECONDS));
        }
    }

    private static Map<Integer, ShardLoad> loadEntries(int... ids) {
        return IntStream.of(ids)
            .boxed()
            .collect(Collectors.toMap(Function.identity(), ShardLoad::inactive));
    }

    private static IntOpenHashSet intSet(int... ids) {
        return new IntOpenHashSet(ids);
    }

    /**
     * AWAITABLE BALANCER SHARD LOCATOR
     */
    private static class AwaitableShardLocatorImpl implements ShardInfoProvider {
        private final LinkedBlockingQueue<ShardIds> shardIdsChanges = new LinkedBlockingQueue<>();
        private final AtomicInteger totalShardCount = new AtomicInteger(MetabaseTotalShardCounter.SHARD_COUNT_UNKNOWN);

        @Override
        public ShardsLoadMap getShardsLoad() {
            return ShardsLoadMap.EMPTY;
        }

        @Override
        public void setLocalShardIds(@Nonnull ShardIds ids) {
            shardIdsChanges.offer(ids);
        }

        public ShardIds getLastUpdate() {
            return shardIdsChanges.poll();
        }

        @Override
        public void setTotalShardCount(int totalShardCount) {
            this.totalShardCount.set(totalShardCount);
        }

        @Override
        public OptionalInt getTotalShardCount() {
            int total = totalShardCount.get();
            if (total < 0) {
                return OptionalInt.empty();
            }
            return OptionalInt.of(total);
        }
    };

    /**
     * TEST SERVER
     */
    private static final class TestServer implements AutoCloseable {
        private final Server server;
        private final RemoteCoremonHost remoteHost;
        private static int leaderSeqNo = 12;

        TestServer(LocalCoremonHost localHost, ScheduledExecutorService timer, ExecutorService executor) throws Exception {
            this(new RemoteCoremonHostPeer(localHost), timer, executor);
        }

        TestServer(RemoteCoremonHostPeer peer, ScheduledExecutorService timer, ExecutorService executor) throws Exception {
            this.server = NettyServerBuilder.forPort(0)
                .addService(peer)
                .executor(executor)
                .build()
                .start();

            HostAndPort address = HostAndPort.fromParts("localhost", server.getPort());
            this.remoteHost = new RemoteCoremonHost(address, DefaultClientOptions.empty(), timer);
            this.remoteHost.startPinging(leaderSeqNo++, () -> 42);
        }

        @Override
        public void close() throws Exception {
            remoteHost.close();
            server.shutdown();
            assertTrue(server.awaitTermination(10, TimeUnit.SECONDS));
        }

        long getSeenAliveTimeMillis() {
            return remoteHost.getSeenAliveTimeMillis();
        }

        void awaitNextPing() {
            remoteHost.awaitNextPing();
        }

        void setAssignments(IntSet ids) {
            remoteHost.setAssignments(ids).join();
        }

        void changeAssignments(IntSet add, IntSet remove) {
            remoteHost.changeAssignments(add, remove).join();
        }

        CoremonHost.State getState() {
            return remoteHost.getState(true);
        }

        void emulateFail() throws InterruptedException {
            server.shutdownNow();
            assertTrue(server.awaitTermination(10, TimeUnit.SECONDS));
        }
    }
}
