package ru.yandex.solomon.coremon;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.stream.Stream;

import javax.annotation.Nullable;
import javax.annotation.concurrent.Immutable;

import com.google.common.collect.ImmutableMap;
import it.unimi.dsi.fastutil.ints.Int2ObjectMap;
import it.unimi.dsi.fastutil.ints.Int2ObjectMaps;
import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import ru.yandex.solomon.core.ShardIsNotLocalException;
import ru.yandex.solomon.core.conf.ShardConfDetailed;
import ru.yandex.solomon.core.conf.ShardConfMaybeWrong;
import ru.yandex.solomon.core.conf.SolomonConfWithContext;
import ru.yandex.solomon.coremon.CoremonStateEvent.ReloadShard;
import ru.yandex.solomon.coremon.CoremonStateEvent.RemoveShard;
import ru.yandex.solomon.coremon.meta.service.ShardLocator;
import ru.yandex.solomon.flags.FeatureFlag;
import ru.yandex.solomon.flags.FeatureFlagsHolder;
import ru.yandex.solomon.labels.shard.ShardKey;

/**
 * @author Sergey Polovko
 */
@Immutable
final class CoremonStateMap {
    private static final Logger logger = LoggerFactory.getLogger(CoremonStateMap.class);

    @Nullable
    private final SolomonConfWithContext conf;
    private final Int2ObjectMap<CoremonShardWithConf> shardsByNumId;
    private final ImmutableMap<ShardKey, CoremonShardWithConf> shardsByKey;
    private final int skipByReloadLimit;
    private final Int2ObjectMap<Throwable> initErrors;
    private final int inactiveLocalShardCount;

    CoremonStateMap(
        @Nullable SolomonConfWithContext conf,
        Int2ObjectMap<CoremonShardWithConf> shardsByNumId,
        Int2ObjectMap<Throwable> initErrors,
        int skipByReloadLimit,
        int inactiveLocalShardCount)
    {
        this.conf = conf;
        this.shardsByNumId = shardsByNumId;
        this.initErrors = initErrors;
        var shardsByKey = new HashMap<ShardKey, CoremonShardWithConf>(shardsByNumId.size());
        for (CoremonShardWithConf s : shardsByNumId.values()) {
            shardsByKey.put(s.conf.shardKey(), s);
        }
        this.skipByReloadLimit = skipByReloadLimit;
        this.shardsByKey = ImmutableMap.copyOf(shardsByKey);
        this.inactiveLocalShardCount = inactiveLocalShardCount;
    }

    CoremonStateMap() {
        this(null, Int2ObjectMaps.emptyMap(), Int2ObjectMaps.emptyMap(), 0, 0);
    }

