package ru.yandex.stockpile.server.data.dao;

import java.util.Arrays;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.function.Supplier;

import ru.yandex.kikimr.client.kv.KikimrKvClient;
import ru.yandex.kikimr.client.kv.KvReadRangeResult;
import ru.yandex.kikimr.util.NameRange;
import ru.yandex.solomon.codec.archive.MetricArchiveImmutable;
import ru.yandex.solomon.memory.layout.MemMeasurableSubsystem;
import ru.yandex.solomon.memory.layout.MemoryBySubsystem;
import ru.yandex.solomon.memory.layout.MemoryCounter;
import ru.yandex.solomon.util.collection.enums.EnumMapToAtomicLong;
import ru.yandex.stockpile.kikimrKv.counting.ReadClass;
import ru.yandex.stockpile.server.ListOfOptional;
import ru.yandex.stockpile.server.SnapshotLevel;
import ru.yandex.stockpile.server.data.chunk.ChunkAddressGlobal;
import ru.yandex.stockpile.server.data.chunk.ChunkWithNo;
import ru.yandex.stockpile.server.data.chunk.DataRangeGlobal;
import ru.yandex.stockpile.server.data.chunk.SnapshotAddress;
import ru.yandex.stockpile.server.data.command.SnapshotCommandPartsSerialized;
import ru.yandex.stockpile.server.data.index.SnapshotIndexPartsSerialized;
import ru.yandex.stockpile.server.data.log.StockpileLogEntrySerialized;
import ru.yandex.stockpile.server.data.names.FileNamePrefix;


/**
 * @author Sergey Polovko
 */
public class StockpileShardStorageMeasured implements StockpileShardStorage, MemMeasurableSubsystem {

    private final StockpileShardStorage delegate;

    /**
     * Memory used while some background process is in progress.
     */
    private final EnumMapToAtomicLong<InFlightDataType> usedMemory;

    public StockpileShardStorageMeasured(StockpileShardStorage delegate) {
        this.delegate = delegate;
        this.usedMemory = new EnumMapToAtomicLong<>(InFlightDataType.class);
    }

    @Override
    public long getGeneration() {
        return delegate.getGeneration();
    }

    @Override
    public CompletableFuture<Void> lock() {
        return delegate.lock();
    }

    @Override
    public CompletableFuture<Void> writeLogEntry(StockpileLogEntrySerialized serialized) {
        long size = serialized.memorySizeIncludingSelf();
        return measureMemoryUsage(InFlightDataType.LOG_WRITE, size, () -> {
            return delegate.writeLogEntry(serialized);
        });
    }

    @Override
    public CompletableFuture<Void> writeLogSnapshot(long firstLogTxn, StockpileLogEntrySerialized serialized) {
        long size = serialized.memorySizeIncludingSelf();
        return measureMemoryUsage(InFlightDataType.LOG_WRITE, size, () -> {
            return delegate.writeLogSnapshot(firstLogTxn, serialized);
        });
    }

    @Override
    public CompletableFuture<Void> writeSnapshotChunkToTemp(SnapshotLevel level, long txn, ChunkWithNo chunkWithNo) {
        long size = chunkWithNo.getContent().length;
        return measureMemoryUsage(InFlightDataType.CHUNK_WRITE, size, () -> {
            return delegate.writeSnapshotChunkToTemp(level, txn, chunkWithNo);
        });
    }

    private <V> CompletableFuture<V> measureMemoryUsage(
        InFlightDataType dataType, long size, Supplier<CompletableFuture<V>> op)
    {
        usedMemory.addAndGet(dataType, size);
        CompletableFuture<V> future;
        try {
            future = op.get();
        } catch (Throwable t) {
            future = CompletableFuture.failedFuture(t);
        }
        return future.whenComplete((u, t) -> {
           usedMemory.addAndGet(dataType, -size);
        });
    }

