package ru.yandex.solomon.coremon.balancer;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap;
import it.unimi.dsi.fastutil.ints.IntOpenHashSet;
import it.unimi.dsi.fastutil.ints.IntSet;
import it.unimi.dsi.fastutil.ints.IntSets;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import ru.yandex.monitoring.coremon.EShardState;
import ru.yandex.solomon.core.conf.SolomonConfWithContext;
import ru.yandex.solomon.core.db.model.Shard;
import ru.yandex.solomon.coremon.balancer.cluster.CoremonHost;
import ru.yandex.solomon.coremon.balancer.cluster.CoremonHost.State;
import ru.yandex.solomon.coremon.balancer.db.ShardAssignments;
import ru.yandex.solomon.coremon.balancer.db.ShardAssignmentsDao;
import ru.yandex.solomon.coremon.balancer.db.ShardBalancerOptions;
import ru.yandex.solomon.coremon.balancer.state.LoadCalc;
import ru.yandex.solomon.coremon.balancer.state.ShardLoad;
import ru.yandex.solomon.coremon.balancer.state.ShardsLoadMap;
import ru.yandex.solomon.util.time.DurationUtils;

import static ru.yandex.misc.concurrent.CompletableFutures.join;
import static ru.yandex.misc.concurrent.CompletableFutures.joinAll;

/**
 * @author Sergey Polovko
 */
final class Routines {
    private static final Logger logger = LoggerFactory.getLogger(Routines.class);

    private static final double HOSTS_FOR_NEW_SHARDS_FRACTION = 0.1;
    private static final int HOSTS_FOR_NEW_SHARDS_MIN_COUNT = 3;

    private static CoremonHost findHostForNewShard(List<CoremonHost> onlineHosts, ShardBalancerOptions options) {
        // TODO: move common logic of calculating load into cluster
        final ShardsLoadMap[] hostShardsLoad = onlineHosts.stream()
            .map(h -> h.getState(true).getShards())
            .toArray(ShardsLoadMap[]::new);

        final var calc = new LoadCalc(
            hostShardsLoad,
            options.getCpuWeightFactor(),
            options.getMemoryWeightFactor(),
            options.getNetworkWeightFactor());

        return onlineHosts.get(calc.getOneOfLeastLoadedHostsIdx(HOSTS_FOR_NEW_SHARDS_FRACTION, HOSTS_FOR_NEW_SHARDS_MIN_COUNT));
    }

    /**
     * For force shards reassignment we use RoundRobin approach, because:
     *   - most of the time hosts are evenly loaded
     *   - it allows to involve more nodes to load shards
     *   - it is much cheaper
     */
    static class MoveShards {
        private final List<CoremonHost> onlineHosts;
        private final ShardAssignmentsDao assignmentsDao;

        MoveShards(List<CoremonHost> onlineHosts, ShardAssignmentsDao assignmentsDao) {
            this.onlineHosts = onlineHosts;
            this.assignmentsDao = assignmentsDao;
        }

        public ShardAssignments run(CoremonHost offlineHost) {
            final var offlineHostState = offlineHost.getState(false);
            if (offlineHostState.getAssignments().isEmpty()) {
                return ShardAssignments.EMPTY;
            }
            logger.info("move {} shards from {}", offlineHostState.getAssignments().size(), offlineHost.getFqdn());

            final IntSet[] hostShards = roundRobin(offlineHostState);
            final var futures = new ArrayList<CompletableFuture<ShardAssignments>>(hostShards.length);
            for (int i = 0; i < hostShards.length; i++) {
                final CoremonHost onlineHost = onlineHosts.get(i);
                if (onlineHost.getFqdn().equals(offlineHost.getFqdn())) {
                    // do not move kicked shards
                    continue;
                }

                final IntSet shardsForHost = hostShards[i];
                final var newAssignments = ShardAssignments.ofShardIds(onlineHost.getFqdn(), shardsForHost);
                logger.info("move {} shards to {}", newAssignments.size(), onlineHost.getFqdn());

                // (1) save new assignments map
                futures.add(assignmentsDao.save(newAssignments)
                    .thenCompose(aVoid -> {
                        // (2) assign to online host
                        return onlineHost.changeAssignments(shardsForHost, IntSets.EMPTY_SET);
                    })
                    .thenApply(aVoid -> {
                        // (3) unassign shards from offline host
                        // do not wait sync with peer, because it is offline right now
                        offlineHost.changeAssignments(IntSets.EMPTY_SET, shardsForHost);
                        return newAssignments;
                    }));
            }

            return ShardAssignments.combine(joinAll(futures));
        }

