package ru.yandex.stockpile.server.shard;

import java.util.Optional;
import java.util.Queue;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.Flow;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;

import ru.yandex.misc.actor.ActorRunner;
import ru.yandex.solomon.memory.layout.MemMeasurable;
import ru.yandex.solomon.memory.layout.MemoryCounter;
import ru.yandex.stockpile.memState.MetricIdAndData;
import ru.yandex.stockpile.server.SnapshotLevel;
import ru.yandex.stockpile.server.data.chunk.ChunkWithNo;
import ru.yandex.stockpile.server.data.chunk.ChunkWriter;
import ru.yandex.stockpile.server.data.command.SnapshotCommandContent;
import ru.yandex.stockpile.server.data.index.SnapshotIndex;
import ru.yandex.stockpile.server.data.index.SnapshotIndexContent;
import ru.yandex.stockpile.server.shard.MergeProcessMetrics.MergeKindMetrics;
import ru.yandex.stockpile.server.shard.stat.LevelSizeAndCount;
import ru.yandex.stockpile.server.shard.stat.SizeAndCount;

import static java.util.concurrent.CompletableFuture.completedFuture;
import static java.util.concurrent.CompletableFuture.failedFuture;
import static ru.yandex.misc.concurrent.CompletableFutures.whenComplete;

/**
 * @author Vladimir Gordiychuk
 */
public class MergeWriter implements Flow.Subscriber<MetricIdAndData>, MemMeasurable {
    private static final long SELF_SIZE = MemoryCounter.objectSelfSizeLayout(MergeWriter.class);
    private static final int MAX_WRITE_IN_FLIGHT = 1;
    private static final int PRE_FETCH_METRICS = 100;
    private final Queue<MetricIdAndData> queue = new ConcurrentLinkedQueue<>();
    private final ShardThread shardThread;
    public final SnapshotLevel level;
    public final long txn;

    private final AtomicLong bytesInQueue = new AtomicLong(SELF_SIZE);
    private final AtomicLong bytesInFlight = new AtomicLong(SELF_SIZE);

    private final AtomicInteger writeInFlight = new AtomicInteger();
    private final ChunkWriter chunkWriter;
    private final ActorRunner actor;
    private final AtomicReference<State> state = new AtomicReference<>(State.RUNNING);
    private final CompletableFuture<Optional<SnapshotIndexWithStats>> doneFuture = new CompletableFuture<>();
    private final MergeKindMetrics metrics;
    private SizeAndCount commandSize = SizeAndCount.zero;

    private Flow.Subscription subscription;

    public MergeWriter(ShardThread shardThread, long now, SnapshotLevel level, long decimatedAt, long txn, MergeKindMetrics metrics) {
        this.shardThread = shardThread;
        this.level = level;
        this.txn = txn;
        this.metrics = metrics;
        this.chunkWriter = new ChunkWriter(SnapshotReason.UNKNOWN, now, decimatedAt);
        this.actor = new ActorRunner(this::act, shardThread.shard.mergeExecutor);
    }

    @Override
    public void onSubscribe(Flow.Subscription subscription) {
        this.subscription = subscription;
        this.subscription.request(PRE_FETCH_METRICS);
    }

    @Override
    public void onNext(MetricIdAndData item) {
        if (state.get() == State.RUNNING) {
            metrics.writeQueue.add(1);
            bytesInQueue.addAndGet(item.memorySizeIncludingSelf());
            queue.add(item);
            actor.schedule();
        }
    }

    @Override
    public void onError(Throwable e) {
        state.set(State.ERROR);
        doneFuture.completeExceptionally(e);
    }

    @Override
    public void onComplete() {
        state.compareAndSet(State.RUNNING, State.FINISHING);
        actor.schedule();
    }

    public void cancel() {
        state.set(State.CANCELED);
        var copy = subscription;
        if (copy != null) {
            copy.cancel();
        }
    }

    public CompletableFuture<Optional<SnapshotIndexWithStats>> getDoneFuture() {
        return doneFuture;
    }

