package ru.yandex.solomon.balancer;

import java.time.Clock;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.DoubleSummaryStatistics;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;
import java.util.function.Supplier;
import java.util.stream.Collectors;

import javax.annotation.Nullable;

import io.grpc.Status;
import it.unimi.dsi.fastutil.IndirectPriorityQueue;
import it.unimi.dsi.fastutil.doubles.DoubleArrays;
import it.unimi.dsi.fastutil.objects.Object2IntMap;
import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap;
import it.unimi.dsi.fastutil.objects.ObjectHeapIndirectPriorityQueue;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import ru.yandex.misc.actor.ActorWithFutureRunner;
import ru.yandex.misc.concurrent.CompletableFutures;
import ru.yandex.solomon.balancer.dao.BalancerDao;
import ru.yandex.solomon.balancer.remote.RemoteCluster;
import ru.yandex.solomon.balancer.remote.RemoteNodeState;
import ru.yandex.solomon.balancer.remote.RemoteShardState;
import ru.yandex.solomon.balancer.snapshot.SnapshotAssignments;
import ru.yandex.solomon.balancer.snapshot.SnapshotNode;
import ru.yandex.solomon.balancer.snapshot.SnapshotShard;
import ru.yandex.solomon.util.collection.queue.ArrayListLockQueue;
import ru.yandex.solomon.util.time.DurationUtils;

import static java.util.concurrent.CompletableFuture.completedFuture;
import static java.util.concurrent.CompletableFuture.failedFuture;
import static java.util.stream.Collectors.collectingAndThen;
import static java.util.stream.Collectors.toList;
import static ru.yandex.solomon.balancer.ClusterState.findBestNode;

/**
 * @author Vladimir Gordiychuk
 */
public class BalancerImpl implements Balancer {
    private static final Logger logger = LoggerFactory.getLogger(BalancerImpl.class);
    private static final long SNAPSHOT_FLUSH_INTERVAL_MILLIS = TimeUnit.MINUTES.toMillis(5);
    private static final long ASSIGNMENT_REMIND_DELAY_MILLIS = TimeUnit.SECONDS.toMillis(15);

    private final Clock clock;
    private final Executor executor;
    private final List<Resource> resources;
    private final RemoteCluster cluster;
    private final ShardsHolder shardsHolder;
    private final TotalShardCounter shardCounter;
    private final BalancerDao dao;
    private final ActorWithFutureRunner actor;
    private final ScheduledFuture<?> periodicAct;

    private final ArrayListLockQueue<Supplier<CompletableFuture<?>>> tasksQueue = new ArrayListLockQueue<>();
    private final Map<String, ShardSummary> shards = new ConcurrentHashMap<>();
    private final Map<String, NodeSummary> nodes = new ConcurrentHashMap<>();
    private volatile BalancerOptions opts;
    private final Resources maximum = new Resources();
    private boolean init = false;
    private volatile AssignmentSeqNo latestSeqNo;
    private long scheduledSnapshotFlushAt;
    private volatile boolean autoFreeze;
    private volatile String autoFreezeReason;
    @Nullable
    private RebalanceProcess rebalanceProcess;
    private volatile boolean closed;

    public BalancerImpl(
        Clock clock,
        long leaderSeqNo,
        List<Resource> resources,
        RemoteCluster cluster,
        ScheduledExecutorService timer,
        ExecutorService executorService,
        ShardsHolder shardsHolder,
        TotalShardCounter shardCounter,
        BalancerDao dao)
    {
        this.clock = clock;
        this.executor = executorService;
        this.resources = resources;
        this.cluster = cluster;
        this.latestSeqNo = new AssignmentSeqNo(leaderSeqNo, 0);
        this.opts = BalancerOptions.newBuilder().build();
        this.shardsHolder = shardsHolder;
        this.shardCounter = shardCounter;
        this.dao = dao;
        this.scheduledSnapshotFlushAt = clock.millis() + SNAPSHOT_FLUSH_INTERVAL_MILLIS;
        this.actor = new ActorWithFutureRunner(this::act, executorService);
        this.periodicAct = timer.scheduleAtFixedRate(actor::schedule, 0, 1_000, TimeUnit.MILLISECONDS);
    }

    private CompletableFuture<?> act() {
        if (closed) {
            shardCounter.reset();
            nodes.values().forEach(NodeSummary::close);
            var e = Status.UNAVAILABLE.withDescription("Not leader anymore").asRuntimeException();
            for (var shard : shards.values()) {
                shard.cancelAssignmentWaits(e);
            }
            nodes.clear();
            shards.clear();
            periodicAct.cancel(false);
            return completedFuture(null);
        }

        logger.debug("start act");
        return new Iteration().run()
            .exceptionally(e -> {
                    logger.error("balancer iteration failed", e);
                return null;
            });
    }

    private void updateMaximum() {
        var limits = opts.getLimits();
        for (var resource : resources) {
            double max = limits.get(resource);
            if (max == 0.0) {
                max = resource.maximum(shards.values(), nodes.values());
            }
            maximum.set(resource, max);
        }
    }

