package ru.yandex.stockpile.client.util;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;
import java.util.stream.IntStream;

import javax.annotation.ParametersAreNonnullByDefault;

import com.google.common.net.HostAndPort;
import io.grpc.Server;
import io.grpc.ServerBuilder;
import io.grpc.inprocess.InProcessServerBuilder;
import io.grpc.netty.NettyServerBuilder;
import io.netty.util.concurrent.DefaultThreadFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * @author Vladimir Gordiychuk
 */
@ParametersAreNonnullByDefault
public class InMemoryStockpileCluster {
    private static final Logger logger = LoggerFactory.getLogger(InMemoryStockpileCluster.class);

    private final ExecutorService executorService;
    private final ConcurrentMap<HostAndPort, Server> addressToServer;
    private final int[] shards;
    private final Map<HostAndPort, InMemoryStockpileService> addressToService;
    private final boolean portBinding;

    private InMemoryStockpileCluster(Builder builder) throws IOException {
        this.executorService = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors(),
                new DefaultThreadFactory("stockpile-server-grpc")
        );
        this.portBinding = builder.portBinding;

        shards = IntStream.rangeClosed(builder.firstShardId, builder.lastShardId).toArray();
        shuffle(shards);
        int[][] splittedShards = splitShard(shards, builder.countServer);

        addressToServer = new ConcurrentHashMap<>();
        addressToService = new HashMap<>();
        for (int serverNum = 0; serverNum < builder.countServer; serverNum++) {
            InMemoryStockpileService service = new InMemoryStockpileService(splittedShards[serverNum], shards.length, builder.writeToDevNull);
            final String host;
            final int port;
            final Server server;
            if (!builder.portBinding) {
                host = "in-memory-stockpile-server-" + serverNum;
                server = InProcessServerBuilder.forName(host)
                        .addService(service)
                        .executor(executorService)
                        .build()
                        .start();

                port = server.getPort();
            } else {
                host = "localhost";
                server = NettyServerBuilder.forPort(0)
                        .addService(service)
                        .executor(executorService)
                        .build()
                        .start();
                port = server.getPort();
            }

            HostAndPort address;
            if (port == -1) {
                address = HostAndPort.fromString(host);
            } else {
                address = HostAndPort.fromParts(host, port);
            }
            addressToServer.put(address, server);
            addressToService.put(address, service);
            logger.debug("Startup node {} with shards {}", address, splittedShards[serverNum]);
        }
    }

    public static Builder newBuilder() {
        return new Builder();
    }

    public List<HostAndPort> getServerList() {
        return new ArrayList<>(addressToServer.keySet());
    }

    public int[] getShards() {
        return shards;
    }

    public HostAndPort getServerWithShard(int shardId) {
        return addressToService.entrySet().stream()
                .filter(entry -> entry.getValue().hasShard(shardId))
                .map(Map.Entry::getKey)
                .findFirst()
                .orElseThrow(() -> new IllegalArgumentException("ShardId not found in cluster: " + shardId));
    }

    public void forceStopServer(HostAndPort address) throws InterruptedException {
        logger.debug("Force stop node with address {}", address);
        Server server = addressToServer.get(address);
        server.shutdownNow();
        server.awaitTermination();
    }

    public void restartServer(HostAndPort address) throws InterruptedException, IOException {
        Server server = addressToServer.get(address);
        if (!server.isTerminated()) {
            forceStopServer(address);
        }

        logger.debug("Restart node with address {}", address);
        final Server newServer;
        if (portBinding) {
            newServer = ServerBuilder.forPort(address.getPort())
                    .addService(addressToService.get(address))
                    .executor(executorService)
                    .build()
                    .start();
        } else {
            newServer = InProcessServerBuilder.forName(address.getHost())
                    .addService(addressToService.get(address))
                    .executor(executorService)
                    .build()
                    .start();

        }

        addressToServer.replace(address, newServer);
    }

    public void stop() throws InterruptedException {
        logger.debug("Shutdown stockpile cluster: {}", addressToServer.keySet());
        addressToServer.values().forEach(Server::shutdownNow);
        for (Server server : addressToServer.values()) {
            server.awaitTermination(1, TimeUnit.SECONDS);
        }

        executorService.shutdownNow();
    }

    private void shuffle(int[] shards) {
        ThreadLocalRandom random = ThreadLocalRandom.current();
        for (int index = shards.length; index > 1; index--) {
            swap(shards, index - 1, random.nextInt(index));
        }
    }

    private int[][] splitShard(int[] shards, int serverCount) {
        int[][] result = new int[serverCount][];

        int pos = 0;
        for (int index = 0; index < serverCount; index++) {
            int countNotBalancedShards = shards.length - pos;
            int countServersWithoutLoad = serverCount - index;
            int shardsPerServer = (int) Math.floor((double) countNotBalancedShards / (double) countServersWithoutLoad);

            int to = pos + Math.min(shardsPerServer, countNotBalancedShards);
            result[index] = Arrays.copyOfRange(shards, pos, to);

            pos += shardsPerServer;
        }

        return result;
    }

    private void swap(int shards[], int from, int to) {
        int temp = shards[to];
        shards[to] = shards[from];
        shards[from] = temp;
    }

    public static class Builder {
        private int countServer = 1;
        private int firstShardId = 1;
        private int lastShardId = 42;
        private boolean portBinding = true;
        private boolean writeToDevNull = false;

        private Builder() {

        }

        public Builder serverCount(int countServer) {
            this.countServer = countServer;
            return this;
        }

        public Builder shardRange(int from, int to) {
            firstShardId = from;
            lastShardId = to;
            return this;
        }

        public Builder inProcess() {
            portBinding = false;
            return this;
        }

        public Builder devNullWrite() {
            writeToDevNull = true;
            return this;
        }

        public InMemoryStockpileCluster build() throws IOException {
            return new InMemoryStockpileCluster(this);
        }
    }
}
