package ru.yandex.solomon.dumper;

import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.ConcurrentModificationException;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;

import javax.annotation.Nullable;

import it.unimi.dsi.fastutil.ints.Int2ObjectMap;
import it.unimi.dsi.fastutil.ints.Int2ObjectMaps;
import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import ru.yandex.misc.actor.ActorRunner;
import ru.yandex.misc.concurrent.CompletableFutures;
import ru.yandex.solomon.dumper.storage.shortterm.DumperFile;
import ru.yandex.solomon.dumper.storage.shortterm.DumperTx;
import ru.yandex.solomon.dumper.storage.shortterm.ProducerKey;
import ru.yandex.solomon.dumper.storage.shortterm.ShortTermStorageReader;
import ru.yandex.solomon.memory.layout.MemMeasurableSubsystem;
import ru.yandex.solomon.memory.layout.MemoryBySubsystem;
import ru.yandex.solomon.slog.Log;
import ru.yandex.solomon.slog.LogsIndexSerializer;
import ru.yandex.solomon.slog.ResolvedLogMetaHeader;
import ru.yandex.solomon.util.ExceptionUtils;

import static java.util.concurrent.CompletableFuture.completedFuture;
import static java.util.stream.Collectors.collectingAndThen;
import static java.util.stream.Collectors.toList;

/**
 * @author Vladimir Gordiychuk
 */
public class DumperShard implements MemMeasurableSubsystem {
    private static final Logger logger = LoggerFactory.getLogger(DumperShard.class);
    private static final int MAX_WRITE_INFLIGHT = 100;
    private static final int MAX_SHARD_MEMORY_USAGE = 32 << 20; // 32 MiB
    private static final int MAX_PARSING_QUEUE_SIZE = 10;

    private final int shardId;
    private final Executor executor;
    private final Executor parseExecutor;
    private final ScheduledExecutorService timer;
    private final DumperShardMetrics metrics;
    private final ShortTermStorageReader shortTermStorage;
    private final LongTermStorage longTermStorage;
    private final SolomonShardOptsProvider optsProvider;
    private long latestOptsSnapshotHash;
    private final Int2ObjectOpenHashMap<SolomonShardProcess> shardProcesses;
    private final ActorRunner actor;

    @Nullable
    private CompletableFuture<DumperFile> activeFetch;
    @Nullable
    private CompletableFuture<Void> activeDelete;
    private final ArrayDeque<CompletableFuture<ParsedTx>> parsing = new ArrayDeque<>(MAX_PARSING_QUEUE_SIZE);
    private final ArrayDeque<CompletableFuture<DumperTx>> writes = new ArrayDeque<>(MAX_WRITE_INFLIGHT);
    private final AtomicInteger writesInFlight = new AtomicInteger();
    private final AtomicLong memoryUsageBytes = new AtomicLong();

    private volatile boolean stopped;
    private final CompletableFuture<Void> stopFuture = new CompletableFuture<>();

    public DumperShard(
            int shardId,
            Executor executor,
            Executor parseExecutor,
            ScheduledExecutorService timer,
            DumperShardMetrics metrics,
            SolomonShardOptsProvider optsProvider,
            ShortTermStorageReader shortTermStorage,
            LongTermStorage longTermStorage)
    {
        this.shardId = shardId;
        this.executor = executor;
        this.parseExecutor = parseExecutor;
        this.timer = timer;
        this.metrics = metrics;
        this.optsProvider = optsProvider;
        this.shortTermStorage = shortTermStorage;
        this.longTermStorage = longTermStorage;
        this.shardProcesses = new Int2ObjectOpenHashMap<>();
        this.actor = new ActorRunner(this::act, executor);
    }

    public DumperShardMetrics metrics() {
        return metrics;
    }

    public Int2ObjectMap<SolomonShardProcess> processes() {
        return new Int2ObjectOpenHashMap<>(shardProcesses);
    }

