package ru.yandex.solomon.name.resolver.balancer;

import java.util.concurrent.CompletableFuture;

import io.grpc.Status;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import ru.yandex.solomon.balancer.BalancerProto;
import ru.yandex.solomon.balancer.TAssignShardRequest;
import ru.yandex.solomon.balancer.TAssignShardResponse;
import ru.yandex.solomon.balancer.TUnassignShardRequest;
import ru.yandex.solomon.balancer.TUnassignShardResponse;
import ru.yandex.solomon.locks.DistributedLock;
import ru.yandex.solomon.name.resolver.NameResolverLocalShards;
import ru.yandex.solomon.name.resolver.NameResolverShardFactory;

import static java.util.Objects.requireNonNull;
import static java.util.concurrent.CompletableFuture.completedFuture;

/**
 * @author Vladimir Gordiychuk
 */
public class AssignmentGate {
    private static final Logger logger = LoggerFactory.getLogger(AssignmentGate.class);

    private final NameResolverLocalShards shards;
    private final DistributedLock leader;
    private final NameResolverShardFactory factory;

    public AssignmentGate(NameResolverLocalShards shards, DistributedLock leader, NameResolverShardFactory factory) {
        this.shards = shards;
        this.leader = leader;
        this.factory = factory;
    }

    public CompletableFuture<TAssignShardResponse> assignShard(TAssignShardRequest request) {
        return ensureLeaderOwnership(request.getAssignmentSeqNo().getLeaderSeqNo())
            .thenApply(node -> {
                ensureDeadlineNotExpired(request.getExpiredAt());
                var assignment = requireNonNull(BalancerProto.fromProto(request.getAssignmentSeqNo()));
                logger.info("Receive shard assignment {} {} from {}", request.getShardId(), assignment, node);

                var prev = shards.getShardById(request.getShardId());
                if (prev != null) {
                    prev.stop();
                    shards.remove(prev);
                }

                var shard = factory.create(request.getShardId(), assignment);
                if (!shards.addShard(shard)) {
                    throw Status.FAILED_PRECONDITION
                        .withDescription("Shard " + request.getShardId() + " already assigned " + shard.seqNo + " by another thread")
                        .asRuntimeException();
                }

                shard.start();
                return TAssignShardResponse.getDefaultInstance();
            });
    }

    public CompletableFuture<TUnassignShardResponse> unassignShard(TUnassignShardRequest request) {
        return ensureLeaderOwnership(request.getAssignmentSeqNo().getLeaderSeqNo())
            .thenCompose(node -> {
                ensureDeadlineNotExpired(request.getExpiredAt());
                logger.info("Receive shard unassign {}", request.getShardId());

                var shard = shards.getShardById(request.getShardId());
                if (shard == null) {
                    return completedFuture(TUnassignShardResponse.getDefaultInstance());
                }

                shards.remove(shard);
                if (!request.getGraceful()) {
                    shard.stop();
                    return completedFuture(TUnassignShardResponse.getDefaultInstance());
                }

                return shard.stop().thenApply(ignore -> TUnassignShardResponse.newBuilder().build());
            });
    }

    private CompletableFuture<String> ensureLeaderOwnership(long seqNo) {
        return leader.getLockDetail(seqNo)
            .thenApply(detail -> {
                if (detail.isEmpty()) {
                    throw Status.ABORTED
                        .withDescription("Reject because leader ownership expired")
                        .asRuntimeException();
                }

                if (Long.compareUnsigned(seqNo, detail.get().seqNo()) != 0) {
                    throw Status.ABORTED
                        .withDescription("Rejected, seqNo mismatch("
                            + seqNo
                            + " != "
                            + detail.get().seqNo()
                            + "), leader now "
                            + detail.get().owner())
                        .asRuntimeException();
                }

                return detail.get().owner();
            });
    }

    private void ensureDeadlineNotExpired(long expiredAt) {
        if (expiredAt == 0) {
            return;
        }

        if (System.currentTimeMillis() + 200L >= expiredAt) {
            throw Status.DEADLINE_EXCEEDED.asRuntimeException();
        }
    }
}
