package ru.yandex.stockpile.server.shard;

import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
import java.util.function.Supplier;

import com.google.common.base.Throwables;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.PooledByteBufAllocator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import ru.yandex.solomon.codec.serializer.NettyStockpileSerializer;
import ru.yandex.solomon.memory.layout.MemoryCounter;
import ru.yandex.solomon.util.protobuf.ByteStrings;
import ru.yandex.stockpile.memState.LogEntriesContent;
import ru.yandex.stockpile.server.Txn;
import ru.yandex.stockpile.server.data.log.LogReason;
import ru.yandex.stockpile.server.data.log.StockpileLogEntryContent;
import ru.yandex.stockpile.server.data.log.StockpileLogEntryContentSerializer;
import ru.yandex.stockpile.server.data.log.StockpileLogEntrySerialized;
import ru.yandex.stockpile.server.shard.actor.ActorRunnableType;
import ru.yandex.stockpile.server.shard.actor.InActor;
import ru.yandex.stockpile.server.shard.actor.StockpileShardActState;
import ru.yandex.stockpile.server.shard.stat.StockpileShardAggregatedStats;

/**
 * @author Vladimir Gordiychuk
 */
public class LogSnapshotProcess extends ShardProcess {
    private static final long SELF_SIZE = MemoryCounter.objectSelfSizeLayout(StockpileLogEntryContent.class);
    private static final Logger logger = LoggerFactory.getLogger(LogSnapshotProcess.class);
    private final long txn;
    private long firstTxn;
    private LogReason reason;
    private long snapshotMemorySize;
    private final CompletableFuture<Txn> future = new CompletableFuture<>();
    private volatile long diskSize;

    protected LogSnapshotProcess(StockpileShard shard, long txn) {
        super(shard, ProcessType.LOG_SNAPSHOT, "tx " + txn + ", " + shard.logState.estimateLogSnapshotSize() + " bytes");
        this.txn = txn;
    }

    @Override
    void start(InActor a) {
        if (shard.logSnapshotProcess != this) {
            throw new IllegalStateException("Multiple snapshot process not allowed");
        }

        // await until finish inflight writes, before start snapshot it, otherwise
        // transaction can be reordered:
        // 1 [commit]
        // 2 [commit]
        // 3   [in-flight]
        // 4 [snapshot]
        // after restart it will looks like
        // 3. [commit]
        // 4. [commit]
        shard.txTracker.written(new TxTracker.Tx() {
            @Override
            public long txn() {
                return txn;
            }

            @Override
            public void completeTx() {
                shard.metrics.act.switchToState(StockpileShardActState.LOG_SNAPSHOT);
                shard.stats.writeLogSnapshots += 1;
                firstTxn = shard.logState.firstTxnSinceLastLogSnapshot().orElse(txn);
                reason = shard.logState.anyReasonToLog();

                LogEntriesContent content = shard.logState.takeLogSnapshot();
                snapshotMemorySize = content.memorySizeIncludingSelf();
                var snapshot = content.asLogEntryContent();
                shard.commonExecutor.execute(() -> writeSnapshot(snapshot));
            }
        });
    }

    public CompletableFuture<Txn> getFuture() {
        return future;
    }

    private void txCompletedInActor(InActor a) {
        Objects.requireNonNull(shard.logSnapshotProcess);
        shard.logSnapshotProcess = null;
        shard.logState.completeWriteLogSnapshot(diskSize, txn);
        completedSuccessfullyWriteStats();
        shard.checkSizeStartSnapshot(a);
        future.complete(new Txn(txn));
    }

    private ByteBuf serialize(StockpileLogEntryContent snapshot) {
        ByteBuf buffer = null;
        try {
            var estimatedSize = Math.toIntExact(snapshot.estimateSerializedSize());
            buffer = allocateBuffer(estimatedSize);
            var serializer = new NettyStockpileSerializer(buffer);
            StockpileLogEntryContentSerializer.S.serializeToEof(snapshot, serializer);
            diskSize = buffer.readableBytes();
            return buffer;
        } catch (Throwable e) {
            if (buffer != null) {
                buffer.release();
            }
            Throwables.throwIfUnchecked(e);
            throw new RuntimeException(e);
        } finally {
            snapshot.release();
        }
    }

    private ByteBuf allocateBuffer(int estimatedSize) {
        var allocator = PooledByteBufAllocator.DEFAULT;
        try {
            return allocator.directBuffer(estimatedSize);
        } catch (OutOfMemoryError e) {
            return allocator.heapBuffer(estimatedSize);
        }
    }

    private void writeSnapshot(StockpileLogEntryContent snapshot) {
        var buffer = serialize(snapshot);
        var serialized = new StockpileLogEntrySerialized(txn, ByteStrings.fromByteBuf(buffer));
        var future = loopUntilSuccessFuture("writeSnapshotLogEntry(" + serialized + ", delete from tx " + firstTxn + ")", () -> {
            return measureWrite(() -> shard.storage.writeLogSnapshot(firstTxn, serialized));
        });

        future.whenComplete((response, throwable) -> {
            if (throwable != null) {
                logger.warn("failed write logs snapshot in shard: " + shard.shardId + ", " + buffer.readableBytes() + " bytes", throwable);
            }

            buffer.release();
            shard.run(ActorRunnableType.COMPLETE_LOG_SNAPSHOT, LogSnapshotProcess.this::txCompletedInActor);
        });
    }

    private <A> CompletableFuture<A> measureWrite(Supplier<CompletableFuture<A>> op) {
        StockpileShardAggregatedStats stockpileShardAggregatedStats = shard.globals.stockpileShardAggregatedStats;
        stockpileShardAggregatedStats.writeLogOpsStarted.inc();

        long startNanos = System.nanoTime();
        return op.get().whenComplete((response, throwable) -> {
            if (throwable == null) {
                long durationMillis = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startNanos);

                stockpileShardAggregatedStats.writeElapsedTimeByLogType.get(LogOrSnapshot.SNAPSHOT).record(durationMillis);
                stockpileShardAggregatedStats.writeLogOpsCompleted.inc();
                shard.metrics.write.avgLogSnapshotWriteMillis.mark(durationMillis);
            }
        });
    }

    @Override
    public long memorySizeIncludingSelf() {
        return SELF_SIZE + snapshotMemorySize;
    }

    @Override
    protected void stoppedReleaseResources() {

    }
}