        private IntSet[] roundRobin(CoremonHost.State offlineHostState) {
            final IntSet[] hostShards = new IntSet[onlineHosts.size()];
            for (int i = 0; i < onlineHosts.size(); i++) {
                hostShards[i] = new IntOpenHashSet();
            }
            final var it = offlineHostState.getAssignments().getShards().iterator();
            for (int i = 0; it.hasNext(); i++) {
                hostShards[i % hostShards.length].add(it.nextInt());
            }
            return hostShards;
        }
    }

    /**
     * In case partial errors there can be shards which are not assigned to coremon host. This routine
     * will fix that.
     */
    static class AssignLostShards {

        private final SolomonConfWithContext conf;
        private final List<CoremonHost> onlineHosts;
        private final ShardAssignmentsDao assignmentsDao;
        private final ShardBalancerOptions options;

        AssignLostShards(
            SolomonConfWithContext conf,
            List<CoremonHost> onlineHosts,
            ShardAssignmentsDao assignmentsDao,
            ShardBalancerOptions options)
        {
            this.conf = conf;
            this.onlineHosts = onlineHosts;
            this.assignmentsDao = assignmentsDao;
            this.options = options;
        }

        public ShardAssignments run(ShardAssignments currentAssignments) {
            if (onlineHosts.isEmpty()) {
                return ShardAssignments.EMPTY;
            }

            final var shard2Host = new Int2ObjectOpenHashMap<String>();
            for (Shard s : conf.getAllRawShards()) {
                if (currentAssignments.get(s.getNumId()) == null) {
                    final String host = findHostForNewShard(onlineHosts, options).getFqdn();
                    shard2Host.put(s.getNumId(), host);
                    logger.warn("found lost assignment {} -> {}", Integer.toUnsignedString(s.getNumId()), host);
                }
            }

            if (shard2Host.isEmpty()) {
                return ShardAssignments.EMPTY;
            }

            ShardAssignments assignments = ShardAssignments.ownOf(shard2Host);
            join(assignmentsDao.save(assignments));
            return assignments;
        }
    }

    /**
     * Routine for rebalaning shards over hosts to make even load on them.
     */
    static class Rebalance {

        private final List<CoremonHost> onlineHosts;
        private final ShardAssignmentsDao assignmentsDao;
        private final ShardBalancerMetrics metrics;

        Rebalance(List<CoremonHost> onlineHosts, ShardAssignmentsDao assignmentsDao, ShardBalancerMetrics metrics) {
            this.onlineHosts = onlineHosts;
            this.assignmentsDao = assignmentsDao;
            this.metrics = metrics;
        }