    private void forceFlushSnapshot() {
        scheduledSnapshotFlushAt = 0;
    }

    private AssignmentSeqNo nextSeqNo() {
        var prev = latestSeqNo;
        latestSeqNo = new AssignmentSeqNo(prev.getLeaderSeqNo(), prev.getAssignSeqNo() + 1);
        return latestSeqNo;
    }

    private CompletableFuture<?> assignShardToNode(ShardSummary shard, NodeSummary node) {
        if (closed) {
            return completedFuture(null);
        }

        var seqNo = nextSeqNo();
        logger.info("Assign shard {} to node {} seqNo {}", shard.getShardId(), node.getAddress(), seqNo);
        node.assign(shard, seqNo);
        forceFlushSnapshot();
        long expiredAt = clock.millis() + opts.getAssignExpirationMillis();
        return completedFuture(null)
            .thenCompose(ignore -> node.getRemote().assignShard(shard.getShardId(), seqNo, expiredAt))
            .handle((response, e) -> {
                if (e != null) {
                    logger.error("Assign shard {} {} on node {} failed", shard.getShardId(), shard.getAssignmentSeqNo(), node.getAddress(), e);
                    run(() -> {
                        nextSeqNo();
                        node.failedAssign(shard, isAutoFreeze());
                    });
                } else {
                    logger.info("Assignment confirmed: {} {} on {}", shard.getShardId(), shard.getAssignmentSeqNo(), node.getAddress());
                    run(() -> node.successAssign(shard, clock.millis()));
                }
                return null;
            });
    }

    private CompletableFuture<?> forceUnassign(ShardSummary shard, NodeSummary node) {
        return unassign(shard, node, false, clock.millis() + opts.getForceUnassignExpirationMillis());
    }

    private CompletableFuture<?> gracefulUnassign(ShardSummary shard, NodeSummary node) {
        return unassign(shard, node, true, clock.millis() + opts.getGracefulUnassignExpirationMillis());
    }

    private CompletableFuture<?> unassign(ShardSummary shard, NodeSummary node, boolean graceful, long expiredAt) {
        if (closed) {
            return completedFuture(null);
        }

        if (shard.getNode() == node) {
            node.unassign(shard);
        }

        return unassign(shard.getShardId(), node, graceful, expiredAt);
    }

    private CompletableFuture<?> unassign(String shardId, NodeSummary node, boolean graceful, long expiredAt) {
        forceFlushSnapshot();
        // Send to previous owner friendly remind that project ownership expired
        // request can be failed it doesn't matter because after try remind
        // project ownership will be force reassign
        var seqNo = nextSeqNo();
        return CompletableFutures.safeCall(() -> node.getRemote().unassignShard(shardId, seqNo, graceful, expiredAt))
                .exceptionally((e) -> {
                    logger.error("unassign shard {} from node {} graceful={} failed", shardId, node.getAddress(), graceful, e);
                    return null;
                });
    }

    private boolean isAutoFreeze() {
        return autoFreeze;
    }

    @Nullable
    @Override
    public String getAutoFreezeReason() {
        if (autoFreeze) {
            return autoFreezeReason;
        }
        return null;
    }

    @Override
    public double getDispersion() {
        DoubleSummaryStatistics stat = nodes.values()
            .stream()
            .filter(node -> {
                if (!node.isActive()) {
                    return false;
                }

                if (node.getStatus() != NodeStatus.CONNECTED) {
                    return false;
                }

                if (node.getUptimeMillis() < TimeUnit.MINUTES.toMillis(5L)) {
                    return false;
                }

                return Double.compare(node.getFailCommandPercent(), 0.5) < 0;
            })
            .mapToDouble(node -> node.getUsage(maximum))
            .summaryStatistics();

        if (stat.getCount() == 0) {
            return 0;
        }

        return (stat.getMax() - stat.getMin()) / stat.getMax();
    }

    @Override
    public Map<String, NodeSummary> getNodes() {
        return nodes;
    }

    @Override
    public AssignmentSeqNo getLatestSeqNo() {
        return latestSeqNo;
    }

    @Override
    public Map<String, ShardSummary> getShards() {
        return shards;
    }

    @Override
    public void close() {
        closed = true;
        actor.schedule();
    }

    @Override
    public CompletableFuture<String> getOrCreateAssignment(String shardId) {
        return shardsHolder.add(shardId)
                .thenCompose(ignore -> {
                    var future = new CompletableFuture<String>();
                    run(() -> {
                        var shard = shards.computeIfAbsent(shardId, ShardSummary::new);
                        CompletableFutures.whenComplete(shard.waitAssignment(), future);
                    });
                    return future;
                });
    }

    @Override
    public CompletableFuture<String> getAssignment(String shardId) {
        return supplyAsync(() -> Optional.ofNullable(shards.get(shardId))
                .map(ShardSummary::getNode)
                .map(NodeSummary::getAddress)
                .orElse(null));
    }

