package ru.yandex.solomon.coremon;

import java.nio.charset.StandardCharsets;
import java.util.Comparator;
import java.util.List;
import java.util.Optional;
import java.util.OptionalInt;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import javax.annotation.Nullable;
import javax.annotation.ParametersAreNonnullByDefault;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.PooledByteBufAllocator;
import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap;
import net.jpountz.xxhash.XXHashFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;

import ru.yandex.misc.actor.ActorRunner;
import ru.yandex.solomon.config.protobuf.coremon.TCoremonEngineConfig;
import ru.yandex.solomon.config.thread.ThreadPoolProvider;
import ru.yandex.solomon.core.conf.SolomonConfManager;
import ru.yandex.solomon.core.conf.SolomonConfWithContext;
import ru.yandex.solomon.core.conf.SolomonRawConf;
import ru.yandex.solomon.core.conf.watch.SolomonConfListener;
import ru.yandex.solomon.coremon.CoremonStateEvent.ReloadShard;
import ru.yandex.solomon.coremon.CoremonStateEvent.RemoveShard;
import ru.yandex.solomon.coremon.CoremonStateEvent.UpdateConf;
import ru.yandex.solomon.coremon.balancer.ShardLocatorImpl;
import ru.yandex.solomon.coremon.balancer.cluster.ShardInfoProvider;
import ru.yandex.solomon.coremon.balancer.state.ShardIds;
import ru.yandex.solomon.coremon.balancer.state.ShardLoad;
import ru.yandex.solomon.coremon.balancer.state.ShardsLoadMap;
import ru.yandex.solomon.coremon.meta.db.MetabaseShardStorageImpl;
import ru.yandex.solomon.coremon.meta.service.MetabaseNotInitialized;
import ru.yandex.solomon.coremon.meta.service.MetabaseShard;
import ru.yandex.solomon.coremon.meta.service.MetabaseShardResolver;
import ru.yandex.solomon.coremon.meta.service.MetabaseTotalShardCounter;
import ru.yandex.solomon.flags.FeatureFlagsHolder;
import ru.yandex.solomon.labels.query.Selectors;
import ru.yandex.solomon.labels.query.ShardSelectors;
import ru.yandex.solomon.labels.shard.ShardKey;
import ru.yandex.solomon.memory.layout.MemInfoProvider;
import ru.yandex.solomon.memory.layout.MemoryBySubsystem;
import ru.yandex.solomon.staffOnly.annotations.LinkedOnRootPage;
import ru.yandex.solomon.staffOnly.manager.ExtraContentParam;
import ru.yandex.solomon.staffOnly.manager.find.annotation.NamedObjectFinderAnnotation;
import ru.yandex.solomon.staffOnly.manager.special.ExtraContent;
import ru.yandex.solomon.util.collection.Nullables;
import ru.yandex.solomon.util.collection.queue.ArrayListLockQueue;


/**
 * @author Stepan Koltsov
 */
@LinkedOnRootPage("Coremon State")
@ParametersAreNonnullByDefault
@Component
public class CoremonState implements MetabaseShardResolver<MetabaseShard>, MemInfoProvider, SolomonConfListener, ShardInfoProvider, AutoCloseable {

    private static final Logger logger = LoggerFactory.getLogger(CoremonState.class);

    private final SolomonConfManager confManager;
    private final CoremonShardFactory shardFactory;
    private final ShardLocatorImpl shardLocator;
    private final CoremonStateReloadLimiter reloadLimiter;
    private final boolean onlyPartitionedShard;
    private final AtomicInteger totalShardCount = new AtomicInteger(MetabaseTotalShardCounter.SHARD_COUNT_UNKNOWN);

