package ru.yandex.solomon.coremon.balancer.cluster;

import java.time.Instant;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.IntSupplier;

import javax.annotation.concurrent.Immutable;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.net.HostAndPort;
import it.unimi.dsi.fastutil.ints.IntSet;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import ru.yandex.grpc.utils.GrpcClientOptions;
import ru.yandex.grpc.utils.GrpcTransport;
import ru.yandex.monitoring.coremon.CoremonBalancerServiceGrpc;
import ru.yandex.monitoring.coremon.TChangeAssignmentsRequest;
import ru.yandex.monitoring.coremon.TChangeAssignmentsResponse;
import ru.yandex.monitoring.coremon.THostLoad;
import ru.yandex.monitoring.coremon.TPingRequest;
import ru.yandex.monitoring.coremon.TPingResponse;
import ru.yandex.monitoring.coremon.TShardsLoad;
import ru.yandex.solomon.coremon.balancer.state.ShardIds;
import ru.yandex.solomon.coremon.balancer.state.ShardsLoadMap;
import ru.yandex.solomon.coremon.meta.service.MetabaseTotalShardCounter;

import static com.google.common.base.Preconditions.checkArgument;
import static java.util.concurrent.CompletableFuture.completedFuture;
import static java.util.concurrent.CompletableFuture.failedFuture;

/**
 * @author Sergey Polovko
 */
final class RemoteCoremonHost implements CoremonHost {
    private static final Logger logger = LoggerFactory.getLogger(RemoteCoremonHost.class);

    private final GrpcTransport transport;
    private final long requestTimeoutMillis;
    private final ScheduledExecutorService timer;
    private final AtomicReference<RemoteState> state = new AtomicReference<>(RemoteState.EMPTY);
    private final AtomicReference<IntSupplier> totalShardCount = new AtomicReference<>(() -> MetabaseTotalShardCounter.SHARD_COUNT_UNKNOWN);

    private volatile long leaderSeqNo;
    private volatile boolean initialized;
    private volatile long seenAliveTimeMillis;
    private volatile ScheduledFuture<?> pingFuture;
    private volatile CountDownLatch pingAwaiter; // for testing

    // for manager ui
    private Throwable lastError;
    private Instant lastErrorTime;

    RemoteCoremonHost(
        HostAndPort address,
        GrpcClientOptions options,
        ScheduledExecutorService timer)
    {
        this.transport = new GrpcTransport(address, options);
        this.requestTimeoutMillis = options.getDefaultTimeoutMillis();
        this.timer = timer;
    }

    @Override
    public void startPinging(long leaderSeqNo, IntSupplier totalShardCount) {
        checkArgument(leaderSeqNo != 0, "leaderSeqNo must not be 0");
        this.initialized = false;
        this.leaderSeqNo = leaderSeqNo;
        this.totalShardCount.set(totalShardCount);
        pingAfterMillis(0, 1000);
    }

    @Override
    public void stopPinging() {
        this.leaderSeqNo = 0;
        ScheduledFuture<?> pingFuture = this.pingFuture;
        if (pingFuture != null) {
            pingFuture.cancel(false);
        }
    }

    private boolean isStopped() {
        return leaderSeqNo == 0;
    }

    private void pingAfterMillis(long minDelayMillis, long maxDelayMillis) {
        if (isStopped()) {
            return;
        }

        final long rndDelayMillis = ThreadLocalRandom.current().nextLong(minDelayMillis, maxDelayMillis);
        if (logger.isDebugEnabled()) {
            logger.debug("schedule next ping of {} after {} ms", getFqdn(), rndDelayMillis);
        }

        pingFuture = timer.schedule(() -> {
            if (isStopped()) {
                return;
            }
            try {
                pingRpc().whenComplete(this::onPingComplete);
            } catch (Throwable t) {
                logger.error("unhandled exception while ping {}", getFqdn(), t);
                saveLastError(t);
                pingAfterMillis(1000, 3000);
            }
        }, rndDelayMillis, TimeUnit.MILLISECONDS);
    }