    @Override
    public CompletableFuture<?> kickShard(String shardId) {
        return runAsync(() -> {
            var shard = shards.get(shardId);
            if (shard == null) {
                return completedFuture(null);
            }

            var node = shard.getNode();
            if (node == null) {
                return completedFuture(null);
            }
            logger.info("Graceful unassign shard {} from {} caused by kick", shard.getShardId(), node.getAddress());
            return gracefulUnassign(shard, node);
        });
    }

    @Override
    public CompletableFuture<?> kickNode(String address) {
        return runAsync(() -> {
            var node = nodes.get(address);
            if (node == null || node.isFreeze()) {
                return completedFuture(null);
            }
            var shards = new HashMap<>(node.getShards());
            logger.info("Graceful unassign shards {} from node {} caused by kick", shards.size(), node);
            return shards.values()
                .stream()
                .map(shard -> {
                    logger.info("Graceful unassign shard {} from {} caused by kick", shard.getShardId(), node.getAddress());
                    return gracefulUnassign(shard, node);
                })
                .collect(collectingAndThen(toList(), CompletableFutures::allOfUnit));
        });
    }

    @Override
    public CompletableFuture<?> setActive(String address, boolean flag) {
        return runAsync(() -> {
            var node = nodes.get(address);
            if (node == null) {
                return completedFuture(null);
            }

            node.setActive(flag);
            forceFlushSnapshot();
            return completedFuture(null);
        });
    }

    @Override
    public CompletableFuture<?> setActive(boolean flag) {
        return runAsync(() -> {
            for (var node : nodes.values()) {
                node.setActive(flag);
            }
            forceFlushSnapshot();
            return completedFuture(null);
        });
    }

    @Override
    public CompletableFuture<?> setFreeze(boolean flag) {
        return runAsync(() -> {
            for (var node : nodes.values()) {
                node.setFreeze(flag);
            }
            forceFlushSnapshot();
            return completedFuture(null);
        });
    }

    @Override
    public CompletableFuture<?> setFreeze(String address, boolean value) {
        return runAsync(() -> {
            var node = nodes.get(address);
            if (node == null) {
                return completedFuture(null);
            }

            node.setFreeze(value);
            forceFlushSnapshot();
            return completedFuture(null);
        });
    }

    @Override
    public CompletableFuture<?> setOptions(BalancerOptions opts) {
        return dao.saveOptions(opts)
            .thenCompose((ignore) -> runAsync(() -> {
                this.opts = opts;
                return completedFuture(null);
            }));
    }

    @Override
    public BalancerOptions getOptions() {
        return opts;
    }

    @Override
    public Resources getResourceMaximum() {
        return maximum;
    }

    @Override
    public List<Resource> getResources() {
        return resources;
    }

    @Override
    public CompletableFuture<?> rebalance() {
        return runAsync(() -> {
            if (rebalanceProcess == null) {
                rebalanceProcess = newRebalanceProcess();
            }
            return completedFuture(null);
        });
    }

    @Override
    public CompletableFuture<?> cancelRebalance() {
        return runAsync(() -> {
            rebalanceProcess = null;
            return completedFuture(null);
        });
    }

    @Override
    public double getRebalanceProgress() {
        RebalanceProcess copy = rebalanceProcess;
        if (copy == null) {
            return 1;
        } else {
            return copy.getProgress();
        }
    }

    @Override
    public void run(Runnable task) {
        tasksQueue.enqueue(() -> {
            task.run();
            return completedFuture(null);
        });
        actor.schedule();
    }

    @Override
    public <T> CompletableFuture<T> runAsync(Supplier<CompletableFuture<T>> supplier) {
        CompletableFuture<T> doneFuture = new CompletableFuture<>();
        tasksQueue.enqueue(() -> {
            CompletableFuture<T> result;
            try {
                result = supplier.get();
            } catch (Throwable e) {
                result = failedFuture(e);
            }

            result.whenCompleteAsync((r, e) -> {
                if (e != null) {
                    doneFuture.completeExceptionally(e);
                } else {
                    doneFuture.complete(r);
                }
            }, executor);
            return doneFuture;
        });
        actor.schedule();
        return doneFuture;
    }

    @Override
    public <T> CompletableFuture<T> supplyAsync(Supplier<T> supplier) {
        return runAsync(() -> completedFuture(supplier.get()));
    }

    private List<SnapshotNode> prepareSnapshotNodes() {
        return nodes.values()
                .stream()
                .map(node -> {
                    var result = new SnapshotNode();
                    result.address = node.getAddress();
                    result.active = node.isActive();
                    result.freeze = node.isFreeze();
                    return result;
                }).collect(toList());
    }

    private List<SnapshotShard> prepareSnapshotShards() {
        return shards.values()
                .stream()
                .map(shard -> {
                    var result = new SnapshotShard();
                    result.shardId = shard.getShardId();
                    result.assignmentSeqNo = shard.getAssignmentSeqNo();
                    result.node = shard.getNode() != null
                            ? shard.getNode().getAddress()
                            : null;
                    result.status = shard.getStatus();
                    result.resources = new Resources(shard.getResources());
                    return result;
                })
                .collect(toList());
    }

