package ru.yandex.solomon.coremon.balancer.cluster;

import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.IntSupplier;

import javax.annotation.concurrent.Immutable;

import it.unimi.dsi.fastutil.ints.IntSet;

import ru.yandex.solomon.coremon.balancer.state.ShardIds;
import ru.yandex.solomon.coremon.balancer.state.ShardsLoadMap;
import ru.yandex.solomon.selfmon.ng.ProcSelfMon;
import ru.yandex.solomon.selfmon.ng.linux.ProcPidStat;
import ru.yandex.solomon.util.host.HostUtils;

import static java.util.concurrent.CompletableFuture.completedFuture;
import static ru.yandex.solomon.selfmon.ng.ProcSelfMon.ticksToMillis;

/**
 * @author Sergey Polovko
 */
public class LocalCoremonHost implements CoremonHost {

    private static final long startTime = System.currentTimeMillis();

    private final ShardInfoProvider shardsInfoProvider;
    private final AtomicReference<ShardIds> assignments = new AtomicReference<>(ShardIds.EMPTY);
    private final ScheduledExecutorService timer;
    private volatile ScheduledFuture<?> pingFuture = null;

    public LocalCoremonHost(ShardInfoProvider shardsInfoProvider, ScheduledExecutorService timer) {
        this.shardsInfoProvider = shardsInfoProvider;
        this.timer = timer;
    }

    @Override
    public String getFqdn() {
        return HostUtils.getFqdn();
    }

    @Override
    public long getSeenAliveTimeMillis() {
        return System.currentTimeMillis();
    }

    @Override
    public void startPinging(long leaderSeqNo, IntSupplier totalAssignmentCount) {
        pingFuture = timer.scheduleAtFixedRate(() -> setTotalShardCount(totalAssignmentCount.getAsInt()),
                0, 1000, TimeUnit.MILLISECONDS);
    }

    @Override
    public void stopPinging() {
        var future = pingFuture;
        if (future != null) {
            future.cancel(false);
        }
    }

    @Override
    public State getState(boolean refreshShardsStatus) {
        ShardsLoadMap shardsStatus = refreshShardsStatus
            ? shardsInfoProvider.getShardsLoad()
            : ShardsLoadMap.EMPTY;
        return new LocalState(assignments.get(), shardsStatus);
    }

    @Override
    public CompletableFuture<Void> setAssignments(IntSet shardIds) {
        ShardIds ids = ShardIds.ofWholeShards(shardIds);
        assignments.set(ids);
        shardsInfoProvider.setLocalShardIds(ids);
        return completedFuture(null);
    }

    @Override
    public CompletableFuture<Void> changeAssignments(IntSet shardIdsAdd, IntSet shardIdsRemove) {
        ShardIds oldValue, newValue;
        do {
            oldValue = assignments.get();
            newValue = oldValue.addRemoveShards(shardIdsAdd, shardIdsRemove);
        } while (!assignments.compareAndSet(oldValue, newValue));
        shardsInfoProvider.setLocalShardIds(newValue);
        return completedFuture(null);
    }

    public void setTotalShardCount(int totalShardCount) {
        shardsInfoProvider.setTotalShardCount(totalShardCount);
    }

    @Override
    public void close() {
        // nop
    }

    /**
     * LOCAL STATE
     */
    @Immutable
    private static final class LocalState implements State {
        final long uptimeMillis;
        final long cpuTimeNanos;
        final long memoryBytes;
        final long networkBytes;
        final ShardIds assignments;
        final ShardsLoadMap shardsLoadMap;

        LocalState(ShardIds assignments, ShardsLoadMap shardsLoadMap) {
            ProcPidStat stat = ProcPidStat.readProcSelfStat();
            this.memoryBytes = ProcSelfMon.pagesToBytes(stat.getRss());
            this.cpuTimeNanos = TimeUnit.MILLISECONDS.toNanos(ticksToMillis(stat.getUtime()));
            this.uptimeMillis = System.currentTimeMillis() - startTime;
            this.networkBytes = 0; // TODO: get from sysfs
            this.assignments = assignments;
            this.shardsLoadMap = shardsLoadMap;
        }

        @Override
        public long getUptimeMillis() {
            return uptimeMillis;
        }

        @Override
        public long getCpuTimeNanos() {
            return cpuTimeNanos;
        }

        @Override
        public long getMemoryBytes() {
            return memoryBytes;
        }

        @Override
        public long getNetworkBytes() {
            return networkBytes;
        }

        @Override
        public ShardIds getAssignments() {
            return assignments;
        }

        @Override
        public ShardsLoadMap getShards() {
            return shardsLoadMap;
        }

        @Override
        public boolean isSynced() {
            return assignments.getHash() == shardsLoadMap.getIdsHash() &&
                assignments.getShards().equals(shardsLoadMap.getIds());
        }
    }
}