        ShardAssignments run(ShardAssignments currentAssignments, ShardBalancerOptions options) {
            final State[] hostShardsState = onlineHosts.stream()
                .map(h -> h.getState(true))
                .toArray(State[]::new);

            ShardsLoadMap[] hostShardsLoad = new ShardsLoadMap[hostShardsState.length];

            int countNotReady = 0;
            for (int index = 0; index < hostShardsState.length; index++) {
                var state = hostShardsState[index];
                var shardsLoadMap = state.getShards();
                hostShardsLoad[index] = shardsLoadMap;
                for (ShardLoad s : hostShardsLoad[index].values()) {
                    if (s.getState() == EShardState.INACTIVE) {
                        continue;
                    }

                    if (s.getState() != EShardState.READY) {
                        if (s.getUptimeMillis() > TimeUnit.MINUTES.toMillis(5L)) {
                            continue;
                        }

                        countNotReady++;
                        if (options.getRebalaceShardsInFlight() >= countNotReady) {
                            continue;
                        }

                        logger.info("skip rebalancing: have not ready shard (id={}, state={}, uptime={}, fqdn={})",
                                Integer.toUnsignedLong(s.getId()),
                                s.getState(),
                                DurationUtils.formatDurationMillis(s.getUptimeMillis()),
                                onlineHosts.get(index).getFqdn());
                        return currentAssignments;
                    }
                }
            }

            final var calc = new LoadCalc(
                hostShardsLoad,
                options.getCpuWeightFactor(),
                options.getMemoryWeightFactor(),
                options.getNetworkWeightFactor());

            if (calc.getLoadScoreDiff() < options.getRebalaceThreshold()) {
                metrics.dispersion.set(calc.getLoadScoreDiff());
                logger.info("skip rebalancing: weight difference is OK");
                return currentAssignments;
            }

            final int maxInflight = options.getRebalaceShardsInFlight() - countNotReady;
            final var reassign = calc.getShardsToMove(maxInflight, options.getRebalaceThreshold());
            metrics.dispersion.set(reassign.dispersion());
            if (reassign.toMove().isEmpty()) {
                logger.info("skip rebalancing: nothing to move");
                return currentAssignments;
            }

            if (!hostShardsState[reassign.fromIdx()].isSynced()) {
                logger.info("skip rebalancing: have not synced hosts (fqdn={})", onlineHosts.get(reassign.fromIdx()).getFqdn());
                return currentAssignments;
            }

            if (!hostShardsState[reassign.toIdx()].isSynced()) {
                logger.info("skip rebalancing: have not synced hosts (fqdn={})", onlineHosts.get(reassign.toIdx()).getFqdn());
                return currentAssignments;
            }

            final CoremonHost fromHost = onlineHosts.get(reassign.fromIdx());
            final CoremonHost toHost = onlineHosts.get(reassign.toIdx());
            final var newAssignments = ShardAssignments.ofShardIds(toHost.getFqdn(), reassign.toMove());
            logger.info("rebalancing: move {} shads from {} to {}",
                    IntStream.of(reassign.toMove().toIntArray())
                            .mapToObj(Integer::toUnsignedLong)
                            .collect(Collectors.toList()),
                    fromHost.getFqdn(),
                    toHost.getFqdn());

            metrics.rebalanceShards.add(reassign.toMove().size());
            var future = assignmentsDao.save(newAssignments)
                .thenCompose(aVoid -> fromHost.changeAssignments(IntSets.EMPTY_SET, reassign.toMove()))
                .thenCompose(aVoid -> toHost.changeAssignments(reassign.toMove(), IntSets.EMPTY_SET))
                .thenApply(aVoid -> newAssignments);

            return join(future);
        }
    }

    /**
     * Assigns newly created (not yet assigned) shard to a host.
     */
    static final class AssignNewShard {
        private final ShardAssignmentsDao assignmentsDao;
        private final List<CoremonHost> onlineHosts;
        private final ShardBalancerOptions options;

        AssignNewShard(
            ShardAssignmentsDao assignmentsDao,
            List<CoremonHost> onlineHosts,
            ShardBalancerOptions options)
        {
            this.assignmentsDao = assignmentsDao;
            this.onlineHosts = onlineHosts;
            this.options = options;
        }

        CompletableFuture<String> run(int shardNumId) {
            CoremonHost coremonHost = findHostForNewShard(onlineHosts, options);
            var assignment = ShardAssignments.singleton(shardNumId, coremonHost.getFqdn());
            logger.info("assign new shard {} to {}", Integer.toUnsignedLong(shardNumId), coremonHost.getFqdn());
            return assignmentsDao.save(assignment)
                .thenCompose(aVoid -> coremonHost.changeAssignments(IntSets.singleton(shardNumId), IntSets.EMPTY_SET))
                .thenApply(aVoid -> coremonHost.getFqdn());
        }
    }

    static final class KickShard {
        private final ShardAssignmentsDao assignmentsDao;
        private final List<CoremonHost> onlineHosts;
        private final ShardBalancerOptions options;

        KickShard(
                ShardAssignmentsDao assignmentsDao,
                List<CoremonHost> onlineHosts,
                ShardBalancerOptions options)
        {
            this.assignmentsDao = assignmentsDao;
            this.onlineHosts = onlineHosts;
            this.options = options;
        }

        CompletableFuture<ShardAssignments> run(int shardNumId) {
            CoremonHost coremonHost = findHostForNewShard(onlineHosts, options);
            var assignment = ShardAssignments.singleton(shardNumId, coremonHost.getFqdn());
            logger.info("kick shard {} to {}", Integer.toUnsignedLong(shardNumId), coremonHost.getFqdn());
            return assignmentsDao.save(assignment)
                    .thenCompose(aVoid -> coremonHost.changeAssignments(IntSets.singleton(shardNumId), IntSets.EMPTY_SET))
                    .thenApply(aVoid -> assignment);
        }
    }
}