    private SnapshotAssignments prepareAssignmentSnapshot() {
        return new SnapshotAssignments(prepareSnapshotNodes(), prepareSnapshotShards());
    }

    private NotReadyShards getNotReadyShards() {
        int notReady = 0;
        int longLoading = 0;
        for (var shard : this.shards.values()) {
            if (shard.getStatus() != ShardStatus.READY) {
                notReady++;
                if (shard.getUptimeMillis() > TimeUnit.MINUTES.toMillis(5)) {
                    longLoading++;
                }
            }
        }
        return new NotReadyShards(notReady, longLoading);
    }

    private RebalanceProcess newRebalanceProcess() {
        var nodes = this.nodes.values().toArray(NodeSummary[]::new);

        var indexByNode = new Object2IntOpenHashMap<>(nodes.length);
        indexByNode.defaultReturnValue(-1);

        var shardsTotal = 0;

        var nodeQueues = new NodeQueue[nodes.length];
        var indices = new int[nodes.length];
        var size = 0;

        for (int i = 0; i < nodes.length; i++) {
            var node = nodes[i];
            indexByNode.put(node, i);

            var shards = node.getShards().values().toArray(ShardSummary[]::new);
            if (shards.length == 0) {
                continue;
            }
            shardsTotal += shards.length;

            var prioritizedShards = prioritize(shards);
            var nodeUsage = node.getUsage(maximum);
            nodeQueues[i] = new NodeQueue(prioritizedShards, nodeUsage);
            indices[size++] = i;
        }

        var pq = new ObjectHeapIndirectPriorityQueue<>(nodeQueues, indices, size);

        return new RebalanceProcess(pq, nodeQueues, indexByNode, shardsTotal);
    }

    private List<ShardSummary> prioritize(ShardSummary[] shards) {
        var len = shards.length;

        var indices = new int[len];
        var usages = new double[len];
        for (int i = 0; i < len; i++) {
            indices[i] = i;
            usages[i] = shards[i].getUsage(maximum);
        }

        DoubleArrays.quickSortIndirect(indices, usages);
        var prioritizedShards =  new ArrayList<ShardSummary>(len);
        for (int i = len - 1; i >= 0; i--) {
            prioritizedShards.add(shards[indices[i]]);
        }
        return prioritizedShards;
    }

    private class Iteration {
        private final long startedAt;

        public Iteration() {
            this.startedAt = clock.millis();
        }

        public CompletableFuture<?> run() {
            return completedFuture(null)
                .thenCompose(ignore -> init())
                .thenAccept(ignore -> updateMaximum())
                .thenCompose(ignore -> processRunnables())
                .thenCompose(ignore -> actualizeShardList())
                .thenAccept(ignore -> actualizeClusterMembers())
                .thenAccept(ignore -> actualizeGlobalFreeze())
                .thenCompose(ignore -> processPings())
                .thenCompose(ignore -> unassignExpired())
                .thenCompose(ignore -> assignUnassignedShards())
                .thenCompose(ignore -> actRebalance())
                .thenCompose(ignore -> flushAssignments());
        }

        private CompletableFuture<?> init() {
            if (init) {
                return completedFuture(null);
            }
            logger.debug("Initialize balancer");
            return dao.createSchema()
                .thenCompose(o -> shardsHolder.reload())
                .thenCompose(o -> loadAssignments())
                .thenCompose(o -> loadOptions())
                .thenRun(() -> {
                    long lastAssignSeqNo = shards.values()
                            .stream()
                            .map(ShardSummary::getAssignmentSeqNo)
                            .filter(Objects::nonNull)
                            .max(Comparator.naturalOrder())
                            .map(AssignmentSeqNo::getAssignSeqNo)
                            .orElse(clock.millis());

                    latestSeqNo = new AssignmentSeqNo(latestSeqNo.getLeaderSeqNo(), lastAssignSeqNo);
                    init = true;
                    updateMaximum();
                });
        }

        private CompletableFuture<?> loadOptions() {
            return dao.getOptions()
                .thenAccept(load -> {
                    opts = load;
                    logger.info("opts restore: {}", opts);
                });
        }

