package ru.yandex.metabase.client.impl;

import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.OptionalInt;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import javax.annotation.ParametersAreNonnullByDefault;
import javax.annotation.WillClose;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.util.concurrent.MoreExecutors;
import io.grpc.Status;
import io.netty.util.concurrent.DefaultThreadFactory;
import it.unimi.dsi.fastutil.ints.Int2ObjectMap;
import it.unimi.dsi.fastutil.ints.Int2ObjectMaps;
import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap;
import it.unimi.dsi.fastutil.objects.Object2IntMap;
import it.unimi.dsi.fastutil.objects.Object2IntMaps;
import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import ru.yandex.cluster.discovery.ClusterDiscovery;
import ru.yandex.cluster.discovery.ClusterDiscoveryImpl;
import ru.yandex.metabase.client.MetabaseClientOptions;
import ru.yandex.misc.concurrent.CompletableFutures;
import ru.yandex.solomon.labels.shard.ShardKey;
import ru.yandex.solomon.selfmon.AvailabilityStatus;

/**
 * @author Vladimir Gordiychuk
 */
@ParametersAreNonnullByDefault
class MetabaseCluster implements AutoCloseable {
    private static final Logger logger = LoggerFactory.getLogger(MetabaseCluster.class);

    private static final State EMPTY_STATE = new State(Object2IntMaps.emptyMap(), Int2ObjectMaps.emptyMap(), Map.of(), 0, OptionalInt.empty(), 0);
    private static final double ALL_SHARDS_DISCOVERED_FRACTION_THRESHOLD = 0.999;

    private final MetabaseClientOptions opts;
    private final ClusterDiscovery<MetabaseNode> discovery;
    private final AtomicReference<State> state = new AtomicReference<>(EMPTY_STATE);

    @WillClose
    private final ScheduledExecutorService timer;
    private final ScheduledFuture<?> scheduledRefresh;

    MetabaseCluster(List<String> addresses, MetabaseClientOptions opts) {
        this(addresses, MetabaseNodeImpl::new, opts);
    }

    @VisibleForTesting
    MetabaseCluster(List<String> addresses, MetabaseNodeFactory metabaseNodeFactory, MetabaseClientOptions opts) {
        this.opts = opts;
        // TODO: use common
        ThreadFactory threadFactory = new DefaultThreadFactory("metabase-client-scheduler", true);
        this.timer = Executors.newSingleThreadScheduledExecutor(threadFactory);
        var executor = opts.getGrpcOptions().getRpcExecutor().orElseGet(MoreExecutors::directExecutor);
        var reloadInterval = TimeUnit.HOURS.toMillis(1);
        this.discovery = new ClusterDiscoveryImpl<>(
                address -> metabaseNodeFactory.create(address, opts, executor, timer),
                addresses,
                this.opts.getDiscoveryService(),
                timer,
                executor,
                reloadInterval);
        scheduledRefresh = timer.scheduleAtFixedRate(this::updateState, 5_000, 15_000, TimeUnit.MILLISECONDS);
        forceUpdateClusterState();
    }

    public State getState() {
        final State state = this.state.get();
        ensureMostPartitionsDiscovered(state);
        return state;
    }

    boolean isAllowCreateNew(int numId) {
        var shard = state.get().shardByNumId(numId);
        return shard == null ? false : shard.isAllowNew();
    }

    public MetabaseNodeClient getNode(String fqdn) {
        return discovery.getTransportByNode(fqdn).getClient();
    }

    public CompletableFuture<Void> forceUpdateClusterState() {
        return CompletableFuture.completedFuture(discovery)
                .thenCompose(ClusterDiscovery::forceUpdate)
                .thenCompose(ignore -> discovery.getNodes()
                        .stream()
                        .map(fqdn -> CompletableFuture.completedFuture(fqdn)
                                .thenApply(discovery::getTransportByNode)
                                .thenCompose(MetabaseNode::forceUpdate)
                                .exceptionally(t -> {
                                    logger.warn("force update metabase node {} failed", fqdn, t);
                                    return null;
                                }))
                        .collect(Collectors.collectingAndThen(Collectors.toList(), CompletableFutures::allOfVoid)))
                .handle((ignore, e) -> {
                    if (e != null) {
                        logger.error("exception while updating node states", e);
                    }
                    updateState();
                    return null;
                });
    }

    @Override
    public void close() {
        this.scheduledRefresh.cancel(false);
        this.discovery.close();
        this.timer.shutdownNow();
    }

    public AvailabilityStatus getAvailability() {
        State currentState = state.get();
        if (currentState.createdAtNanos == 0) {
            return new AvailabilityStatus(0, "shard location not initialized");
        }

        return checkMostPartitionsDiscovered(currentState);
    }

    private static AvailabilityStatus checkMostPartitionsDiscovered(State currentState) {
        int discoveredPartitions = Math.toIntExact(currentState.totalDiscoveredPartitions() + currentState.totalInactivePartitionCount);
        int totalPartitions = currentState.totalPartitionCount.orElse(-1);
        if (totalPartitions < 0) {
            return new AvailabilityStatus(0, "total shard count is not known yet");
        }
        if (discoveredPartitions < ALL_SHARDS_DISCOVERED_FRACTION_THRESHOLD * totalPartitions) {
            return new AvailabilityStatus(((double) discoveredPartitions) / totalPartitions,
                    "not all metabase shards are discovered yet (" + discoveredPartitions + '/' + totalPartitions + ')');
        }
        return AvailabilityStatus.AVAILABLE;
    }