    private void act() {
        try {
            if (stopped) {
                onStop();
                return;
            }

            actActualizeOpts();
            actFetch();
            actWrite();
            actDelete();
        } catch (Throwable e) {
            error(e);
            actor.schedule();
        }
    }

    private void error(@Nullable Throwable e) {
        if (e == null) {
            return;
        }

        logger.error("ShardId:{} exception occurs", shardId, e);
        if (shortTermStorage.isStop()) {
            stopped = true;
            actor.schedule();
        }
    }

    private void actActualizeOpts() {
        var snapshot = optsProvider.snapshot();
        if (snapshot.optionsHash() == latestOptsSnapshotHash) {
            return;
        }

        var it = shardProcesses.int2ObjectEntrySet().fastIterator();
        while (it.hasNext()) {
            var entry = it.next();
            int numId = entry.getIntKey();
            var opts = snapshot.resolve(numId);
            var process = entry.getValue();
            if (opts == null) {
                process.close();
                logger.info("ShardId:{} stop processing numId:{}", shardId, Integer.toUnsignedString(numId));
                it.remove();
                continue;
            }
            if (process.updateOpts(opts)) {
                logger.info("ShardId:{} opts changed for numId:{}", shardId, Integer.toUnsignedLong(numId));
            }
        }
        this.latestOptsSnapshotHash = snapshot.optionsHash();
    }

    private void actFetch() {
        if (memoryUsageBytes.get() > MAX_SHARD_MEMORY_USAGE) {
            return;
        }

        if (activeFetch != null) {
            if (!activeFetch.isDone()) {
                return;
            }
            var future = activeFetch;
            activeFetch = null;
            parse(future.getNow(null));
        }

        for (int index = parsing.size(); index < MAX_PARSING_QUEUE_SIZE; index++) {
            if (memoryUsageBytes.get() > MAX_SHARD_MEMORY_USAGE) {
                break;
            }

            var future = shortTermStorage.next();
            if (!future.isDone()) {
                activeFetch = future;
                future.whenComplete((file, e) -> actor.schedule());
                return;
            }

            parse(future.getNow(null));
        }
    }

    private void parse(DumperFile file) {
        long fileBytes = file.memorySizeIncludingSelf();
        memoryUsageBytes.addAndGet(fileBytes);
        var tx = file.tx;
        var future = toLogs(file)
                .stream()
                .map(log -> enqueueLog(tx, log))
                .collect(collectingAndThen(toList(), CompletableFutures::allOf))
                .thenApply(list -> new ParsedTx(tx, list));
        parsing.add(future);
        metrics.aggregated.txParse.forFuture(future);
        future.whenComplete((ignore, e) -> {
            memoryUsageBytes.addAndGet(-fileBytes);
            actor.schedule();
        });
    }

    private void actWrite() {
        if (writesInFlight.get() >= MAX_WRITE_INFLIGHT) {
            return;
        }

        CompletableFuture<ParsedTx> future;
        while ((future = parsing.peek()) != null && future.isDone()) {
            if (future != parsing.poll()) {
                throw new ConcurrentModificationException();
            }

            writeTx(future.join());

            if (writesInFlight.incrementAndGet() >= MAX_WRITE_INFLIGHT) {
                return;
            }
        }
    }