        private CompletableFuture<?> loadAssignments() {
            return dao.getAssignments()
                .thenAccept(snapshot -> {
                    logger.debug("Started loadAssignments");
                    if (snapshot.isEmpty()) {
                        logger.info("Skip restore assignments, because absent");
                        return;
                    }
                    logger.debug("Assignments: {}", snapshot);
                    for (var snapshotNode : snapshot.nodes) {
                        var node = nodes.get(snapshotNode.address);
                        if (node == null) {
                            node = newNode(snapshotNode.address);
                            logger.info("Node {} was included(from assignments) to cluster, will expired at {}", snapshotNode.address, Instant.ofEpochMilli(node.getExpiredAt()));
                            nodes.put(node.getAddress(), node);
                        }

                        node.setFreeze(snapshotNode.freeze);
                        node.setActive(snapshotNode.active);
                    }

                    for (var snapshotShard : snapshot.shards) {
                        var shard = shards.computeIfAbsent(snapshotShard.shardId, ShardSummary::new);
                        if (shard.getNode() != null) {
                            // already actual state reported by node
                            continue;
                        }

                        shard.restoreResources(snapshotShard);
                        if (snapshotShard.node == null) {
                            continue;
                        }

                        var node = nodes.get(snapshotShard.node);
                        if (node != null) {
                            node.assign(shard, snapshotShard.assignmentSeqNo);
                            shard.setAssignedAt(0);
                        }
                    }

                    logger.info("assignments restored: {}",
                            shards.values().stream()
                                    .map(shard -> shard.getShardId() + "=>" + Optional.ofNullable(shard.getNode())
                                            .map(NodeSummary::getAddress)
                                            .orElse("none"))
                                    .collect(Collectors.toList()));
                });
        }

        private CompletableFuture<?> processRunnables() {
            return tasksQueue.dequeueAll()
                .stream()
                .map(Supplier::get)
                .collect(collectingAndThen(toList(), CompletableFutures::allOfUnit));
        }

        private CompletableFuture<?> processPings() {
            logger.debug("process node states");
            List<CompletableFuture<?>> futures = new ArrayList<>();
            for (var node : nodes.values()) {
                var state = node.getRemote().takeState();
                if (state == null) {
                    logger.debug("{} state is null", node.getAddress());
                    continue;
                }
                futures.addAll(processStateFromNode(node, state));
            }

            return CompletableFutures.allOfVoid(futures);
        }

        private List<CompletableFuture<?>> processStateFromNode(NodeSummary node, RemoteNodeState state) {
            long prevUptime = node.getUptimeMillis();
            boolean isUnknown = node.getStatus() == NodeStatus.UNKNOWN;
            node.updateNodeState(state);
            node.setExpiredAt(state.receivedAt + opts.getHeartbeatExpirationMillis());

            if (isUnknown && node.getShards().isEmpty()) {
                rememberShardLocation(node, state.shards);
            } else {
                Map<String, ShardSummary> unknownShards = updateShardsOnNode(node, state.shards);
                /*
                 * When assigned shards for node on leader mismatch with node status, it's can indicate
                 * few cases:
                 * 1) node process failed, but restarted fast enough
                 * 2) node unload assigned shard because table generation changed - manually, or by another leader
                 * <p>
                 * In all cases reassign already assigned shards to node.
                 */
                if (unknownShards.isEmpty()) {
                    logger.debug("unknownShards is empty for {}", node.getAddress());
                    return List.of();
                }

                List<CompletableFuture<?>> futures = new ArrayList<>(unknownShards.size());
                for (ShardSummary shard : unknownShards.values()) {
                    if (node.getUptimeMillis() > prevUptime && shard.getAssignedAt() + ASSIGNMENT_REMIND_DELAY_MILLIS >= state.receivedAt) {
                        logger.debug("skip remind assignment shard {} on node {}", shard.getShardId(), node.getAddress());
                        // load shard it's slow process, so avoid remind assignments too often
                        continue;
                    }

                    logger.info("Remind shard {} assigned on node {}", shard.getShardId(), node.getAddress());
                    node.unassign(shard);
                    futures.add(assignShardToNode(shard, node));
                }
                return futures;
            }

            return List.of();
        }

        private void rememberShardLocation(NodeSummary node, List<RemoteShardState> shardsState) {
            for (var state : shardsState) {
                ShardSummary shard = shards.get(state.shardId);
                if (shard == null) {
                    shard = new ShardSummary(state.shardId);
                    shards.put(state.shardId, shard);
                }

                if (shard.getNode() == null) {
                    logger.info("Shard {} located at {}", shard.getShardId(), node.getAddress());
                    node.assign(shard, nextSeqNo());
                    shard.setAssignedAt(clock.millis());
                    node.updateResources(shard.updateResources(state, resources));
                } else if (shard.getNode() == node) {
                    logger.debug("Shard {} located at {} as previously, will be refreshed",
                            shard.getShardId(),
                            node.getAddress()
                    );
                    node.updateResources(shard.updateResources(state, resources));
                } else {
                    logger.info("Shard {} not located at {} anymore, new location {}",
                        shard.getShardId(),
                        node.getAddress(),
                        shard.getNode().getAddress()
                    );
                }
            }
        }