    private final ActorRunner actorRunner;
    private final ArrayListLockQueue<CoremonStateEvent> actorEvents = new ArrayListLockQueue<>(1);
    private final AtomicReference<SolomonConfWithContext> conf = new AtomicReference<>();
    private final AtomicReference<CoremonStateMap> stateMap = new AtomicReference<>();
    private final AtomicReference<HashSnapshot> stateSnapshotHash = new AtomicReference<>(HashSnapshot.expired());
    private final FeatureFlagsHolder featureHolder;

    @Autowired
    public CoremonState(
        SolomonConfManager confManager,
        CoremonShardFactory shardFactory,
        ShardLocatorImpl shardLocator,
        ThreadPoolProvider threadPoolProvider,
        TCoremonEngineConfig config,
        FeatureFlagsHolder featureHolder)
    {
        this.confManager = confManager;
        this.shardFactory = shardFactory;
        this.shardLocator = shardLocator;
        this.actorRunner = new ActorRunner(this::act, threadPoolProvider.getExecutorService(
            config.getMiscThreadPool(),
            "CoremonEngineConfig.MiscThreadPool"));
        this.reloadLimiter = new CoremonStateReloadLimiter(actorRunner::schedule);
        this.featureHolder = featureHolder;
        onlyPartitionedShard = config.getLaunchAsMetabase();
    }

    private void act() {
        final CoremonStateMap oldState = stateMap.get();
        CoremonStateMap newState = (oldState == null) ? new CoremonStateMap() : oldState;

        List<CoremonStateEvent> events = actorEvents.dequeueAll();
        if (!events.isEmpty()) {
            newState = newState.reloadShards(shardFactory, shardLocator,
                    filterEvents(events, ReloadShard.class), featureHolder, onlyPartitionedShard);
            newState = newState.removeShards(shardLocator, filterEvents(events, RemoveShard.class));
        }

        SolomonConfWithContext currConf = Nullables.orDefault(conf.get(), newState.getConf());
        if (currConf == null) {
            currConf = SolomonConfWithContext.create(SolomonRawConf.EMPTY);
        }
        if (newState.getSkipByReloadLimit() > 0 || currConf != newState.getConf() || !filterEvents(events, UpdateConf.class).isEmpty()) {
            newState = newState.update(currConf, shardFactory, shardLocator,
                    reloadLimiter, featureHolder, onlyPartitionedShard);
        }

        if (newState != oldState) {
            stateMap.set(newState);
            stateSnapshotHash.set(HashSnapshot.expired());

            // events must be finished after updating reference to a state map,
            // because on the next read clients must see updated state

            for (CoremonStateEvent event : events) {
                event.done();
            }
        }
    }

    private static <T> List<T> filterEvents(List<CoremonStateEvent> events, Class<T> clazz) {
        return events.stream()
            .filter(clazz::isInstance)
            .map(clazz::cast)
            .collect(Collectors.toList());
    }

    @Override
    public void close() {
        CoremonStateMap oldState = stateMap.getAndSet(null);
        if (oldState != null) {
            oldState.close();
        }
    }

    @Override
    public MemoryBySubsystem memoryBySystem() {
        MemoryBySubsystem r = new MemoryBySubsystem();
        getInitializedShards().forEach(shard -> {
            r.addAllMemory(shard.memoryBySystem());
        });
        return r;
    }

    @Override
    public MetabaseShard resolveShard(ShardKey shardKey) {
        return getShardByKey(shardKey).getMetabaseShard();
    }

    @Nullable
    @Override
    public MetabaseShard resolveShardOrNull(int numId) {
        var shard = getShardByNumIdOrNull(numId);
        return shard == null
            ? null
            : shard.getMetabaseShard();
    }

    @Override
    public MetabaseShard resolveShard(int numId) {
        return getShardByNumId(numId).getMetabaseShard();
    }