    // We can demand that only partitions for request's shards are fully discovered
    private static void ensureMostPartitionsDiscovered(State currentState) {
        if (currentState.createdAtNanos == 0) {
            throw Status.UNAVAILABLE.withDescription("Cluster is not discovered yet").asRuntimeException();
        }
        var status = checkMostPartitionsDiscovered(currentState);
        if (Double.compare(status.getAvailability(), 1.0) < 0) {
            throw Status.UNAVAILABLE.withDescription(status.getDetails()).asRuntimeException();
        }
    }

    private void updateState() {
        try {
            this.updateStateUnsafe();
        } catch (Throwable t) {
            logger.error("Exception in updateState", t);
        }
    }

    /**
     * Discovery may throw and break update loop
     * @see ru.yandex.cluster.discovery.ClusterDiscoveryImpl#getTransportByNode
     */
    private void updateStateUnsafe() {
        long startNanos = System.nanoTime();
        var nodeStates = discovery.getNodes()
                .stream()
                .map(fqdn -> discovery.getTransportByNode(fqdn).getState())
                .collect(Collectors.toList());
        // starting from latest, as we use sets for storing partitions
        nodeStates.sort(Comparator.comparingLong(ShardsState::getCreatedAt).reversed());

        State oldState = state.get();
        boolean hasDiff = false;

        // totalPartitionCount (or totalShardCount previously) is singleton on balancer, which sends to all nodes with TPingRequest
        OptionalInt totalPartitionCount = nodeStates.stream()
                .filter(x -> x.getTotalPartitionCount().isPresent())
                .map(ShardsState::getTotalPartitionCount)
                .findFirst()
                .orElse(OptionalInt.empty());

        // dead nodes contribute to totalInactiveShardCount by their obsolete state
        // proper fix is to use single metabase state from leader, see SOLOMON-8659
        int totalInactivePartitionCount = nodeStates.stream()
                .mapToInt(ShardsState::getInactivePartitionCount)
                .sum();

        if (totalPartitionCount.isPresent()) {
            if (!totalPartitionCount.equals(oldState.totalPartitionCount)) {
                hasDiff = true;
            } else {
                totalPartitionCount = oldState.totalPartitionCount;
            }
        }

        Object2IntMap<ShardKey> shardsKeyToNumId = oldState.shardsKeyToNumId;
        Int2ObjectMap<MetabaseShard> shardsByNumId = oldState.shardsByNumId;
        Map<String, List<MetabaseShard>> shardsByProject = oldState.shardsByProject;

        if (hasFreshState(nodeStates, oldState.createdAtNanos)) {
            shardsByNumId = nodeStates.stream().flatMap(ShardsState::getShards)
                    .collect(Collectors.toMap(
                            MetabaseShard::getNumId,
                            Function.identity(),
                            MetabaseShard::mergeShardInfoFromServers,
                            Int2ObjectOpenHashMap::new
                    ));
            shardsKeyToNumId = shardsByNumId.int2ObjectEntrySet().stream().collect(
                    Collectors.toMap(
                            entry -> entry.getValue().getKey(),
                            Int2ObjectMap.Entry::getIntKey,
                            (l, r) -> l,
                            Object2IntOpenHashMap::new
                    ));
            shardsByProject = shardsByNumId.values().stream().collect(Collectors.groupingBy(shard -> shard.getKey().getProject()));
            hasDiff = true;
        }

        if (!hasDiff) {
            return;
        }

        State update = new State(shardsKeyToNumId, shardsByNumId, shardsByProject, startNanos, totalPartitionCount, totalInactivePartitionCount);
        State prev = state.get();
        do {
            if (prev.createdAtNanos > startNanos) {
                return;
            }
        } while (state.compareAndSet(prev, update));
    }

    private boolean hasFreshState(List<ShardsState> nodeStates, long ns) {
        for (var state : nodeStates) {
            if (state.getCreatedAt() >= ns) {
                return true;
            }
        }

        return false;
    }

    static class State {
        final Object2IntMap<ShardKey> shardsKeyToNumId;
        final Int2ObjectMap<MetabaseShard> shardsByNumId;
        // TODO change to project,cluster,service bitmap index?
        final Map<String, List<MetabaseShard>> shardsByProject;
        final long createdAtNanos;
        final OptionalInt totalPartitionCount;
        final int totalInactivePartitionCount;

        State(
                Object2IntMap<ShardKey> shardsKeyToNumId,
                Int2ObjectMap<MetabaseShard> shardsByNumId,
                Map<String, List<MetabaseShard>> shardsByProject,
                long now,
                OptionalInt totalPartitionCount,
                int totalInactivePartitionCount
        )
        {
            this.shardsKeyToNumId = shardsKeyToNumId;
            this.shardsByNumId = shardsByNumId;
            this.shardsByProject = shardsByProject;
            this.createdAtNanos = now;
            this.totalPartitionCount = totalPartitionCount;
            this.totalInactivePartitionCount = totalInactivePartitionCount;
        }

        MetabaseShard shardByShardKey(ShardKey shardKey) {
            final int numId = shardsKeyToNumId.getOrDefault(shardKey, -1);
            return shardsByNumId.get(numId);
        }

        MetabaseShard shardByNumId(int numId) {
            return shardsByNumId.get(numId);
        }

        int shardNumId(ShardKey shardKey) {
            return shardsKeyToNumId.getOrDefault(shardKey, -1);
        }

        Collection<MetabaseShard> projectShards(String project) {
            var res = shardsByProject.get(project);
            return res == null ? Collections.emptyList() : res;
        }

        Stream<MetabaseShard> allShards() {
            return shardsByNumId.values().stream();
        }

        int totalDiscoveredPartitions() {
            return shardsByNumId.values().stream().mapToInt(MetabaseShard::getTotalPartitions).sum();
        }

    }
}