        private Map<String, ShardSummary> updateShardsOnNode(NodeSummary node, List<RemoteShardState> shardsState) {
            var nodeShards = node.getShards();
            var updated = new HashSet<String>(nodeShards.size());

            for (var state : shardsState) {
                ShardSummary shard = nodeShards.get(state.shardId);
                if (shard == null) {
                    shard = shards.get(state.shardId);
                    if (shard == null) {
                        shard = new ShardSummary(state.shardId);
                    }

                    String newLocation = shard.getNode() != null
                        ? shard.getNode().getAddress()
                        : "none";

                    logger.info("Shard {} not located at {} anymore, new location {}", state.shardId, node.getAddress(), newLocation);
                    forceUnassign(shard, node);
                    continue;
                }

                node.updateResources(shard.updateResources(state, resources));
                updated.add(state.shardId);
            }

            if (updated.size() == nodeShards.size()) {
                return Map.of();
            }

            Map<String, ShardSummary> unknown = new HashMap<>(nodeShards.size() - updated.size());
            for (var entry : nodeShards.entrySet()) {
                if (!updated.contains(entry.getKey())) {
                    unknown.put(entry.getKey(), entry.getValue());
                }
            }

            return unknown;
        }

        /**
         * Actualize list nodes in the cluster with discovery service, all shards assigned
         * from excluded node will be force unassign.
         * <p>
         * It's allow scale-in/scale-out cluster without restart.
         */
        private void actualizeClusterMembers() {
            Set<String> nodesInCluster = cluster.getNodes();
            if (nodesInCluster.equals(nodes.keySet())) {
                return;
            }

            logger.debug("actualize cluster memberships");
            // excludes
            var it = nodes.values().iterator();
            while (it.hasNext()) {
                var node = it.next();
                if (!nodesInCluster.contains(node.getAddress())) {
                    logger.info("Node {} was excluded from cluster caused by absence into discovery list, shards on node {}", node.getAddress(), node.getShards().keySet());
                    node.close();
                    it.remove();
                    forceFlushSnapshot();
                }
            }

            // includes
            for (String address : nodesInCluster) {
                if (nodes.containsKey(address)) {
                    continue;
                }

                var node = newNode(address);
                logger.info("Node {} was included to cluster, will expired at {}", address, Instant.ofEpochMilli(node.getExpiredAt()));
                nodes.put(address, node);
            }
        }

        private NodeSummary newNode(String address) {
            var remote = cluster.create(address, latestSeqNo.getLeaderSeqNo());
            var node = new NodeSummary(remote);
            node.setExpiredAt(clock.millis() + opts.getHeartbeatExpirationMillis());
            return node;
        }

        private CompletableFuture<Void> deleteShard(ShardSummary shard) {
            return shardsHolder.delete(shard.getShardId())
                    .thenRun(() -> {
                        nextSeqNo(); // report to all node about change shards count
                        shards.remove(shard.getShardId(), shard);
                        forceFlushSnapshot();
                    });
        }

        private CompletableFuture<Void> unassignAndDeleteShard(ShardSummary shard) {
            var node = shard.getNode();
            if (node == null) {
                shard.cancelAssignmentWaits(Status.ABORTED.asRuntimeException());
                return deleteShard(shard);
            }

            return forceUnassign(shard, node)
                    .thenCompose(ignore -> deleteShard(shard));
        }

        private CompletableFuture<?> actualizeShardList() {
            Set<String> allShards = shardsHolder.getShards();
            if (allShards.isEmpty()) {
                return completedFuture(null);
            }

            List<CompletableFuture<?>> futures = new ArrayList<>();

            // unassign removed shards
            for (var entry : shards.entrySet()) {
                var shardId = entry.getKey();
                var shard = entry.getValue();
                if (!allShards.contains(shardId)) {
                    logger.info("Unassign and delete shard {}", shardId);
                    futures.add(unassignAndDeleteShard(shard));
                }
            }

            // add new shards to assignment list
            for (String shardId : allShards) {
                shards.computeIfAbsent(shardId, ShardSummary::new);
                forceFlushSnapshot();
            }

            // actualize total count in one place
            shardCounter.setTotalShardCount(shards.size());

            return CompletableFutures.allOfVoid(futures);
        }

        private void actualizeGlobalFreeze() {
            if (opts.isDisableAutoFreeze()) {
                autoFreeze = false;
                return;
            }

            var activeNodes = nodes.values()
                    .stream()
                    .filter(NodeSummary::isActive)
                    .collect(Collectors.toList());

            // All not deactivated, non block unassign by expire
            if (activeNodes.size() <= 2) {
                autoFreeze = false;
                return;
            }

            // Wait with reassignments until state of all node is undefined
            int unknownNode = ClusterState.countUnknown(activeNodes);
            if (unknownNode > 0) {
                autoFreeze = true;
                autoFreezeReason = "Unknown status for " + unknownNode + " nodes";
                return;
            }

            // Wait with reassignments because cluster was recently restarted and not warmup yet
            long clusterUptime = ClusterState.uptimeMillis(activeNodes);
            if (clusterUptime < TimeUnit.MINUTES.toMillis(10)) {
                autoFreeze = true;
                autoFreezeReason = "Cluster uptime " + DurationUtils.formatDurationMillis(clusterUptime) + " < 10m";
                return;
            }

            // When half cluster not available also stop any reassignments
            long expireBorder = Math.round(opts.getHeartbeatExpirationMillis() / 2.5);
            int expiredNodeBorder = Math.floorDiv(activeNodes.size(), 2);
            int countNearExpire = ClusterState.countNearExpire(activeNodes, expireBorder, startedAt);
            if (countNearExpire >= expiredNodeBorder) {
                autoFreeze = true;
                autoFreezeReason = countNearExpire + " nodes near expire or already expired, latest ping success more then " + DurationUtils.formatDurationMillis(expireBorder);
                return;
            }

            // loose half of cluster
            autoFreeze = false;
        }

