package ru.yandex.stockpile.server.shard;

import java.time.Instant;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.EnumMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicLong;
import java.util.stream.Stream;

import javax.annotation.Nullable;
import javax.annotation.ParametersAreNonnullByDefault;

import ru.yandex.misc.dataSize.DataSize;
import ru.yandex.solomon.memory.layout.MemMeasurable;
import ru.yandex.solomon.memory.layout.MemMeasurableSubsystem;
import ru.yandex.solomon.memory.layout.MemoryBySubsystem;
import ru.yandex.solomon.memory.layout.MemoryCounter;
import ru.yandex.stockpile.server.SnapshotLevel;
import ru.yandex.stockpile.server.Txn;
import ru.yandex.stockpile.server.data.chunk.IndexEntry;
import ru.yandex.stockpile.server.data.chunk.IndexRangeResult;
import ru.yandex.stockpile.server.data.chunk.SnapshotAddress;
import ru.yandex.stockpile.server.data.index.SnapshotIndex;
import ru.yandex.stockpile.server.data.index.stats.IndexStatsLevel;
import ru.yandex.stockpile.server.shard.stat.LevelSizeAndCount;

/**
 * @author Stepan Koltsov
 */
@ParametersAreNonnullByDefault
public class AllIndexes implements MemMeasurableSubsystem {
    private static final long SELF_SIZE = MemoryCounter.objectSelfSizeLayout(AllIndexes.class);
    private Map<SnapshotLevel, Level> levels = new EnumMap<>(SnapshotLevel.class);
    private IndexStatsLevel stats = new IndexStatsLevel();

    // modify

    public AllIndexes(
        String debug,
        List<SnapshotIndexWithStats> twoHourIndexes,
        List<SnapshotIndexWithStats> dailyIndex,
        List<SnapshotIndexWithStats> eternityIndex)
    {
        levels.put(SnapshotLevel.TWO_HOURS, new Level(SnapshotLevel.TWO_HOURS, twoHourIndexes));
        levels.put(SnapshotLevel.DAILY, new Level(SnapshotLevel.DAILY, dailyIndex));
        levels.put(SnapshotLevel.ETERNITY, new Level(SnapshotLevel.ETERNITY, eternityIndex));
        updateRecordCountStats();
    }

    public void removeSnapshot(SnapshotAddress address) {
        levels.get(address.level()).removeSnapshot(address);
    }

    public void removeSnapshot(SnapshotIndexWithStats snapshot) {
        removeSnapshot(snapshot.getIndex().snapshotAddress());
    }

    public void addSnapshot(SnapshotIndexWithStats snapshot) {
        levels.get(snapshot.getIndex().getLevel()).addSnapshot(snapshot);
        updateRecordCountStats();
    }

    public void updateLatestSnapshotTime(SnapshotLevel level, long tsMillis) {
        levels.get(level).updateLatestSnapshotTime(tsMillis);
    }

    public void updateRecordCountStats() {
        var stats = new IndexStatsLevel();
        stream().forEach(index -> {
            stats.add(index.getLevel(), index.getContent().getStats());
        });

        this.stats = stats;
    }


    // not modification after this point (except refcounts)

    public IndexRangeResult rangeRequestsForMetric(long localId, long fromMillis) {
        List<IndexEntry> result = new ArrayList<>();
        var it = streamWithRef().iterator();
        long indexFromMillis = 0;
        boolean startFound = false;
        while (it.hasNext()) {
            var index = it.next().getIndex();
            var entry = index.findMetric(localId);
            if (entry == null) {
                continue;
            }

            if (entry.getLastTsMillis() >= fromMillis || startFound) {
                startFound = true;
                result.add(entry);
            } else {
                indexFromMillis = Math.max(indexFromMillis, entry.getLastTsMillis() + 1);
            }
        }
        return new IndexRangeResult(result, indexFromMillis);
    }

    // read


    public Stream<SnapshotIndex> stream() {
        return streamWithStats().map(SnapshotIndexWithStats::getIndex);
    }

    public Stream<SnapshotIndexWithStats> streamWithStats() {
        return streamWithRef();
    }

    public Stream<SnapshotIndexWithStats> streamWithRef() {
        return Arrays.stream(SnapshotLevel.inOrderOfRead()).flatMap(this::streamLevel);
    }

    public long latestSnapshotTime(SnapshotLevel level) {
        return levels.get(level).indexLastTs;
    }

    public int snapshotCount(SnapshotLevel level) {
        return levels.get(level).indexes.size();
    }

    public Stream<SnapshotIndexWithStats> streamLevel(SnapshotLevel level) {
        return levels.get(level).indexes.stream();
    }

    public long recordCount() {
        return stats.getTotalByLevel().getTotalByProjects().getTotalByKinds().records;
    }

    public long metricCount() {
        return stats.getTotalByLevel().getTotalByProjects().getTotalByKinds().metrics;
    }

    public IndexStatsLevel getStats() {
        return stats;
    }

    public LevelSizeAndCount diskSize(SnapshotLevel level) {
        return levels.get(level).diskSize;
    }

