package ru.yandex.solomon.coremon.balancer;

import java.util.ArrayList;
import java.util.List;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;

import javax.annotation.ParametersAreNonnullByDefault;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import com.google.common.collect.Sets;
import it.unimi.dsi.fastutil.objects.ObjectOpenHashSet;

import ru.yandex.misc.concurrent.CompletableFutures;
import ru.yandex.solomon.balancer.ShardsHolder;
import ru.yandex.solomon.core.conf.SolomonConfWithContext;
import ru.yandex.solomon.core.conf.watch.SolomonConfListener;
import ru.yandex.solomon.core.db.model.StringInterner;
import ru.yandex.solomon.coremon.balancer.db.BalancerShard;
import ru.yandex.solomon.coremon.balancer.db.BalancerShardsDao;
import ru.yandex.solomon.locks.DistributedLock;
import ru.yandex.solomon.util.async.InFlightLimiter;

import static java.util.concurrent.CompletableFuture.completedFuture;
import static ru.yandex.misc.concurrent.CompletableFutures.safeCall;

/**
 * @author Stanislav Kashirin
 */
@ParametersAreNonnullByDefault
public final class CoremonShardsHolder implements ShardsHolder, SolomonConfListener, AutoCloseable {

    // must be (2^n)-1
    @VisibleForTesting
    static final int RECONCILIATION_THRESHOLD = 0x1F;

    private final Set<String> shardIds = ConcurrentHashMap.newKeySet();
    private final ConcurrentMap<String, CompletableFuture<Void>> deletingShards = new ConcurrentHashMap<>();

    private final AtomicReference<Object> lastState = new AtomicReference<>(Set.of());
    private final AtomicInteger extractCount = new AtomicInteger(0);

    private final InFlightLimiter deleteLimiter = new InFlightLimiter(10);

    private final BalancerShardsDao dao;
    private final DistributedLock lock;
    private final CompletableFuture<Void> initFuture;
    private final Cache<String, Boolean> recentlyAddedShards;

    private volatile boolean closed;

    CoremonShardsHolder(BalancerShardsDao dao, DistributedLock lock) {
        this.dao = dao;
        this.lock = lock;
        this.initFuture = new CompletableFuture<>();
        this.recentlyAddedShards = CacheBuilder.newBuilder()
            .expireAfterWrite(5, TimeUnit.MINUTES)
            .build();
    }

    CompletableFuture<Void> init(SolomonConfWithContext conf) {
        return dao.findAll()
            .thenCompose(shards -> {
                var shardIds = new ObjectOpenHashSet<String>(shards.size());
                for (var shard : shards) {
                    shardIds.add(shard.id());
                }

                var shardsToAdd = new ArrayList<BalancerShard>();
                for (var rawShard : conf.getAllRawShards()) {
                    var shardId = toShardId(rawShard.getNumId());
                    if (!shardIds.contains(shardId)) {
                        shardsToAdd.add(new BalancerShard(shardId, rawShard.getCreatedAtMillis()));
                    }
                }

                return dao.bulkUpsert(shardsToAdd);
            });
    }

    CompletableFuture<Void> add(int numId) {
        return add(toShardId(numId));
    }

    boolean recentlyAdded(int numId) {
        return recentlyAddedShards.getIfPresent(toShardId(numId)) != null;
    }

    @Override
    public CompletableFuture<Void> reload() {
        return initFuture.thenCompose(i -> dao.findAll())
            .thenAccept(shards -> {
                processShards(shards);
                extractCount.set(0);
            });
    }

    @Override
    public Set<String> getShards() {
        var deletingIds = Set.copyOf(deletingShards.keySet());
        var balancerIds = Sets.difference(shardIds, deletingIds);
        var recentIds = Set.copyOf(recentlyAddedShards.asMap().keySet());
        var confIds = confIds();
        if (confIds.isEmpty()) {
            return balancerIds;
        }
        return Sets.union(Sets.intersection(balancerIds, confIds), recentIds);
    }

