package ru.yandex.stockpile.server.shard;

import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;

import javax.annotation.Nullable;

import com.google.protobuf.ByteString;
import io.netty.buffer.ByteBuf;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import ru.yandex.misc.thread.WhatThreadDoes;
import ru.yandex.monlib.metrics.primitives.Histogram;
import ru.yandex.solomon.util.protobuf.ByteStrings;
import ru.yandex.stockpile.server.Txn;
import ru.yandex.stockpile.server.data.log.StockpileLogEntryContent;
import ru.yandex.stockpile.server.data.log.StockpileLogEntrySerialized;
import ru.yandex.stockpile.server.shard.actor.InActor;
import ru.yandex.stockpile.server.shard.actor.StockpileShardActState;
import ru.yandex.stockpile.tool.Sampler;

/**
 * @author Stepan Koltsov
 */
class TxWrite implements TxTracker.Tx {
    private static final Logger logger = LoggerFactory.getLogger(TxWrite.class);
    private static final Sampler logSampler = new Sampler(0.5, 1000);

    private final LogProcess parent;
    private final long txn;
    private final List<StockpileWriteRequest> requests;
    private final long createdAtNanos = System.nanoTime();
    private volatile TxLogEntry deltaLogEntry;

    public TxWrite(LogProcess parent, long txn, List<StockpileWriteRequest> writeRequests) {
        Txn.validateTxn(txn);
        this.parent = parent;
        this.txn = txn;
        this.requests = writeRequests;
    }

    void start(InActor a) {
        parent.shard.commonExecutor.execute(() -> {
            WhatThreadDoes.Handle h = WhatThreadDoes.push("Shard " + parent.shard.shardId + " write tx " + txn + ", " + requests.size() + " reqs");
            try {
                TxLogEntry logEntry = TxLogEntry.of(requests);
                ByteBuf content = logEntry.serialize();
                deltaLogEntry = logEntry;
                parent.shard.globals.stockpileShardAggregatedStats.writeTxPrepareTimeHistogram.record(getSpendMillis());
                CompletableFuture<Void> future = writeLogEntry(ByteStrings.fromByteBuf(content));
                future.whenComplete((response, throwable) -> {
                    content.release();
                    if (throwable != null) {
                        failed(throwable);
                    } else {
                        parent.written(this);
                        parent.shard.globals.usage.write(deltaLogEntry.getStats());
                    }
                });
            } catch (Throwable e) {
                failed(e);
            } finally {
                h.popSafely();
            }
        });
    }

    private void failed(Throwable e) {
        logger.error("WRITE for shard {} tx {}, {} reqs failed", parent.shard.shardId, txn, requests.size(), e);
        completeWriteFutures(e);
        var log = deltaLogEntry;
        if (log != null) {
            log.release();
        }
        parent.failed(this, e);
    }

    private long getSpendMillis() {
        return TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - createdAtNanos);
    }

    private void completeWriteFutures(@Nullable Throwable e) {
        if (e != null) {
            parent.shard.commonExecutor.execute(() -> {
                for (StockpileWriteRequest r : requests) {
                    r.getFuture().completeExceptionally(e);
                    r.release();
                }
            });
        } else {
            parent.shard.commonExecutor.execute(() -> {
                Txn writeTx = new Txn(txn);
                for (StockpileWriteRequest r : requests) {
                    r.getFuture().complete(writeTx);
                    r.release();
                }
            });
        }
    }

    private CompletableFuture<Void> writeLogEntry(ByteString content) {
        StockpileLogEntrySerialized serialized = new StockpileLogEntrySerialized(txn, content);
        return parent.loopUntilSuccessFuture("writeLogEntryNotSnapshot(" + serialized + ")", () -> {
            var metrics = parent.shard.globals.stockpileShardAggregatedStats;
            metrics.writeLogOpsStarted.inc();
            metrics.writeLogByteSize.record(serialized.size());

            long startNanos = System.nanoTime();
            return parent.shard.storage.writeLogEntry(serialized).whenComplete((response, throwable) -> {
                long durationMillis = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startNanos);
                if (throwable == null) {
                    metrics.writeElapsedTimeByLogType.get(LogOrSnapshot.SHORT).record(durationMillis);
                    metrics.writeLogTimeHistogramByRecords.record(durationMillis, deltaLogEntry.getRecords());
                    metrics.writeLogOpsCompleted.inc();
                }
                parent.shard.metrics.write.avgLogWriteMillis.mark(durationMillis);
            });
        });
    }

    @Override
    public long txn() {
        return txn;
    }

    @Override
    public void completeTx() {
        try {
            parent.shard.metrics.act.switchToState(StockpileShardActState.RUNNABLE_COMPLETE_WRITE);
            completeWriteFutures(null);
            StockpileLogEntryContent entry = deltaLogEntry.getLogEntryContent();
            TxWriteSummary summary = deltaLogEntry.getSummary();
            parent.shard.metrics.write.update(summary);
            timeMeasure(
                () -> parent.shard.logState.completeWriteLogTx(entry, summary.bytes, this.txn),
                parent.shard.globals.stockpileShardAggregatedStats.writeTxUpdateMemStateTimeHistogram);
            timeMeasure(
                () -> parent.shard.cache.updateWithOnWriteCompleted(entry),
                parent.shard.globals.stockpileShardAggregatedStats.writeTxUpdateCacheTimeHistogram);

            if (logSampler.acquire()) {
                logger.info("WRITE for shard {}, {} reqs completed in {} ms", parent.shard.shardId, requests.size(), getSpendMillis());
            }
        } catch (Throwable e) {
            logger.error("COMPLETE_WRITE for shard {} tx {}, {} reqs failed", parent.shard.shardId, txn, requests.size(), e);
            throw new RuntimeException("Shard " + parent.shard.shardId + ", txn " + txn + ", " + requests.size() + " reqs failed", e);
        } finally {
            deltaLogEntry.release();
        }
    }

    private void timeMeasure(Runnable task, Histogram histogram) {
        long started = System.nanoTime();
        task.run();
        histogram.record(TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - started));
    }
}