    private void onPingComplete(TPingResponse response, Throwable error) {
        if (error != null) {
            logger.warn("ping {} failed, reason: {}", getFqdn(), error.getMessage());
            saveLastError(error);
            pingAfterMillis(1000, 3000);
            return;
        }

        seenAliveTimeMillis = System.currentTimeMillis();

        try {
            RemoteState oldState, newState;
            boolean initialized = this.initialized;
            do {
                oldState = state.get();
                newState = oldState.updateLoad(response.getHostLoad(), response.getShardsLoad());
            } while (!state.compareAndSet(oldState, newState));

            if (newState.isSynced()) {
                pingAfterMillis(5000, 8000);
            } else if (!initialized) {
                logger.info("skip assignments ({} shards) synchronization on {} because not initialized yet", newState.getAssignments().size(), getFqdn());
                pingAfterMillis(1000, 3000);
            } else {
                logger.info("force assignments ({} shards) synchronization on {}", newState.getAssignments().size(), getFqdn());
                syncAssignments(newState.getAssignments());
                pingAfterMillis(1000, 3000);
            }

            CountDownLatch pingAwaiter = this.pingAwaiter;
            if (pingAwaiter != null) {
                pingAwaiter.countDown();
            }
        } catch (Throwable t) {
            logger.error("cannot update state in onPingComplete()", t);
        }
    }

    @Override
    public String getFqdn() {
        return transport.getAddress().getHost();
    }

    @Override
    public long getSeenAliveTimeMillis() {
        return seenAliveTimeMillis;
    }

    @Override
    public State getState(boolean refreshShardsStatus) {
        return state.get();
    }

    @Override
    public CompletableFuture<Void> setAssignments(IntSet shardIds) {
        ShardIds assignments = ShardIds.ofWholeShards(shardIds);

        // skip remote host update if it was previously initialized and there are
        // no changes in shard assignments
        if (initialized && state.get().getAssignments().equals(assignments)) {
            return completedFuture(null);
        }

        try {
            RemoteState oldState, newState;
            do {
                oldState = state.get();
                newState = oldState.updateAssignments(assignments);
            } while (!state.compareAndSet(oldState, newState));

            return syncAssignments(assignments)
                    .whenComplete((aVoid, throwable) -> {
                        if (throwable == null) {
                            // mark host as initialized only when we make successful remote call
                            initialized = true;
                        }
                    });
        } catch (Throwable t) {
            return failedFuture(t);
        }
    }

    private CompletableFuture<Void> syncAssignments(ShardIds assignments) {
        var request = TChangeAssignmentsRequest.newBuilder()
            .setLeaderSeqNo(leaderSeqNo)
            .setExpiredAt(System.currentTimeMillis() + requestTimeoutMillis)
            .addAllShardIdsSet(assignments.getShards())
            .build();

        return changeAssignmentsRpc(request)
            .whenComplete(this::onChangeAssignmentsComplete)
            .thenAccept(r -> {});
    }

    @Override
    public CompletableFuture<Void> changeAssignments(IntSet shardIdsAdd, IntSet shardIdsRemove) {
        if (shardIdsAdd.isEmpty() && shardIdsRemove.isEmpty()) {
            return completedFuture(null);
        }
        try {
            RemoteState oldState, newState;
            do {
                oldState = state.get();
                ShardIds newAssignments = oldState.getAssignments()
                    .addRemoveShards(shardIdsAdd, shardIdsRemove);
                newState = oldState.updateAssignments(newAssignments);
            } while (!state.compareAndSet(oldState, newState));

            var request = TChangeAssignmentsRequest.newBuilder()
                .setLeaderSeqNo(leaderSeqNo)
                .setExpiredAt(System.currentTimeMillis() + requestTimeoutMillis)
                .addAllShardIdsAdd(shardIdsAdd)
                .addAllShardIdsRemove(shardIdsRemove)
                .build();

            return changeAssignmentsRpc(request)
                .whenComplete(this::onChangeAssignmentsComplete)
                .thenAccept(r -> {});
        } catch (Throwable t) {
            return failedFuture(t);
        }
    }