    public void destroy() {
        for (SnapshotLevel level : SnapshotLevel.values()) {
            levels.put(level, new Level(level, List.of()));
        }
    }

    @Override
    public void addMemoryBySubsystem(MemoryBySubsystem memory) {
        memory.addMemory("stockpile.shard.index.2h", levels.get(SnapshotLevel.TWO_HOURS).memorySizeIncludingSelf());
        memory.addMemory("stockpile.shard.index.daily", levels.get(SnapshotLevel.DAILY).memorySizeIncludingSelf());
        memory.addMemory("stockpile.shard.index.eternity", levels.get(SnapshotLevel.ETERNITY).memorySizeIncludingSelf());
        memory.addMemory("stockpile.shard.index.other", SELF_SIZE + stats.memorySizeIncludingSelf());
    }

    @Override
    public long memorySizeIncludingSelf() {
        return SELF_SIZE
                + levels.get(SnapshotLevel.TWO_HOURS).memorySizeIncludingSelf()
                + levels.get(SnapshotLevel.DAILY).memorySizeIncludingSelf()
                + levels.get(SnapshotLevel.ETERNITY).memorySizeIncludingSelf()
                + stats.memorySizeIncludingSelf();
    }

    private static class Level implements MemMeasurable {
        private static final long SELF_SIZE = MemoryCounter.objectSelfSizeLayout(AllIndexes.class);
        private final SnapshotLevel level;
        private final List<SnapshotIndexWithStats> indexes;
        private final AtomicLong memoryUsage = new AtomicLong();
        private LevelSizeAndCount diskSize = LevelSizeAndCount.zero;
        private volatile long indexLastTs = SnapshotTs.SpecialTs.NEVER.value;

        public Level(SnapshotLevel level, @Nullable SnapshotIndexWithStats snapshot) {
            this(level, snapshot != null && snapshot.isReal() ? List.of(snapshot) : List.of());
        }

        public Level(SnapshotLevel level, List<SnapshotIndexWithStats> snapshots) {
            for (var snapshot : snapshots) {
                if (level != snapshot.getIndex().getLevel()) {
                    throw new IllegalStateException("Snapshot level invalid "+ level + " != " + snapshot.getIndex().getLevel());
                }
                this.indexLastTs = snapshot.getIndex().getContent().getTsMillis();
                this.memoryUsage.addAndGet(snapshot.memorySizeIncludingSelf());
                Txn.validateTxn(snapshot.getIndex().getTxn());
                this.diskSize = LevelSizeAndCount.plus(this.diskSize, snapshot.diskSize());
            }
            this.level = level;
            this.indexes = new ArrayList<>(snapshots);
        }

        public List<SnapshotIndexWithStats> getIndexes() {
            return indexes;
        }

        private void removeSnapshot(SnapshotAddress address) {
            var it = indexes.iterator();
            while (it.hasNext()) {
                var index = it.next();
                long txn = index.getIndex().getTxn();
                if (txn > address.txn()) {
                    return;
                } else if (txn == address.txn()) {
                    it.remove();
                    diskSize = LevelSizeAndCount.minus(diskSize, index.diskSize());
                    memoryUsage.addAndGet(-index.memorySizeIncludingSelf());
                    return;
                }
            }
        }

        private void addSnapshot(SnapshotIndexWithStats snapshot) {
            Txn.validateTxn(snapshot.getIndex().getTxn());
            if (level != snapshot.getIndex().getLevel()) {
                throw new IllegalStateException("Snapshot level invalid "+ level + " != " + snapshot.getIndex().getLevel());
            }

            memoryUsage.addAndGet(snapshot.memorySizeIncludingSelf());
            diskSize = LevelSizeAndCount.plus(diskSize, snapshot.diskSize());
            var it = indexes.listIterator();
            while (it.hasNext()) {
                var index = it.next();
                long txn = index.getIndex().getTxn();
                if (txn > snapshot.getIndex().getTxn()) {
                    it.previous();
                    it.add(snapshot);
                    return;
                } else if (txn == snapshot.getIndex().getTxn()) {
                    throw new IllegalStateException("not able add snapshot with same txn " + snapshot.getIndex().snapshotAddress());
                }
            }

            indexes.add(snapshot);
        }

        private void updateLatestSnapshotTime(long tsMillis) {
            if (indexLastTs > tsMillis) {
                throw new IllegalStateException(Instant.ofEpochMilli(indexLastTs) + " > " + Instant.ofEpochMilli(tsMillis));
            }

            indexLastTs = tsMillis;
        }

        @Override
        public String toString() {
            return "Level{" +
                    level +
                    ", snapshots=" + indexes.size() +
                    ", memory=" + DataSize.prettyString(memoryUsage.get()) +
                    ", lastMerge=" + Instant.ofEpochMilli(indexLastTs) +
                    '}';
        }

        @Override
        public long memorySizeIncludingSelf() {
            return SELF_SIZE + memoryUsage.get() + indexes.size() * MemoryCounter.OBJECT_POINTER_SIZE;
        }
    }
}