    @Override
    public Stream<MetabaseShard> resolveShard(String folderId, Selectors shardSelector) {
        checkInitializedSuccessfully();
        var shardKey = ShardSelectors.getShardKeyOrNull(shardSelector);
        if (shardKey != null) {
            var shard = getShardByKeyOrNull(shardKey);
            if (shard != null) {
                return Stream.of(shard.getMetabaseShard());
            }
        }
        return getShards()
            .filter(shard -> folderId.isEmpty() || folderId.equals(shard.getFolderId()))
            .filter(shard -> shardSelector.match(shard.getShardKey()))
            .sorted(Comparator.comparing(MetabaseShard::getId));
    }

    @Override
    public long getShardKeysHash() {
        CoremonStateMap state = stateMap.get();
        if (state == null) {
            return 0;
        }

        long now = System.currentTimeMillis();
        var snapshot = stateSnapshotHash.get();
        if (!snapshot.isExpired(now)) {
            return snapshot.calculatedHash;
        }

        var hasher = XXHashFactory.fastestInstance().newStreamingHash64(0);
        ByteBuf buf = PooledByteBufAllocator.DEFAULT.heapBuffer(255);

        try {
            state.getShardsStream().forEach(shard -> {
                var processingShard = shard.getProcessingShard();

                ShardKey shardKey = processingShard.getShardKey();
                buf.writeCharSequence(shardKey.getProject(), StandardCharsets.UTF_8);
                buf.writeCharSequence(shardKey.getService(), StandardCharsets.UTF_8);
                buf.writeCharSequence(shardKey.getCluster(), StandardCharsets.UTF_8);
                buf.writeBoolean(shard.isReady());
                buf.writeIntLE(shard.getNumId());

                var metabaseShard = shard.getMetabaseShard();
                buf.writeBoolean(metabaseShard.fileMetricsCount() >= metabaseShard.maxFileMetrics());

                hasher.update(buf.array(), buf.arrayOffset() + buf.readerIndex(), buf.readableBytes());
                buf.clear();
            });
        } finally {
            buf.release();
        }

        var update = new HashSnapshot(hasher.getValue(), now);
        stateSnapshotHash.compareAndSet(snapshot, update);
        return update.calculatedHash;
    }

    @Override
    public Stream<MetabaseShard> getShards() {
        CoremonStateMap state = stateMap.get();
        if (state == null) {
            return Stream.empty();
        }
        return state.getShardsStream()
            .map(CoremonShard::getMetabaseShard);
    }

    @Override
    public int getInactiveShardCount() {
        return stateMap.get().getInactiveLocalShardCount();
    }

    public int getShardsCount() {
        var state = stateMap.get();
        if (state == null) {
            return 0;
        }
        return state.size();
    }

    private void checkInitializedSuccessfully() {
        if (isLoading()) {
            throw new MetabaseNotInitialized();
        }
    }

    public List<CoremonShard> findShards() {
        return getInitializedShards()
            .sorted(Comparator.comparing(CoremonShard::getId))
            .collect(Collectors.toList());
    }

    /**
     * Only for manager interface usage
     */
    @NamedObjectFinderAnnotation
    public CoremonShard getShardByIdOrNumId(String id) {
        checkInitializedSuccessfully();
        var state = stateMap.get();
        var shard = state.getShardByIdOrNull(id);
        if (shard != null) {
            return shard;
        }

        int numId = Integer.parseUnsignedInt(id);
        return state.getShardByNumId(numId);
    }

    @Nullable
    public CoremonShard getShardByIdOrNull(String id) {
        checkInitializedSuccessfully();
        return stateMap.get().getShardByIdOrNull(id);
    }

    @Nullable
    public CoremonShard getShardByNumIdOrNull(int numId) {
        checkInitializedSuccessfully();
        return stateMap.get().getShardByNumIdOrNull(numId);
    }

    public CoremonShard getShardByNumId(int numId) {
        checkInitializedSuccessfully();
        return stateMap.get().getShardByNumId(numId);
    }

    @Nullable
    public CoremonShard getShardByKeyOrNull(ShardKey shardKey) {
        checkInitializedSuccessfully();
        return stateMap.get().getShardByKeyOrNull(shardKey);
    }

