package ru.yandex.solomon.alert.cluster.balancer;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;

import com.google.common.net.HostAndPort;
import io.grpc.Context;
import io.grpc.Status;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import ru.yandex.misc.actor.ActorWithFutureRunner;
import ru.yandex.misc.concurrent.CompletableFutures;
import ru.yandex.monlib.metrics.MetricConsumer;
import ru.yandex.monlib.metrics.MetricSupplier;
import ru.yandex.monlib.metrics.MetricType;
import ru.yandex.monlib.metrics.labels.Labels;
import ru.yandex.solomon.alert.cluster.balancer.client.AlertingBalancerClient;
import ru.yandex.solomon.alert.cluster.broker.AlertingProjectShard;
import ru.yandex.solomon.alert.cluster.broker.AlertingProjectShardFactory;
import ru.yandex.solomon.alert.cluster.broker.ShardMetrics;
import ru.yandex.solomon.alert.cluster.project.AssignmentConverter;
import ru.yandex.solomon.alert.cluster.project.AssignmentSnapshot;
import ru.yandex.solomon.alert.cluster.project.ProjectAssignment;
import ru.yandex.solomon.alert.domain.StringInterner;
import ru.yandex.solomon.alert.protobuf.TAssignProjectRequest;
import ru.yandex.solomon.alert.protobuf.TAssignProjectResponse;
import ru.yandex.solomon.alert.protobuf.TProjectAssignmentRequest;
import ru.yandex.solomon.alert.protobuf.TUnassignProjectRequest;
import ru.yandex.solomon.alert.protobuf.TUnassignProjectResponse;
import ru.yandex.solomon.balancer.AssignmentSeqNo;
import ru.yandex.solomon.locks.DistributedLock;
import ru.yandex.solomon.util.async.InFlightLimiter;
import ru.yandex.solomon.util.host.HostUtils;

import static java.util.concurrent.CompletableFuture.completedFuture;
import static java.util.stream.Collectors.toList;

/**
 * @author Vladimir Gordiychuk
 */
public class AlertingLocalShardsImpl implements AutoCloseable, AlertingLocalShards, MetricSupplier {
    private static final Logger logger = LoggerFactory.getLogger(AlertingLocalShardsImpl.class);

    private final String localAddress;
    private final ConcurrentMap<String, AlertingProjectShard> shards;
    private final AlertingProjectShardFactory shardFactory;
    private final AlertingBalancerClient client;
    private final DistributedLock leader;
    private final ScheduledExecutorService timer;
    private final ActorWithFutureRunner refreshAssignmentActor;
    private final InFlightLimiter initLimiter = new InFlightLimiter(900);

    private final AtomicReference<AssignmentSeqNo> seqNo = new AtomicReference<>(AssignmentSeqNo.EMPTY);
    private volatile boolean closed;

    public AlertingLocalShardsImpl(
        AlertingProjectShardFactory factory,
        AlertingBalancerClient client,
        DistributedLock leader,
        ScheduledExecutorService timer,
        ExecutorService executor)
    {
        this(HostUtils.getFqdn(), factory, client, leader, timer, executor);
    }

    public AlertingLocalShardsImpl(
        String localAddress,
        AlertingProjectShardFactory factory,
        AlertingBalancerClient client,
        DistributedLock leader,
        ScheduledExecutorService timer,
        ExecutorService executor)
    {
        this.shards = new ConcurrentHashMap<>();
        this.localAddress = localAddress;
        this.shardFactory = factory;
        this.client = client;
        this.leader = leader;
        this.timer = timer;
        this.refreshAssignmentActor = new ActorWithFutureRunner(this::actSyncAssignments, executor);
    }

    @Override
    public CompletableFuture<TAssignProjectResponse> assignShard(TAssignProjectRequest request) {
        return ensureLeaderOwnership(request.getSeqNo().getLeaderSeqNo())
            .thenApply(node -> {
                ensureDeadlineNotExpired(request.getExpiredAt());
                logger.info("Receive shard assign {}", request);
                var projectId = StringInterner.I.intern(request.getProjectId());
                var seqNo = AssignmentConverter.fromProto(request.getSeqNo());
                ensureValidCommand(projectId, seqNo);
                var assignment = new ProjectAssignment(projectId, localAddress, seqNo);
                if (addShard(assignment)) {
                    return TAssignProjectResponse.getDefaultInstance();
                } else {
                    throw Status.ABORTED
                        .withDescription("Already assigned shard: " + shards.get(assignment.getProjectId()))
                        .asRuntimeException();
                }
            });
    }

    @Override
    public CompletableFuture<TUnassignProjectResponse> unassignShard(TUnassignProjectRequest request) {
        return ensureLeaderOwnership(request.getSeqNo().getLeaderSeqNo())
            .thenCompose(node -> {
                ensureDeadlineNotExpired(request.getExpiredAt());
                logger.info("Receive shard unassign {}", request);
                var projectId = request.getProjectId();
                var seqNo = AssignmentConverter.fromProto(request.getSeqNo());
                ensureValidCommand(projectId, seqNo);
                return removeShard(projectId, seqNo, request.getGracefull());
            });
    }