    @Override
    public CompletableFuture<Void> writeSnapshotIndexToTemp(SnapshotIndexPartsSerialized parts) {
        long size = parts.memorySizeIncludingSelf();
        return measureMemoryUsage(InFlightDataType.INDEX_WRITE, size, () -> {
            return delegate.writeSnapshotIndexToTemp(parts);
        });
    }

    @Override
    public CompletableFuture<Void> writeSnapshotCommandToTemp(SnapshotCommandPartsSerialized parts) {
        long size = parts.memorySizeIncludingSelf();
        return measureMemoryUsage(InFlightDataType.COMMAND_WRITE, size, () -> {
            return delegate.writeSnapshotCommandToTemp(parts);
        });
    }

    @Override
    public CompletableFuture<Void> writeProducerSeqNoSnapshot(byte[] snapshot) {
        long size = MemoryCounter.arrayObjectSize(snapshot);
        return measureMemoryUsage(InFlightDataType.CHUNK_WRITE, size, () -> {
            return delegate.writeProducerSeqNoSnapshot(snapshot);
        });
    }

    @Override
    public CompletableFuture<Void> deleteTempFiles() {
        return delegate.deleteTempFiles();
    }

    @Override
    public CompletableFuture<Void> renameSnapshotDeleteLogs(SnapshotAddress address) {
        return delegate.renameSnapshotDeleteLogs(address);
    }

    @Override
    public CompletableFuture<Void> renameSnapshotDeleteOld(SnapshotAddress[] renameSnapshots, SnapshotAddress[] deleteSnapshots) {
        return delegate.renameSnapshotDeleteOld(renameSnapshots, deleteSnapshots);
    }

    @Override
    public CompletableFuture<byte[]> readChunk(
        ReadClass readClass, ChunkAddressGlobal chunkAddress, FileNamePrefix prefix)
    {
        return delegate.readChunk(readClass, chunkAddress, prefix);
    }

    @Override
    public CompletableFuture<List<KikimrKvClient.KvEntryStats>> readRangeNames(ReadClass readClass, NameRange nameRange) {
        return delegate.readRangeNames(readClass, nameRange);
    }

    @Override
    public CompletableFuture<Optional<byte[]>> readData(ReadClass readClass, String name) {
        return delegate.readData(readClass, name);
    }

    @Override
    public CompletableFuture<ListOfOptional<MetricArchiveImmutable>> readSnapshotRanges(DataRangeGlobal[] ranges) {
        long size = Arrays.stream(ranges)
            .mapToLong(DataRangeGlobal::getLength)
            .filter(l -> l >= 0)
            .sum();

        return measureMemoryUsage(InFlightDataType.CHUNK_READ, size, () -> {
            return delegate.readSnapshotRanges(ranges);
        });
    }

    @Override
    public CompletableFuture<Optional<MetricArchiveImmutable>> readSnapshotRange(DataRangeGlobal range) {
        long size = Math.max(range.getLength(), 0);
        return measureMemoryUsage(InFlightDataType.CHUNK_READ, size, () -> {
            return delegate.readSnapshotRange(range);
        });
    }

    @Override
    public CompletableFuture<Void> flushReadQueue() {
        return delegate.flushReadQueue();
    }

    @Override
    public CompletableFuture<KvReadRangeResult> readRange(ReadClass readClass, NameRange nameRange) {
        return delegate.readRange(readClass, nameRange);
    }

    @Override
    public void addMemoryBySubsystem(MemoryBySubsystem memory) {
        for (InFlightDataType fileType : InFlightDataType.values()) {
            String key = "stockpile.shard.inFlight." + fileType.name();
            memory.addMemory(key, usedMemory.get(fileType));
        }
    }

    /**
     * IN FLIGHT DATA TYPE
     */
    private enum InFlightDataType {
        LOG_READ,
        LOG_WRITE,
        INDEX_READ,
        INDEX_WRITE,
        COMMAND_WRITE,
        CHUNK_READ,
        CHUNK_WRITE,
    }
}