    public CoremonShard getShardByKey(ShardKey shardKey) {
        checkInitializedSuccessfully();
        return stateMap.get().getShardByKey(shardKey);
    }

    @Override
    public boolean isLoading() {
        return stateMap.get() == null || !shardLocator.isInitialized();
    }

    public Stream<CoremonShard> getInitializedShards() {
        CoremonStateMap state = stateMap.get();
        return (state == null) ? Stream.empty() : state.getShardsStream();
    }

    @ExtraContent("Shards")
    private void shardsPage(ExtraContentParam p) {
        var shards = getInitializedShards().collect(Collectors.toList());
        var initErrors = Optional.ofNullable(stateMap.get())
                .map(CoremonStateMap::getInitErrors)
                .orElseGet(Int2ObjectOpenHashMap::new);
        CoremonStateWww.shardsPage(p, MetabaseShardStorageImpl.IN_FLIGHT_LIMITER, shards, initErrors);
    }

    public CompletableFuture<Void> removeShard(int numId) {
        RemoveShard event = new RemoveShard(numId);
        actorEvents.enqueue(event);
        actorRunner.schedule();
        return event;
    }

    public CompletableFuture<Void> reloadShard(String projectId, String shardId, boolean allowUpdate, boolean awaitLoad) {
        logger.info("reloading shard {}", shardId);
        return confManager.loadShardConf(projectId, shardId)
            .thenCompose(shardConf -> {
                var event = new ReloadShard(shardConf, allowUpdate, awaitLoad);
                actorEvents.enqueue(event);
                actorRunner.schedule();
                return event;
            });
    }

    @Override
    public void onConfigurationLoad(SolomonConfWithContext conf) {
        this.conf.set(conf);
        this.actorRunner.schedule();
    }

    @Override
    public ShardsLoadMap getShardsLoad() {
        var coremonStateMap = stateMap.get();
        if (coremonStateMap == null) {
            return ShardsLoadMap.EMPTY;
        }
        var map = new Int2ObjectOpenHashMap<ShardLoad>(coremonStateMap.size());
        coremonStateMap.getShardsStream().forEach(shard -> {
            int shardId = shard.getProcessingShard().getNumId();
            ShardLoad load = shard.getProcessingShard().getLoad();
            map.put(shardId, load);
        });

        SolomonConfWithContext lastConf = this.conf.get();
        if (lastConf != null) {
            // fake loads for deactivated shards
            var it = shardLocator.getLocalShards().iterator();
            while (it.hasNext()) {
                var numId = it.nextInt();
                if (!map.containsKey(numId)) {
                    map.put(numId, ShardLoad.inactive(numId));
                }
            }
        }
        return ShardsLoadMap.ownOf(map);
    }

    @Override
    public void setLocalShardIds(ShardIds ids) {
        shardLocator.setLocalShardIds(ids);
        actorEvents.enqueue(UpdateConf.INSTANCE);
        actorRunner.schedule();
    }

    @Override
    public void setTotalShardCount(int totalShardCount) {
        this.totalShardCount.set(totalShardCount);
    }

    @Override
    public OptionalInt getTotalShardCount() {
        int total = totalShardCount.get();
        if (total < 0) {
            return OptionalInt.empty();
        }
        return OptionalInt.of(total);
    }

    private static class HashSnapshot {
        private static final long MAX_LIVE_TIME = TimeUnit.SECONDS.toMillis(15);

        final long calculatedHash;
        final long expiredAt;

        public HashSnapshot(long calculatedHash, long createdAt) {
            this.calculatedHash = calculatedHash;
            this.expiredAt = createdAt + MAX_LIVE_TIME + ThreadLocalRandom.current().nextLong(1_000);
        }

        private static HashSnapshot expired() {
            return new HashSnapshot(0, 0);
        }

        public boolean isExpired(long now) {
            return expiredAt <= now;
        }
    }
}