    private void act() {
        switch (state.get()) {
            case RUNNING:
                actWrite();
                return;
            case FINISHING: {
                if (actWrite()) {
                    return;
                }

                state.set(State.DONE);
                whenComplete(actFinish(), doneFuture);
                return;
            }
            default: {
                MetricIdAndData item;
                while ((item = queue.poll()) != null) {
                    bytesInQueue.addAndGet(-item.memorySizeIncludingSelf());
                    metrics.writeQueue.add(-1);
                    item.close();
                }
            }
        }
    }

    private boolean actWrite() {
        if (writeInFlight.get() >= MAX_WRITE_IN_FLIGHT) {
            return true;
        }

        MetricIdAndData merged;
        while ((merged = queue.poll()) != null) {
            try {
                subscription.request(1);
                metrics.writeQueue.add(-1);
                bytesInQueue.addAndGet(-merged.memorySizeIncludingSelf());
                ChunkWithNo chunk = chunkWriter.writeMetric(merged.localId(), merged.lastTsMillis(), merged.archive());
                if (chunk == null) {
                    continue;
                }

                writeChunk(chunk);
                if (writeInFlight.get() >= MAX_WRITE_IN_FLIGHT) {
                    return true;
                }
            } finally {
                merged.close();
            }
        }

        // no more tasks in queue, ask producer load more data
        subscription.request(1);
        return !queue.isEmpty() || writeInFlight.get() > 0;
    }

    private CompletableFuture<Optional<SnapshotIndexWithStats>> actFinish() {
        ChunkWriter.Finish result = chunkWriter.finish();
        var index = result.index;
        var lastChunk = result.chunkWithNo;
        if (lastChunk != null) {
            return writeChunk(result.chunkWithNo)
                .thenCompose(ignore -> writeIndex(index));
        } else {
            return writeIndex(index);
        }
    }

    private CompletableFuture<?> writeChunk(ChunkWithNo chunk) {
        writeInFlight.incrementAndGet();
        long chunkSize = chunk.memorySizeIncludingSelf();
        bytesInFlight.addAndGet(chunkSize);
        long startNanos = System.nanoTime();
        return shardThread.loopUntilSuccessFuture("writeSnapshotChunkTmp",
            () -> shardThread.shard.storage.writeSnapshotChunkToTemp(level, txn, chunk))
            .whenComplete((r, e) -> {
                metrics.addWriteTime(System.nanoTime() - startNanos);
                if (e != null) {
                    onError(e);
                }

                bytesInFlight.addAndGet(-chunkSize);
                writeInFlight.decrementAndGet();
                actor.schedule();
            });
    }

    private CompletableFuture<Optional<SnapshotIndexWithStats>> writeIndex(SnapshotIndexContent content) {
        try {
            if (!SnapshotIndexWriter.isWritable(content) && commandSize.isEmpty()) {
                return completedFuture(Optional.empty());
            }

            var writer = new SnapshotIndexWriter(shardThread);
            long startNanos = System.nanoTime();
            return writer.write(level, txn, content)
                    .thenApply(indexSize -> {
                        var chunkSize = content.diskSize();
                        var levelSize = new LevelSizeAndCount(indexSize, chunkSize, commandSize);
                        var index = new SnapshotIndex(level, txn, content);
                        var stats = new SnapshotIndexWithStats(index, levelSize);
                        return Optional.of(stats);
                    })
                    .whenComplete((r, e) -> {
                        metrics.addWriteTime(System.nanoTime() - startNanos);
                    });
        } catch (Throwable t) {
            return failedFuture(t);
        }
    }

    public CompletableFuture<Void> writeCommand(SnapshotCommandContent content, boolean last) {
        if (!SnapshotCommandWriter.isWritable(last, content)) {
            return completedFuture(null);
        }

        var writer = new SnapshotCommandWriter(shardThread);
        return writer.write(level, txn, content)
                .thenAccept(written -> commandSize = written);
    }

    @Override
    public long memorySizeIncludingSelf() {
        long size = SELF_SIZE;
        size += chunkWriter.memorySizeIncludingSelf();
        size += bytesInQueue.get();
        size += MemoryCounter.CompletableFuture_SELF_SIZE;

        return size;
    }

    private enum State {
        RUNNING,
        FINISHING,
        CANCELED,
        ERROR,
        DONE
    }
}