        /**
         * As only heartbeat from node expired, all shards assigned to node will be force unassign
         */
        private CompletableFuture<?> unassignExpired() {
            List<CompletableFuture<?>> futures = new ArrayList<>();
            var it = nodes.values().iterator();
            while (it.hasNext()) {
                var node = it.next();
                if (node.getExpiredAt() > startedAt) {
                    continue;
                }

                if (node.getStatus() != NodeStatus.EXPIRED) {
                    logger.info("{} heartbeat expired at {}", node.getAddress(), Instant.ofEpochMilli(node.getExpiredAt()));
                }

                if (!cluster.hasNode(node.getAddress())) {
                    it.remove();
                    logger.info("{} excluded from cluster", node.getAddress());
                    node.close();
                    continue;
                }

                node.markExpired();
                if (isAutoFreeze() || node.isFreeze() || node.isEmpty()) {
                    continue;
                }

                var shards = List.copyOf(node.getShards().values());
                for (ShardSummary shard : shards) {
                    logger.info("Force unassign shard {} from {} caused by heartbeat expiration", shard.getShardId(), node.getAddress());
                    futures.add(forceUnassign(shard, node));
                }
            }
            return CompletableFutures.allOfVoid(futures);
        }

        private CompletableFuture<?> assignUnassignedShards() {
            if (nodes.isEmpty()) {
                return completedFuture(null);
            }

            // assign available only when all nodes at least once send heartbeat or expire it
            for (var node : nodes.values()) {
                if (!node.isActive()) {
                    continue;
                }

                if (node.getStatus() == NodeStatus.UNKNOWN) {
                    logger.debug("skip assign unassigned shards, because not all nodes send heartbeat");
                    return completedFuture(null);
                }
            }

            var availableShards = shardsHolder.getShards();
            var unassignedShards = shards.values().stream()
                .filter(shard -> shard.getNode() == null && availableShards.contains(shard.getShardId()))
                .toArray(ShardSummary[]::new);

            if (unassignedShards.length == 0) {
                return completedFuture(null);
            }
            var targets = prioritize(unassignedShards);

            logger.debug("start assign unassigned shards: {}", targets.size());
            List<CompletableFuture<?>> futures = new ArrayList<>(targets.size());
            for (var shard : targets) {
                var node = findBestNode(shard, nodes.values(), maximum);
                if (node == null) {
                    actor.schedule();
                    continue;
                }

                logger.info("Assign shard {} to node {}", shard.getShardId(), node.getAddress());
                futures.add(assignShardToNode(shard, node));
            }
            return CompletableFutures.allOfVoid(futures);
        }

        private CompletableFuture<?> actRebalance() {
            if (rebalanceProcess != null) {
                if (rebalanceProcess.isDone()) {
                    rebalanceProcess = null;
                } else {
                    return rebalanceProcess.act();
                }
            }

            if (!opts.isEnableAutoRebalance()) {
                return completedFuture(null);
            }

            var maxInFlight = getNotReadyShards().getEffectiveMaxInFlight(opts);
            if (maxInFlight > 0 && getDispersion() >= opts.getAutoRebalanceDispersionThreshold()) {
                rebalanceProcess = newRebalanceProcess();
                return rebalanceProcess.act();
            }

            return completedFuture(null);
        }

        private CompletableFuture<?> flushAssignments() {
            if (startedAt < scheduledSnapshotFlushAt || closed) {
                return completedFuture(null);
            }

            return dao.saveAssignments(prepareAssignmentSnapshot())
                .thenRun(() -> {
                    long delayMillis = ThreadLocalRandom.current().nextLong(15_000, SNAPSHOT_FLUSH_INTERVAL_MILLIS);
                    scheduledSnapshotFlushAt = clock.millis() + delayMillis;
                });
        }
    }

    private class RebalanceProcess {
        private final IndirectPriorityQueue<NodeQueue> pq;
        private final NodeQueue[] nodeQueues;
        private final Object2IntMap<Object> indexByNode;

        private final int shardsTotal;
        private int shardsProcessed;

        RebalanceProcess(
            IndirectPriorityQueue<NodeQueue> pq,
            NodeQueue[] nodeQueues,
            Object2IntMap<Object> indexByNode,
            int shardsTotal)
        {
            this.pq = pq;
            this.nodeQueues = nodeQueues;
            this.indexByNode = indexByNode;
            this.shardsTotal = shardsTotal;
        }