    CoremonStateMap update(
        SolomonConfWithContext conf,
        CoremonShardFactory shardFactory,
        ShardLocator shardLocator,
        CoremonStateReloadLimiter reloadLimiter,
        FeatureFlagsHolder featureHolder,
        boolean onlyPartitionedShard)
    {
        var shardsByNumId = new Int2ObjectOpenHashMap<>(this.shardsByNumId);
        var initErrors = new Int2ObjectOpenHashMap<Throwable>();
        int skipByReloadLimit = 0;

        for (var it = shardsByNumId.int2ObjectEntrySet().fastIterator(); it.hasNext(); ) {
            var e = it.next();
            int numId = e.getIntKey();
            if (!shardLocator.isLocal(numId)) {
                var shard = e.getValue().shard;
                logger.info("shard {} become remote", Integer.toUnsignedString(numId));
                shard.close();
                it.remove();
            }
        }

        List<ShardConfDetailed> reloadCandidates = new ArrayList<>();
        for (ShardConfMaybeWrong shardConf : conf.getShards()) {
            final int numId = shardConf.getNumId();
            if (!shardLocator.isLocal(numId)) {
                final CoremonShardWithConf prevShard = shardsByNumId.remove(numId);
                if (prevShard != null) {
                    prevShard.shard.close();
                    logger.info("shard {} become remote", Integer.toUnsignedString(numId));
                }
                continue;
            }

            if (!shardConf.isCorrect()) {
                logger.error("cannot load config of shard {}", Integer.toUnsignedString(numId), shardConf.getThrowable());
                initErrors.put(numId, shardConf.getThrowable());
                continue;
            }

            try {
                final ShardConfDetailed newShardConf = shardConf.getConfOrThrow();
                final CoremonShardWithConf prevShard = shardsByNumId.get(numId);
                if (prevShard == null) {
                    if (onlyPartitionedShard && !featureHolder.hasFlag(FeatureFlag.PARTITIONED_SHARD, numId)) {
                        throw new IllegalArgumentException("Shard isn't partitioned (only such shards are expected)");
                    }
                    var newShard = shardFactory.createShard(newShardConf);
                    shardsByNumId.put(numId, new CoremonShardWithConf(newShard, newShardConf));
                } else if (!shardFactory.updateShard(prevShard.shard, newShardConf)) {
                    reloadCandidates.add(newShardConf);
                }
            } catch (Throwable t) {
                logger.error("cannot update shard {}", Integer.toUnsignedString(numId), t);
                initErrors.put(numId, t);
            }
        }

        Collections.shuffle(reloadCandidates);
        for (var newShardConf : reloadCandidates) {
            try {
                if (!reloadLimiter.isAllow()) {
                    skipByReloadLimit++;
                    continue;
                }

                final CoremonShardWithConf prevShard = shardsByNumId.get(newShardConf.getNumId());
                logger.info("shard " + prevShard.shard.getId() + " was changed");
                prevShard.shard.close();
                var newShard = shardFactory.createShard(newShardConf);
                shardsByNumId.put(newShardConf.getNumId(), new CoremonShardWithConf(newShard, newShardConf));
                reloadLimiter.add(newShard.getLoadFuture());
            } catch (Throwable t) {
                logger.error("cannot update shard {}", Integer.toUnsignedString(newShardConf.getNumId()), t);
                initErrors.put(newShardConf.getNumId(), t);
            }
        }

        return new CoremonStateMap(
            conf,
            Int2ObjectMaps.unmodifiable(shardsByNumId),
            Int2ObjectMaps.unmodifiable(initErrors),
            skipByReloadLimit,
            shardLocator.getLocalShardCount() - shardsByNumId.size());
    }

    CoremonStateMap reloadShards(
        CoremonShardFactory shardFactory,
        ShardLocator shardLocator,
        List<ReloadShard> events,
        FeatureFlagsHolder featureHolder,
        boolean onlyPartitionedShard)
    {
        for (var it = events.iterator(); it.hasNext(); ) {
            ReloadShard event = it.next();
            String error = null;
            if (!shardLocator.isLocal(event.conf.getNumId())) {
                error = String.format("cannot reload shard %s, because it's not local", event.conf.getId());
            } else if (onlyPartitionedShard && !featureHolder.hasFlag(FeatureFlag.PARTITIONED_SHARD, event.conf.getNumId())) {
                error = String.format("cannot reload shard %s, because it's not partitioned", event.conf.getId());
            }

            if (error != null) {
                logger.warn(error);
                event.abort(error);
                it.remove();
            }
        }

        if (events.isEmpty()) {
            return this;
        }

        var shardsByNumId = new Int2ObjectOpenHashMap<>(this.shardsByNumId);
        for (ReloadShard event : events) {
            int numId = event.conf.getNumId();
            String shardId = event.conf.getId();
            try {
                final ShardConfDetailed newShardConf = event.conf;
                final CoremonShardWithConf prevShard = shardsByNumId.get(numId);
                if (prevShard == null) {
                    var newShard = shardFactory.createShard(newShardConf);
                    shardsByNumId.put(numId, new CoremonShardWithConf(newShard, newShardConf));
                } else if (!event.allowUpdate || !shardFactory.updateShard(prevShard.shard, newShardConf)) {
                    prevShard.shard.close();
                    var newShard = shardFactory.createShard(newShardConf);
                    shardsByNumId.put(numId, new CoremonShardWithConf(newShard, newShardConf));
                } else {
                    shardsByNumId.put(numId, new CoremonShardWithConf(prevShard.shard, newShardConf));
                }

                if (event.awaitLoad) {
                    shardsByNumId.get(numId).shard.getLoadFuture().whenComplete((aVoid, t) -> {
                        if (t != null) {
                            event.completeExceptionally(new IllegalStateException("cannot load shard " + shardId, t));
                        } else {
                            event.done();
                        }
                    });
                }
            } catch (Throwable t) {
                event.completeExceptionally(new IllegalStateException("cannot reload shard " + shardId, t));
            }
        }

        return new CoremonStateMap(
                conf,
                Int2ObjectMaps.unmodifiable(shardsByNumId),
                initErrors,
                skipByReloadLimit,
                shardLocator.getLocalShardCount() - shardsByNumId.size());
    }