    private void ensureValidCommand(String projectId, AssignmentSeqNo seqNo) {
        var actual = this.seqNo.get();
        int compare = actual.compareTo(seqNo);
        if (compare > 0) {
            var shard = shards.get(projectId);
            if (shard != null && shard.getAssignment().getSeqNo().equals(seqNo)) {
                return;
            }

            throw Status.ABORTED
                .withDescription("SeqNo mismatch, actual " + actual + " request " + seqNo)
                .asRuntimeException();
        } else if (compare == 0) {
            return;
        }

        var next = new AssignmentSeqNo(actual.getLeaderSeqNo(), actual.getAssignSeqNo() + 1);
        if (!next.equals(seqNo)) {
            refreshAssignmentActor.schedule();
            return;
        }

        if (!this.seqNo.compareAndSet(actual, next)) {
            refreshAssignmentActor.schedule();
        }
    }

    private CompletableFuture<?> actSyncAssignments() {
        if (closed) {
            return completedFuture(null);
        }

        return leader.getLockDetail(seqNo.get().getLeaderSeqNo())
            .thenCompose(opt -> {
                if (opt.isEmpty()) {
                    // leader unknown, try refresh assignments later
                    timer.schedule(refreshAssignmentActor::schedule, 1, TimeUnit.SECONDS);
                    return completedFuture(null);
                }

                String address = HostAndPort.fromString(opt.get().owner()).getHost();
                return client.listAssignments(address, TProjectAssignmentRequest.newBuilder()
                    .setExpiredAt(System.currentTimeMillis() + 10_000)
                    .build())
                    .thenAccept(response -> updateAssignment(AssignmentConverter.fromProto(response)));
            })
            .exceptionally(e -> {
                logger.error("failed sync assignment snapshot, try again, later", e);
                timer.schedule(refreshAssignmentActor::schedule, 1, TimeUnit.SECONDS);
                return null;
            });
    }

    private void updateAssignment(AssignmentSnapshot snapshot) {
        if (!syncUnassignWithSnapshot(snapshot)) {
            refreshAssignmentActor.schedule();
        }

        if (!syncAssignWithSnapshot(snapshot)) {
            refreshAssignmentActor.schedule();
        }

        logger.info("global assignment state synced with {}", snapshot.getAssignmentSeqNo());
        seqNo.set(snapshot.getAssignmentSeqNo());
    }

    private boolean syncUnassignWithSnapshot(AssignmentSnapshot snapshot) {
        for (var shard : shards.values()) {
            var localAssignment = shard.getAssignment();
            if (localAssignment.getSeqNo().compareTo(snapshot.getAssignmentSeqNo()) > 0) {
                return false;
            }

            var assignment = snapshot.getAssignment(shard.getProjectId());
            if (localAssignment.equals(assignment)) {
                continue;
            }

            if (assignment == null) {
                logger.info("Shard unassign {}, caused by absence in snapshot", shard.getProjectId());
            } else if (localAssignment.compareTo(assignment) < 0) {
                logger.info("Shard unassign {}, caused by obsolete seqNo", shard.getProjectId());
            }

            boolean success = shards.remove(shard.getProjectId(), shard);
            shard.forceShutdown();
            if (!success) {
                return false;
            }
        }
        return true;
    }

    private boolean syncAssignWithSnapshot(AssignmentSnapshot snapshot) {
        for (ProjectAssignment assignment : snapshot.getAssignmentByProjectId().values()) {
            if (!Objects.equals(assignment.getAddress(), localAddress)) {
                continue;
            }

            if (!addShard(assignment)) {
                return false;
            }
        }

        return true;
    }

    private boolean addShard(ProjectAssignment assignment) {
        var assigned = shards.get(assignment.getProjectId());
        if (assignment.equals(assigned != null ? assigned.getAssignment() : null)) {
            return true;
        }

        var newShard = shardFactory.create(assignment);
        var shard = shards.compute(assignment.getProjectId(), (projectId, prev) -> {
            if (prev == null) {
                logger.info("Shard assigned {}, with seqNo {}", assignment.getProjectId(), assignment.getSeqNo());
                return newShard;
            }

            if (prev.getAssignment().equals(newShard.getAssignment())) {
                return prev;
            }

            prev.forceShutdown();
            logger.info("Shard assigned {}, with seqNo {}", assignment.getProjectId(), assignment.getSeqNo());
            return newShard;
        });

        if (shard == newShard) {
            // We don't wait complete of shard loading, that why we fork
            // current grpc context to avoid cancel it by parent context
            // see https://st.yandex-team.ru/YDBREQUESTS-335
            var ctx = Context.current().fork();
            var prev = ctx.attach();
            try {
                initShardWithRetry(newShard);
            } finally {
                ctx.detach(prev);
            }
        }

        return shard.getAssignment().equals(assignment);
    }