    private void onChangeAssignmentsComplete(TChangeAssignmentsResponse reponse, Throwable error) {
        if (error != null) {
            logger.warn("changeAssignments {} failed, reason: {}", getFqdn(), error.getMessage());
            saveLastError(error);
        } else {
            seenAliveTimeMillis = System.currentTimeMillis();
        }
    }

    private void saveLastError(Throwable throwable) {
        lastError = throwable;
        lastErrorTime = Instant.now();
    }

    private CompletableFuture<TChangeAssignmentsResponse> changeAssignmentsRpc(TChangeAssignmentsRequest request) {
        return transport.unaryCall(CoremonBalancerServiceGrpc.getChangeAssignmentsMethod(), request);
    }

    private CompletableFuture<TPingResponse> pingRpc() {
        var request = TPingRequest.newBuilder()
                .setLeaderSeqNo(leaderSeqNo)
                .setExpiredAt(System.currentTimeMillis() + requestTimeoutMillis)
                .setShardIdsHash(this.state.get().getShards().getIdsHash());

        int totalShardCount = this.totalShardCount.get().getAsInt();
        if (totalShardCount >= 0) {
            request.setTotalShardCountKnown(true);
            request.setTotalShardCount(totalShardCount);
        }

        return transport.unaryCall(CoremonBalancerServiceGrpc.getPingMethod(), request.build());
    }

    @Override
    public void close() {
        stopPinging();
        transport.close();
    }

    @VisibleForTesting
    void awaitNextPing() {
        CountDownLatch pingAwaiter = new CountDownLatch(1);
        this.pingAwaiter = pingAwaiter;
        try {
            pingAwaiter.await();
        } catch (InterruptedException e) {
            throw new RuntimeException("cannot await ping", e);
        }
    }

    @Override
    public String toString() {
        return "{address=" + transport.getAddress() + '}';
    }

    /**
     * REMOTE STATE
     */
    @Immutable
    private static final class RemoteState implements State {

        static final RemoteState EMPTY = new RemoteState(0, 0, 0, 0, ShardIds.EMPTY, ShardsLoadMap.EMPTY);

        private final long uptimeMillis;
        private final long cpuTimeNanos;
        private final long memoryBytes;
        private final long networkBytes;
        private final ShardIds assignments;
        private final ShardsLoadMap shards;

        private RemoteState(
            long uptimeMillis,
            long cpuTimeNanos,
            long memoryBytes,
            long networkBytes,
            ShardIds assignments,
            ShardsLoadMap shards)
        {
            this.uptimeMillis = uptimeMillis;
            this.cpuTimeNanos = cpuTimeNanos;
            this.memoryBytes = memoryBytes;
            this.networkBytes = networkBytes;
            this.assignments = assignments;
            this.shards = shards;
        }

        RemoteState updateLoad(THostLoad hostLoad, TShardsLoad shardsLoad) {
            ShardsLoadMap newShards = (shardsLoad != TShardsLoad.getDefaultInstance())
                ? ShardsLoadMap.fromPb(shardsLoad)
                : this.shards;
            return new RemoteState(
                hostLoad.getUptimeMillis(),
                hostLoad.getCpuTimeNanos(),
                hostLoad.getMemoryBytes(),
                hostLoad.getNetworkBytes(),
                assignments,
                newShards);
        }

        RemoteState updateAssignments(ShardIds assignments) {
            return new RemoteState(
                uptimeMillis, cpuTimeNanos, memoryBytes, networkBytes,
                assignments,
                shards.retainAll(assignments.getShards()));
        }

        @Override
        public long getUptimeMillis() {
            return uptimeMillis;
        }

        @Override
        public long getCpuTimeNanos() {
            return cpuTimeNanos;
        }

        @Override
        public long getMemoryBytes() {
            return memoryBytes;
        }

        @Override
        public long getNetworkBytes() {
            return networkBytes;
        }

        @Override
        public ShardIds getAssignments() {
            return assignments;
        }

        @Override
        public ShardsLoadMap getShards() {
            return shards;
        }

        @Override
        public boolean isSynced() {
            return assignments.getHash() == shards.getIdsHash() &&
                assignments.getShards().equals(shards.getIds());
        }
    }
}