    private void writeTx(ParsedTx parsed) {
        Int2ObjectMap<SolomonShardMetrics> solomonShardMetrics = new Int2ObjectOpenHashMap<>();
        var logsByShardId = new Int2ObjectOpenHashMap<List<Log>>();

        var txn = parsed.tx;
        var expectedProducerId = ProducerKey.makeProducerId(txn);
        for (var logs : parsed.logs) {
            for (var entry : logs.int2ObjectEntrySet()) {
                int shardId = entry.getIntKey();
                Log log = entry.getValue();
                if (!solomonShardMetrics.containsKey(log.numId)) {
                    var process = shardProcesses.get(log.numId);
                    if (process != null) {
                        solomonShardMetrics.put(log.numId, process.getMetrics());
                    }
                }

                var list = logsByShardId.get(shardId);
                if (list == null) {
                    list = new ArrayList<>();
                    logsByShardId.put(shardId, list);
                }
                ensureProducerIdValid(expectedProducerId, log, txn);
                ensureProducerSeqNoValid(log, txn);
                list.add(log);
            }
        }

        var futures = new ArrayList<CompletableFuture<?>>(logsByShardId.size());
        var it = logsByShardId.int2ObjectEntrySet().fastIterator();
        while (it.hasNext()) {
            var entry = it.next();
            int shardId = entry.getIntKey();
            var logs = entry.getValue();
            long memorySize = logs.stream().mapToLong(Log::memorySizeIncludingSelf).sum();
            var future = CompletableFutures.safeCall(() -> longTermStorage.write(shardId, logs))
                    .whenComplete((ignore, e) -> memoryUsageBytes.addAndGet(-memorySize));
            futures.add(future);
        }

        var writeFuture = CompletableFutures.allOfVoid(futures).thenApply(ignore -> txn);
        writes.add(writeFuture);
        metrics.aggregated.txWrite.forFuture(writeFuture);
        writeFuture.whenComplete((ignore, e) -> {
            if (e != null) {
                solomonShardMetrics.values().forEach(m -> m.writeError.inc());
            } else {
                long nowSec = System.currentTimeMillis() / 1000;
                long lag = nowSec - txn.getCreatedAtSec();
                for (var m : solomonShardMetrics.values()) {
                    m.writeSuccess.inc();
                    m.txLagSec.set(lag);
                }
            }
            logger.info("ShardId:{}, Tx:{}, ProducerId:{} written", shardId, txn, ProducerKey.makeProducerId(txn));
            writesInFlight.decrementAndGet();
            actor.schedule();
        });
    }

    private void ensureProducerIdValid(int expected, Log log, DumperTx tx) {
        int producerId = ResolvedLogMetaHeader.producerId(log.meta);
        if (producerId != expected) {
            throw new IllegalStateException("Invalid producerId " + producerId + " expected " + expected + " for txn " + tx);
        }
    }

    private void ensureProducerSeqNoValid(Log log, DumperTx tx) {
        long producerSeqNo = ResolvedLogMetaHeader.producerSeqNo(log.meta);
        if (producerSeqNo != tx.txn) {
            throw new IllegalStateException("Invalid producerSeqNo " + producerSeqNo + " expected " + tx.txn + " for txn " + tx);
        }
    }

    private void actDelete() {
        if (activeDelete != null) {
            if (!activeDelete.isDone()) {
                return;
            }
            var future = activeDelete;
            activeDelete = null;
            future.getNow(null);
        }

        List<DumperTx> listTxn;
        while ((listTxn = dequeueCompletedWrite()) != null) {
            var future = shortTermStorage.commit(listTxn);
            metrics.aggregated.txCommit.forFuture(future);
            if (future.isDone()) {
                future.getNow(null);
                continue;
            }

            activeDelete = future;
            future.whenComplete((ignore, e) -> actor.schedule());
            return;
        }
    }

    @Nullable
    private List<DumperTx> dequeueCompletedWrite() {
        List<DumperTx> listTxn = null;
        CompletableFuture<DumperTx> writeFuture;
        while ((writeFuture = writes.peek()) != null && writeFuture.isDone()) {
            if (listTxn == null) {
                listTxn = new ArrayList<>();
            }

            if (writeFuture != writes.poll()) {
                throw new ConcurrentModificationException();
            }

            try {
                listTxn.add(writeFuture.getNow(null));
            } catch (Throwable e) {
                ExceptionUtils.uncaughtException(e);
            }
        }
        return listTxn;
    }

    private void onStop() {
        shortTermStorage.stop();
        stopFuture.complete(null);
        metrics.aggregated.processCount.add(-shardProcesses.size());
        shardProcesses.values().forEach(SolomonShardProcess::close);
        shardProcesses.clear();
        if (activeFetch != null) {
            var future = activeFetch;
            activeFetch = null;
            future.thenAccept(DumperFile::release);
        }
    }