        boolean isDone() {
            return shardsProcessed >= shardsTotal;
        }

        double getProgress() {
            return (double) shardsProcessed / shardsTotal;
        }

        CompletableFuture<?> act() {
            var dispersionTarget = opts.getRebalanceDispersionTarget();
            if (dispersionTarget > 0) {
                var dispersion = getDispersion();
                logger.debug("dispersion before rebalance iteration: {}", dispersion);
                if (dispersion <= dispersionTarget) {
                    shardsProcessed = shardsTotal;
                    logger.info(
                        "finish rebalance process because dispersion target has been reached: {} <= {}",
                        dispersion,
                        dispersionTarget);
                    return completedFuture(null);
                }
            }

            var notReadyShards = getNotReadyShards();
            var maxInFlight = notReadyShards.getEffectiveMaxInFlight(opts);
            if (maxInFlight == 0) {
                logger.debug("skip rebalance iteration, not ready shards {}", notReadyShards);
                return completedFuture(null);
            }

            List<CompletableFuture<?>> futures = new ArrayList<>(maxInFlight);
            for (; !pq.isEmpty() && futures.size() < maxInFlight; shardsProcessed++) {
                var nodeQueue = nodeQueues[pq.first()];
                var shard = nodeQueue.remove();
                if (nodeQueue.isEmpty()) {
                    nodeQueues[pq.dequeue()] = null;
                }

                if (isAutoFreeze()) {
                    continue;
                }

                if (shard.getNode() != null && shard.getNode().isFreeze()) {
                    logger.debug("skip rebalance shard {} because on freeze node {}", shard.getShardId(), shard.getNode().getAddress());
                    continue;
                }

                if (shard.getStatus() != ShardStatus.READY) {
                    logger.debug("skip rebalance shard {} because in state {}", shard.getShardId(), shard.getStatus());
                    continue;
                }

                var bestNode = findBestNode(shard, nodes.values(), maximum);
                if (bestNode == null) {
                    logger.debug("skip rebalance shard {} because absent node ready to accept it", shard.getShardId());
                    continue;
                }

                if (shard.getNode() == bestNode) {
                    logger.debug("skip rebalance shard {}, because already best location {}", shard.getShardId(), shard.getNode().getAddress());
                    continue;
                }

                logger.info("move shard {} caused by rebalance from {} -> {}", shard.getShardId(), shard.getNode().getAddress(), bestNode.getAddress());
                var fromNode = shard.getNode();
                updateNodeUsage(fromNode, fromNode.getUsageWithout(shard, maximum));

                var shardUsage = new Resources(shard.getResources());
                // Consider future assign into chosen best node on next iteration
                bestNode.updateResources(shardUsage);
                updateNodeUsage(bestNode, bestNode.getUsage(maximum));

                var future = gracefulUnassign(shard, fromNode)
                    .thenRun(() -> runAsync(() -> {
                        // Remove added usage to best node, to avoid add it twice after assign
                        var delta = new Resources();
                        delta.minus(shardUsage);
                        bestNode.updateResources(delta);

                        // ensure that shard not reassigned yet
                        if (shard.getNode() == null) {
                            return assignShardToNode(shard, bestNode);
                        } else {
                            return completedFuture(null);
                        }
                    }));
                futures.add(future);
            }

            logger.info("shards rebalance progress: {}%", Math.round(getProgress() * 100));
            return CompletableFutures.allOfVoid(futures);
        }

        private void updateNodeUsage(NodeSummary node, double nodeUsage) {
            var index = indexByNode.getInt(node);
            if (index >= 0 && pq.contains(index)) {
                nodeQueues[index].setNodeUsage(nodeUsage);
                pq.changed(index);
            }
        }
    }

    private static class NodeQueue implements Comparable<NodeQueue> {
        private final List<ShardSummary> shards;
        private int cursor;

        private double nodeUsage;

        private NodeQueue(List<ShardSummary> shards, double nodeUsage) {
            this.shards = shards;
            this.nodeUsage = nodeUsage;
        }

        ShardSummary remove() {
            if (isEmpty()) {
                throw new NoSuchElementException();
            }

            return shards.get(cursor++);
        }

        void setNodeUsage(double nodeUsage) {
            this.nodeUsage = nodeUsage;
        }

        public boolean isEmpty() {
            return cursor >= shards.size();
        }

        @SuppressWarnings("CompareToUsesNonFinalVariable")
        @Override
        public int compareTo(NodeQueue that) {
            return Double.compare(that.nodeUsage, this.nodeUsage);
        }
    }

    private record NotReadyShards(int notReady, int longLoading) {

        int getEffectiveMaxInFlight(BalancerOptions opts) {
            var toIgnore = Math.min(longLoading, opts.getMaxLongLoadingShardsToIgnore());
            return Math.max(0, opts.getMaxReassignInFlight() - (notReady - toIgnore));
        }
    }
}
