package ru.yandex.stockpile.client.impl;

import java.util.Comparator;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import javax.annotation.Nullable;
import javax.annotation.ParametersAreNonnullByDefault;
import javax.annotation.WillCloseWhenClosed;
import javax.annotation.concurrent.ThreadSafe;

import com.google.common.collect.Range;
import com.google.common.util.concurrent.MoreExecutors;
import io.netty.util.concurrent.DefaultThreadFactory;
import it.unimi.dsi.fastutil.ints.Int2ObjectMap;
import it.unimi.dsi.fastutil.ints.Int2ObjectMaps;
import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import ru.yandex.cluster.discovery.ClusterDiscovery;
import ru.yandex.cluster.discovery.ClusterDiscoveryImpl;
import ru.yandex.misc.concurrent.CompletableFutures;
import ru.yandex.solomon.selfmon.AvailabilityStatus;
import ru.yandex.stockpile.api.EStockpileStatusCode;
import ru.yandex.stockpile.client.StockpileClientOptions;

/**
 * @author Vladimir Gordiychuk
 */
@SuppressWarnings("all")
@ThreadSafe
@ParametersAreNonnullByDefault
final class Cluster implements AutoCloseable {
    private static final Logger logger = LoggerFactory.getLogger(Cluster.class);
    @WillCloseWhenClosed
    private final ScheduledExecutorService timer;
    private final StockpileClientOptions opts;
    private AtomicReference<State> state = new AtomicReference<>(State.init());
    private ClusterDiscovery<StockpileNode> discovery;
    private ScheduledFuture future;
    private ScheduledFuture scheduledRefresh;

    Cluster(List<String> addresses, StockpileClientOptions opts) {
        this.timer = Executors.newSingleThreadScheduledExecutor(new DefaultThreadFactory("stockpile-client-scheduler"));
        var executer = opts.getGrpcOptions().getRpcExecutor().orElseGet(MoreExecutors::directExecutor);
        var reloadInterval = TimeUnit.HOURS.toMillis(1);
        this.discovery = new ClusterDiscoveryImpl<StockpileNode>(
                address -> new StockpileNode(address, opts, timer),
                addresses,
                opts.getDiscoveryService(),
                timer,
                executer,
                reloadInterval);
        this.opts = opts;
        this.scheduledRefresh = timer.scheduleAtFixedRate(() -> updateState(), 0, 1, TimeUnit.SECONDS);
    }

    @Nullable
    Shard getShard(int shardId) {
        return state.get().shards.get(shardId);
    }

    NodeClient getClient(String fqdn) {
        return node(fqdn).getClient();
    }

    int getReadyShardsCount() {
        return state.get().readyCount;
    }

    int getTotalShardsCount() {
        return state.get().totalShards;
    }

    Range<Integer> getCompatibleCompressFormat() {
        return state.get().format;
    }

    CompletableFuture<Void> forceClusterStatusUpdate() {
        return discovery.forceUpdate()
                .thenCompose(ignore -> {
                    return nodes()
                            .map(StockpileNode::forceUpdate)
                            .collect(Collectors.collectingAndThen(Collectors.toList(), CompletableFutures::allOfVoid));
                })
                .thenAccept(ignore -> updateState());
    }

    Stream<StockpileNode> nodes() {
        return discovery.getNodes()
                .stream()
                .map(node -> discovery.getTransportByNode(node));
    }

    StockpileNode node(String fqdn) {
        return discovery.getTransportByNode(fqdn);
    }

    CompletableFuture<Void> shardError(int shardId, EStockpileStatusCode code) {
        if (shardId == 0) {
            return forceClusterStatusUpdate();
        }

        switch (code) {
            case NODE_UNAVAILABLE:
            case SHARD_NOT_READY:
                Shard shard = getShard(shardId);
                if (shard == null) {
                    return forceClusterStatusUpdate();
                } else {
                    return node(shard.getFqdn())
                            .forceUpdate()
                            .thenAccept(ignore -> updateState());
                }
            case SHARD_ABSENT_ON_HOST:
            case NOTE_ENOUGH_READY_SHARDS:
                return forceClusterStatusUpdate();
            default:
                return CompletableFuture.completedFuture(null);
        }
    }