    @Override
    public CompletableFuture<Void> add(String shardId) {
        return safeCall(() -> {
            var deleting = deletingShards.get(shardId);
            if (deleting != null) {
                return deleting.thenCompose(i -> add(shardId));
            }

            if (shardIds.contains(shardId)) {
                return completedFuture(null);
            }

            var shard = new BalancerShard(shardId, System.currentTimeMillis());
            return dao.upsert(shard)
                .thenAccept(i -> {
                    rememberRecentlyAdded(shardId);
                    shardIds.add(shardId);
                });
        });
    }

    @Override
    public CompletableFuture<Void> delete(String shardId) {
        if (!shardIds.contains(shardId) || recentlyAddedShards.asMap().containsKey(shardId)) {
            return completedFuture(null);
        }

        var deleteFuture = new CompletableFuture<Void>();
        var doneFuture = deleteFuture.whenComplete((i, t) -> {
            if (t == null) {
                shardIds.remove(shardId);
            }
            deletingShards.remove(shardId);
        });

        var prevFuture = deletingShards.putIfAbsent(shardId, doneFuture);
        if (prevFuture != null) {
            return prevFuture;
        }

        deleteLimiter.run(() -> {
            if (!lock.isLockedByMe()) {
                deleteFuture.completeExceptionally(notLeaderAnymore());
                return deleteFuture;
            }

            if (closed) {
                deleteFuture.completeExceptionally(alreadyClosed());
                return deleteFuture;
            }

            CompletableFutures.whenComplete(dao.delete(shardId), deleteFuture);
            return deleteFuture;
        });

        return doneFuture;
    }

    @Override
    public void onConfigurationLoad(SolomonConfWithContext conf) {
        this.lastState.set(conf);
        this.initFuture.complete(null);
    }

    @Override
    public void close() throws Exception {
        closed = true;
    }

    @SuppressWarnings("unchecked")
    private Set<String> confIds() {
        var lastState = this.lastState.get();
        if (lastState instanceof Set) {
            return (Set<String>) lastState;
        }

        var lastConf = (SolomonConfWithContext) lastState;
        var confIds = extractShardIds(lastConf);
        this.lastState.compareAndSet(lastState, confIds);

        if ((extractCount.incrementAndGet() & RECONCILIATION_THRESHOLD) == 0) {
            reconcile(confIds, lastConf);
        }

        return confIds;
    }

    private void reconcile(Set<String> confIds, SolomonConfWithContext lastConf) {
        if (confIds.isEmpty()) {
            return;
        }

        var now = System.currentTimeMillis();
        var shards = new ArrayList<BalancerShard>();
        for (var shardId : Sets.difference(confIds, shardIds)) {
            var numId = Integer.parseUnsignedInt(shardId);
            var shardConf = lastConf.getShardByNumIdOrNull(numId);
            var createdAt = shardConf != null
                ? shardConf.getRaw().getCreatedAtMillis()
                : now;

            shards.add(new BalancerShard(shardId, createdAt));
        }
        dao.bulkUpsert(shards)
            .thenAccept(i -> processShards(shards));

        var recentIds = Set.copyOf(recentlyAddedShards.asMap().keySet());
        for (var shardId : Sets.difference(shardIds, Sets.union(confIds, recentIds))) {
            delete(shardId);
        }
    }

    private void processShards(List<BalancerShard> shards) {
        var recencyThreshold = System.currentTimeMillis() - TimeUnit.MINUTES.toMillis(5);
        for (var shard : shards) {
            shardIds.add(shard.id());
            if (shard.createdAt() >= recencyThreshold) {
                rememberRecentlyAdded(shard.id());
            }
        }
    }

    private void rememberRecentlyAdded(String shardId) {
        recentlyAddedShards.put(shardId, Boolean.TRUE);
    }

    private static ObjectOpenHashSet<String> extractShardIds(SolomonConfWithContext conf) {
        var shardIds = new ObjectOpenHashSet<String>(conf.getAllShardsCount());
        for (var rawShard : conf.getAllRawShards()) {
            shardIds.add(toShardId(rawShard.getNumId()));
        }
        return shardIds;
    }

    private static String toShardId(int numId) {
        return StringInterner.intern(Integer.toUnsignedString(numId));
    }

    private static IllegalStateException notLeaderAnymore() {
        return new IllegalStateException("Not leader anymore");
    }

    private static IllegalStateException alreadyClosed() {
        return new IllegalStateException("Already closed");
    }

}