    CoremonStateMap removeShards(ShardLocator shardLocator, List<RemoveShard> events) {
        for (var it = events.iterator(); it.hasNext(); ) {
            RemoveShard event = it.next();
            if (!shardsByNumId.containsKey(event.numId)) {
                logger.warn("removing non local or inactive shard {}", Integer.toUnsignedString(event.numId));
                event.done();
                it.remove();
            }
        }

        if (events.isEmpty()) {
            return this;
        }

        var shardsByNumId = new Int2ObjectOpenHashMap<>(this.shardsByNumId);
        for (RemoveShard event : events) {
            try {
                var shardWithConf = shardsByNumId.remove(event.numId);
                shardWithConf.shard.close();
            } catch (Throwable t) {
                event.completeExceptionally(new IllegalStateException("cannot remove shard " + event.numId, t));
            }
        }

        return new CoremonStateMap(
                conf,
                Int2ObjectMaps.unmodifiable(shardsByNumId),
                initErrors,
                skipByReloadLimit,
                shardLocator.getLocalShardCount() - shardsByNumId.size());
    }

    int getInactiveLocalShardCount() {
        return inactiveLocalShardCount;
    }

    private int numIdByShardId(String shardId) {
        if (conf == null) {
            return 0;
        }

        var shard = conf.getShardByIdOrNull(shardId);
        if (shard == null) {
            return 0;
        }
        return shard.getNumId();
    }

    void close() {
        for (CoremonShardWithConf s: shardsByNumId.values()) {
            s.shard.stop();
        }
    }

    @Nullable
    SolomonConfWithContext getConf() {
        return conf;
    }

    int getSkipByReloadLimit() {
        return skipByReloadLimit;
    }

    Int2ObjectMap<Throwable> getInitErrors() {
        return initErrors;
    }

    @Nullable
    CoremonShard getShardByIdOrNull(String id) {
        return getShardByNumIdOrNull(numIdByShardId(id));
    }

    CoremonShard getShardById(String id) {
        CoremonShardWithConf shard = shardsByNumId.get(numIdByShardId(id));
        if (shard == null) {
            throw new ShardIsNotLocalException(id);
        }
        return shard.shard;
    }

    @Nullable
    CoremonShard getShardByNumIdOrNull(int numId) {
        var shard = shardsByNumId.get(numId);
        return shard == null ? null : shard.shard;
    }

    CoremonShard getShardByNumId(int numId) {
        var shard = shardsByNumId.get(numId);
        if (shard == null) {
            throw new ShardIsNotLocalException(Integer.toUnsignedString(numId));
        }
        return shard.shard;
    }

    @Nullable
    CoremonShard getShardByKeyOrNull(ShardKey shardKey) {
        var shard = shardsByKey.get(shardKey);
        return shard == null ? null : shard.shard;
    }

    CoremonShard getShardByKey(ShardKey shardKey) {
        var shard = shardsByKey.get(shardKey);
        if (shard == null) {
            throw new ShardIsNotLocalException("" + shardKey);
        }
        return shard.shard;
    }

    Stream<CoremonShard> getShardsStream() {
        return shardsByNumId.values()
            .stream()
            .map(s -> s.shard);
    }

    public int size() {
        return shardsByNumId.size();
    }

    /**
     * COREMON SHARD WITH CONF
     */
    private static final class CoremonShardWithConf {
        final CoremonShard shard;
        final ShardConfDetailed conf;

        CoremonShardWithConf(CoremonShard shard, ShardConfDetailed conf) {
            this.shard = shard;
            this.conf = conf;
        }
    }
}