    private CompletableFuture<Int2ObjectMap<Log>> enqueueLog(DumperTx tx, Log log) {
        var processor = shardProcessById(log.numId);
        if (processor == null) {
            log.close();
            metrics.aggregated.processingErrors.inc();
            // TODO: triggering reload config's and try again (gordiychuk@)
            logger.error("ShardId:{} Tx:{} NumId:{} unknown shard", shardId, tx, Integer.toUnsignedString(log.numId));
            return completedFuture(Int2ObjectMaps.emptyMap());
        }

        return processor.enqueue(tx, log)
            .handle((result, e) -> {
                if (e != null) {
                    metrics.aggregated.processingErrors.inc();
                    logger.error("ShardId:{} Tx:{} NumId:{} processing failed", shardId, tx, Integer.toUnsignedString(log.numId), e);
                    return Int2ObjectMaps.emptyMap();
                }

                memoryUsageBytes.addAndGet(result.values().stream().mapToLong(Log::memorySizeIncludingSelf).sum());
                return result;
            });
    }

    @Nullable
    private SolomonShardProcess shardProcessById(int numId) {
        SolomonShardProcess process = shardProcesses.get(numId);
        if (process == null) {
            var opts = optsProvider.resolve(numId);
            if (opts == null) {
                logger.info("ShardId:{} NumId:{} unknown", shardId, Integer.toUnsignedLong(numId));
                return null;
            }
            process = new SolomonShardProcess(opts, parseExecutor, timer, longTermStorage, metrics.aggregated.processMetrics);
            logger.info("ShardId:{} start processing numId:{}", shardId, Integer.toUnsignedLong(numId));
            shardProcesses.put(numId, process);
            metrics.aggregated.processCount.add(1);
        }
        return process;
    }

    private List<Log> toLogs(DumperFile file) {
        var content = file.content;
        List<Log> result = new ArrayList<>();
        try {
            var index = LogsIndexSerializer.deserialize(content.retain());
            for (int i = 0; i < index.getSize(); i++) {
                var meta = content.readSlice(index.getMetaSize(i)).retain();
                var data = content.readSlice(index.getDataSize(i)).retain();
                var log = new Log(index.getNumId(i), meta, data);
                result.add(log);
            }
            return result;
        } catch (Throwable e) {
            result.forEach(Log::close);
            logger.error("ShardId:{} {} failed to parse logs", shardId, file, e);
            metrics.aggregated.processingErrors.add(result.size() + 1);
            return List.of();
        } finally {
            file.release();
        }
    }

    public int getId() {
        return shardId;
    }

    public long getMemoryUsage() {
        return memoryUsageBytes.get();
    }

    public void start() {
        actor.schedule();
    }

    public CompletableFuture<Void> stop() {
        stopped = true;
        actor.schedule();
        return stopFuture;
    }

    public boolean isStop() {
        return stopped;
    }

    public void scheduleAct() {
        actor.schedule();
    }

    public CompletableFuture<Void> stopFuture() {
        return stopFuture;
    }

    @Override
    public String toString() {
        return "DumperShard{" +
            "shardId=" + shardId +
            '}';
    }

    @Override
    public void addMemoryBySubsystem(MemoryBySubsystem memory) {
        shortTermStorage.addMemoryBySubsystem(memory);
        for (SolomonShardProcess process : shardProcesses.values()) {
            process.addMemoryBySubsystem(memory);
        }
    }

    private static class ParsedTx implements AutoCloseable {
        private final DumperTx tx;
        private final List<Int2ObjectMap<Log>> logs;

        public ParsedTx(DumperTx tx, List<Int2ObjectMap<Log>> logs) {
            this.tx = tx;
            this.logs = logs;
        }

        @Override
        public void close() {
            logs.stream().flatMap(item -> item.values().stream()).forEach(Log::close);
        }
    }
}
