package ru.yandex.solomon.coremon.balancer.state;

import java.util.Collection;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;

import it.unimi.dsi.fastutil.doubles.DoubleArrays;
import it.unimi.dsi.fastutil.ints.IntOpenHashSet;
import it.unimi.dsi.fastutil.ints.IntSet;
import it.unimi.dsi.fastutil.ints.IntSets;

import ru.yandex.monitoring.coremon.EShardState;

/**
 * @author Sergey Polovko
 */
public class LoadCalc {
    private static final long UPTIME_THRESHOLD = TimeUnit.MINUTES.toMillis(1);

    private final ShardsLoadMap[] hostShardsLoad;
    private final double cpuFactor;
    private final double memoryFactor;
    private final double networkFactor;

    private final Load clusterLoad;
    private final Load[] hostLoad;
    private final double[] loadScore;
    private final int[] indexes;

    public LoadCalc(ShardsLoadMap[] hostShardsLoad, double cpuFactor, double memoryFactor, double networkFactor) {
        this.hostShardsLoad = hostShardsLoad;
        this.cpuFactor = cpuFactor;
        this.memoryFactor = memoryFactor;
        this.networkFactor = networkFactor;

        this.hostLoad = new Load[hostShardsLoad.length];
        this.loadScore = new double[hostShardsLoad.length];
        this.indexes = new int[hostShardsLoad.length];

        for (int i = 0; i < hostShardsLoad.length; i++) {
            hostLoad[i] = sumShards(hostShardsLoad[i].values());
            indexes[i] = i;
        }

        clusterLoad = sumLoads(hostLoad);
        for (int i = 0; i < hostLoad.length; i++) {
            loadScore[i] = loadScore(hostLoad[i], clusterLoad, cpuFactor, memoryFactor, networkFactor);
        }

        DoubleArrays.quickSortIndirect(indexes, loadScore);
    }

    public static double loadScore(
        Load load, Load totalLoad,
        double cpuFactor, double memoryFactor, double networkFactor)
    {
        double cpuScore = totalLoad.cpuTimeNanos > 0 ? (cpuFactor * load.cpuTimeNanos) / totalLoad.cpuTimeNanos : 0d;
        double memoryScore = totalLoad.metricsCount > 0 ? (memoryFactor * load.metricsCount) / totalLoad.metricsCount : 0d;
        double networkScore = totalLoad.networkBytes > 0 ? (networkFactor * load.networkBytes) / totalLoad.networkBytes : 0d;
        return cpuScore + memoryScore + networkScore;
    }

    public double getLoadScoreDiff() {
        return loadScore[getMostLoadedHostIdx()] - loadScore[getLeastLoadedHostIdx()];
    }

    private int getLeastLoadedHostIdx() {
        return indexes[0];
    }

    private int getMostLoadedHostIdx() {
        return indexes[indexes.length - 1];
    }

    public int getOneOfLeastLoadedHostsIdx(double bottomFraction, int minCount) {
        int upper = Math.min(indexes.length, Math.max(minCount, (int) (bottomFraction * indexes.length)));
        return indexes[ThreadLocalRandom.current().nextInt(upper)];
    }