    private CompletableFuture<TUnassignProjectResponse> removeShard(String projectId, AssignmentSeqNo seqNo, boolean graceful) {
        AlertingProjectShard shard;
        do {
            shard = shards.get(projectId);
            if (shard == null || shard.getAssignment().getSeqNo().compareTo(seqNo) > 0) {
                return completedFuture(TUnassignProjectResponse.newBuilder().build());
            }
        } while (!shards.remove(projectId, shard));

        if (graceful) {
            return shard.gracefulShutdown()
                .thenApply(ignore -> TUnassignProjectResponse.getDefaultInstance());
        }

        shard.forceShutdown();
        return completedFuture(TUnassignProjectResponse.getDefaultInstance());
    }

    private CompletableFuture<String> ensureLeaderOwnership(long seqNo) {
        return leader.getLockDetail(seqNo)
            .thenApply(detail -> {
                if (detail.isEmpty()) {
                    throw Status.ABORTED
                        .withDescription("Reject because leader ownership expired")
                        .asRuntimeException();
                }

                if (Long.compareUnsigned(seqNo, detail.get().seqNo()) != 0) {
                    throw Status.ABORTED
                        .withDescription("Rejected, seqNo mismatch("
                            + seqNo
                            + " != "
                            + detail.get().seqNo()
                            + "), leader now "
                            + detail.get().owner())
                        .asRuntimeException();
                }

                return HostAndPort.fromString(detail.get().owner()).getHost();
            });
    }

    private void initShardWithRetry(AlertingProjectShard shard) {
        initLimiter.run(() -> shard.run()
                .whenComplete((r, e) -> {
                    if (e != null) {
                        logger.error("Failed init shard: {}", shard.getAssignment(), e);
                        if (shards.get(shard.getProjectId()) != shard) {
                            shard.forceShutdown();
                            return;
                        }

                        long retryDelayMillis = ThreadLocalRandom.current().nextInt(5_000, 30_000);
                        timer.schedule(() -> initShardWithRetry(shard), retryDelayMillis, TimeUnit.MILLISECONDS);
                    }
                }));
    }

    private void ensureDeadlineNotExpired(long expiredAt) {
        if (expiredAt == 0) {
            return;
        }

        if (System.currentTimeMillis() + 200L >= expiredAt) {
            throw Status.DEADLINE_EXCEEDED.asRuntimeException();
        }
    }

    @Override
    public boolean isAssignmentActual(AssignmentSeqNo seqNo) {
        if (this.seqNo.get().compareTo(seqNo) == 0) {
            return true;
        }

        refreshAssignmentActor.schedule();
        return false;
    }

    @Nullable
    @Override
    public AlertingProjectShard getShardById(String shardId) {
        return shards.get(shardId);
    }

    @Override
    public boolean isReady(String projectId) {
        AlertingProjectShard shard = getShardById(projectId);
        if (shard == null) {
            return false;
        }
        return shard.isReady();
    }

    @Override
    public List<String> assignedShards() {
        return new ArrayList<>(shards.keySet());
    }

    @Nonnull
    @Override
    public Iterator<AlertingProjectShard> iterator() {
        return new ArrayList<>(shards.values()).iterator();
    }

    @Override
    public void close() {
        closed = true;
        gracefulShutdown().join();
    }

    @Override
    public CompletableFuture<Void> gracefulShutdown() {
        closed = true;
        var future = shards.values().parallelStream()
            .map(AlertingProjectShard::gracefulShutdown)
            .collect(Collectors.collectingAndThen(toList(), CompletableFutures::allOfVoid));
        shards.clear();
        return future;
    }

    @Override
    public Stream<AlertingProjectShard> stream() {
        return shards.values().stream();
    }

    @Override
    public int estimateCount() {
        return shards.values()
            .stream()
            .mapToInt(shard -> shard.getMetrics().estimateCount())
            .sum();
    }

    @Override
    public void append(long tsMillis, Labels commonLabels, MetricConsumer consumer) {
        long loading = 0;
        long ready = 0;
        ShardMetrics total = new ShardMetrics(Labels.of("projectId", "total"));
        for (AlertingProjectShard shard : shards.values()) {
            ShardMetrics metrics = shard.getMetrics();
            total.combine(metrics);
            metrics.append(tsMillis, commonLabels, consumer);
            if (shard.isReady()) {
                ready++;
            } else {
                loading++;
            }
        }
        total.append(tsMillis, commonLabels, consumer);
        append("alerting.shard.state", commonLabels.add("state", "LOADING"), loading, consumer);
        append("alerting.shard.state", commonLabels.add("state", "READY"), ready, consumer);
        append("alerting.shard.count", commonLabels, ready + loading, consumer);
    }

    private void append(String metric, Labels labels, long value, MetricConsumer consumer) {
        consumer.onMetricBegin(MetricType.IGAUGE);
        consumer.onLabelsBegin(labels.size() + 1);
        labels.forEach(consumer::onLabel);
        consumer.onLabel("sensor", metric);
        consumer.onLabelsEnd();
        consumer.onLong(0, value);
        consumer.onMetricEnd();
    }
}