    @Override
    public void close() {
        scheduledRefresh.cancel(false);
        discovery.close();
        timer.shutdownNow();
    }

    @Override
    public String toString() {
        return discovery.getNodes()
            .stream()
            .sorted()
            .collect(Collectors.joining(", "));
    }

    public AvailabilityStatus getAvailability() {
        var clusterState = state.get();
        int ready = clusterState.readyCount;
        if (ready == clusterState.totalShards) {
            return AvailabilityStatus.AVAILABLE;
        }

        String details = "stockpile: ready shards " + ready + "/" + clusterState.totalShards;
        double availability = (double) ready / (double) clusterState.totalShards;
        return new AvailabilityStatus(availability, details);
    }

    private void updateState() {
        State prev;
        State update = prepareNextState();
        do {
            prev = state.get();
            if (prev.createdAtNanos > update.createdAtNanos) {
                return;
            }
        } while (!state.compareAndSet(prev, update));
    }

    private State prepareNextState() {
        long createdAtNanos = System.nanoTime();
        var nodeStates = nodes()
            .map(StockpileNode::getState)
            .sorted(Comparator.comparingLong(ShardsState::getCreatedAt))
            .collect(Collectors.toList());

        Range<Integer> format = compatibleFormat(nodeStates);
        Int2ObjectMap<Shard> shards = new Int2ObjectOpenHashMap<>(state.get().shards.size());
        int totalShards = 0;
        for (var nodeState : nodeStates) {
            totalShards = Math.max(totalShards, nodeState.getTotalShards());
            for (var shard : nodeState.getShards()) {
                shards.put(shard.getShardId(), shard);
            }
        }
        // TODO: remove it after upgrade all stockpile clusters
        if (totalShards == 0) {
            totalShards = 4096;
        }

        long nowMillis = System.currentTimeMillis();
        long expiration = Math.max(opts.getMetadataExpireMs() * 2, TimeUnit.SECONDS.toMillis(15));
        int readyCount = shards.values().stream()
            .mapToInt(value -> {
                if (!value.isReady()) {
                    return 0;
                }

                long old = nowMillis - value.getCreatedAt();
                if (old > expiration) {
                    return 0;
                }

                return 1;
            })
            .sum();
        return new State(shards, totalShards, readyCount, format, createdAtNanos);
    }

    private Range<Integer> compatibleFormat(List<ShardsState> nodeStates) {
        Range<Integer> format = null;
        for (var nodeState : nodeStates) {
            var nodeFormat = nodeState.getCompressFormat();
            if (nodeFormat.isEmpty()) {
                continue;
            }

            if (format == null) {
                format = nodeState.getCompressFormat();
                continue;
            }

            if (!format.isConnected(nodeState.getCompressFormat())) {
                format = null;
                break;
            }

            format = format.intersection(nodeState.getCompressFormat());
        }

        if (format == null) {
            return Range.openClosed(0,0);
        }

        return format;
    }

    private static class State {
        private final Int2ObjectMap<Shard> shards;
        private final int totalShards;
        private final Range<Integer> format;
        private final int readyCount;
        private final long createdAtNanos;

        State(Int2ObjectMap<Shard> shards, int totalShards, int readyCount, Range<Integer> format, long now) {
            this.shards = shards;
            this.totalShards = totalShards;
            this.readyCount = readyCount;
            this.format = format;
            this.createdAtNanos = now;
        }

        public static State init() {
            return new State(Int2ObjectMaps.emptyMap(), 0, 0, Range.closed(0, 0), System.nanoTime());
        }
    }
}