    public Reassign getShardsToMove(int count, double scoreDiffThreshold) {
        for (int idx = indexes.length - 1; idx > 0; idx--) {
            var fromHostIdx = indexes[idx];
            var dispersion = loadScore[fromHostIdx] - loadScore[getLeastLoadedHostIdx()];
            if (loadScore[fromHostIdx] - loadScore[getLeastLoadedHostIdx()] < scoreDiffThreshold) {
                return new Reassign(IntSets.EMPTY_SET, 0, 0, dispersion);
            }
            ShardsLoadMap shardsLoadMap = hostShardsLoad[fromHostIdx];
            Load totalLoad = hostLoad[fromHostIdx];

            int[] shardIds = new int[shardsLoadMap.size()];
            Load[] shardLoad = new Load[shardsLoadMap.size()];
            double[] shardWeights = new double[shardsLoadMap.size()];
            int[] shardIndexes = new int[shardsLoadMap.size()];
            int size = 0;

            for (ShardLoad s : shardsLoadMap.values()) {
                if (s.getState() != EShardState.READY) {
                    // do not move, not ready shards
                    continue;
                }
                if (s.getUptimeMillis() <= UPTIME_THRESHOLD) {
                    // do not move, recently assign shard
                    continue;
                }
                var load = toLoad(s);
                var loadScore = loadScore(load, totalLoad, cpuFactor, memoryFactor, networkFactor);
                if (loadScore <= 0d) {
                    // do not move, empty shards
                    continue;
                }
                shardIds[size] = s.getId();
                shardLoad[size] = load;
                shardWeights[size] = loadScore;
                shardIndexes[size] = size;
                size++;
            }

            if (size <= 1) {
                continue;
            }

            DoubleArrays.quickSortIndirect(shardIndexes, shardWeights, 0, size);

            var maxToMove = Math.min(count, size - 1);
            IntSet toMove = new IntOpenHashSet(maxToMove);
            var fromHostLoad = hostLoad[fromHostIdx];
            var toHostLoad = hostLoad[getLeastLoadedHostIdx()];
            for (int i = 0; i < maxToMove && i < shardIndexes.length; i++) {
                var shardIdx = shardIndexes[i];
                int shardId = shardIds[shardIdx];
                toMove.add(shardId);

                // least loaded not least anymore
                {
                    var nextLoad = sumLoads(toHostLoad, shardLoad[shardIdx]);
                    var nextScore = loadScore(nextLoad, clusterLoad, cpuFactor, memoryFactor, networkFactor);
                    if (nextScore > loadScore[indexes[1]]) {
                        break;
                    }
                    toHostLoad = nextLoad;
                }

                // most loaded not most anymore
                {
                    var nextLoad = minus(fromHostLoad, shardLoad[shardIdx]);
                    var nextScore = loadScore(nextLoad, clusterLoad, cpuFactor, memoryFactor, networkFactor);
                    if (nextScore < loadScore[indexes[idx - 1]]) {
                        break;
                    }
                    fromHostLoad = nextLoad;
                }

                // target diff by score reached
                {
                    var fromScore = loadScore(fromHostLoad, clusterLoad, cpuFactor, memoryFactor, networkFactor);
                    var toScore = loadScore(toHostLoad, clusterLoad, cpuFactor, memoryFactor, networkFactor);
                    var diff = fromScore - toScore;
                    if (diff < scoreDiffThreshold) {
                        break;
                    }
                }
            }
            return new Reassign(toMove, fromHostIdx, getLeastLoadedHostIdx(), dispersion);
        }
        return new Reassign(IntSets.EMPTY_SET, 0, 0, getLoadScoreDiff());
    }

    private static Load toLoad(ShardLoad s) {
        return new Load(s.getCpuTimeNanos(), s.getMetricsCount(), s.getNetworkBytes());
    }

    public static Load sumShards(Collection<ShardLoad> shards) {
        double cpu = 0.0;
        double metrics = 0.0;
        double network = 0.0;
        for (ShardLoad shard : shards) {
            cpu += shard.getCpuTimeNanos();
            metrics += shard.getMetricsCount();
            network += shard.getNetworkBytes();
        }
        return new Load(cpu, metrics, network);
    }

    public static Load sumLoads(Load... loads) {
        double cpu = 0.0;
        double metrics = 0.0;
        double network = 0.0;
        for (Load l : loads) {
            cpu += l.cpuTimeNanos;
            metrics += l.metricsCount;
            network += l.networkBytes;
        }
        return new Load(cpu, metrics, network);
    }

    public static Load minus(Load left, Load right) {
        return new Load(
                left.cpuTimeNanos - right.cpuTimeNanos,
                left.metricsCount - right.metricsCount,
                left.networkBytes - right.networkBytes);
    }

    /**
     * LOAD
     */
    public static final class Load {
        public final double cpuTimeNanos;
        public final double metricsCount;
        public final double networkBytes;

        Load(double cpuTimeNanos, double metricsCount, double networkBytes) {
            this.cpuTimeNanos = cpuTimeNanos;
            this.metricsCount = metricsCount;
            this.networkBytes = networkBytes;
        }
    }

    public static final record Reassign(IntSet toMove, int fromIdx, int toIdx, double dispersion) {
    }
}
