#pragma once

#include "storage.h"

#include <solomon/agent/protos/storage_config.pb.h>

#include <solomon/agent/misc/logger.h>
#include <solomon/agent/misc/timer_dispatcher.h>

#include <solomon/libs/cpp/sync/rw_lock.h>

#include <library/cpp/monlib/metrics/metric.h>
#include <library/cpp/monlib/metrics/metric_value.h>

#include <util/generic/deque.h>
#include <util/system/spinlock.h>

namespace NSolomon::NAgent {
namespace {

using TMetricTimeSeries = NWithMemoryInfo::TMetricTimeSeries;

struct TMetricData {
    TMetricTimeSeries Series;
    NMonitoring::EMetricType Type{NMonitoring::EMetricType::UNKNOWN};
};

using TMetric = std::pair<TLabels, TMetricData>;
using TMetricList = TVector<TMetric>;

class TChunk: TMoveOnly, public TAtomicRefCount<TChunk> {
public:
    TChunk() = default;

    void AddMetricData(TLabels&& labels, TMetricData&& data) {
        // Note: labels are sorted and should stay that way after the write!

        MemorySizeBytes_ += static_cast<NWithMemoryInfo::TLabels&>(labels).MemorySizeBytes();
        MemorySizeBytes_ += sizeof(data.Type);
        MemorySizeBytes_ += data.Series.MemorySizeBytes();

        if (MetricList_.capacity() > MetricList_.size()) {
            // Because its sizeof was already taken into account before
            MemorySizeBytes_ -= sizeof(TMetric);
        }

        ui64 oldCapacity = MetricList_.capacity();

        MetricList_.emplace_back(std::move(labels), std::move(data));

        ui64 newCapacity = MetricList_.capacity();
        ui64 diff = newCapacity - oldCapacity;

        if (diff > 1) {
            MemorySizeBytes_ += (diff - 1) * sizeof(TMetric);
        }
    }

    void CopyCommonLabels(const TLabels& commonLabels) {
        Y_ENSURE(CommonLabels_.empty(), "CommonLabels in a chunk are not empty");

        // Note: labels are sorted and should stay that way after the write!
        for (auto& label: commonLabels) {
            CommonLabels_.Add(label);
        }

        MemorySizeBytes_ += static_cast<NWithMemoryInfo::TLabels&>(CommonLabels_).MemorySizeBytes();
    }

    const TLabels& CommonLabels() const {
        return CommonLabels_;
    }

    TMetric& Get(ui32 i) {
        return MetricList_[i];
    }

    TBytes MemorySizeBytes() const {
        return MemorySizeBytes_;
    }

    const TMetric& Get(ui32 i) const {
        return MetricList_[i];
    }

    size_t Size() const {
        return MetricList_.size();
    }

    void SetSeqNo(TSeqNo seqNo) {
        SeqNo_ = seqNo;
    }

    TSeqNo SeqNo() const {
        return SeqNo_;
    }

private:
    TSeqNo SeqNo_;
    TMetricList MetricList_;
    TLabels CommonLabels_;
    TBytes MemorySizeBytes_{sizeof(TChunk)};
};

using TChunkPtr = TIntrusivePtr<TChunk>;
using TChunkPointers = TVector<TChunkPtr>;
using TDataChunks = TDeque<TChunkPtr>;


TSeqNo INITIAL_OFFSET = TSeqNo{1, 0};

class TConsumerOffsets {
private:
    class TOffsetInfo {
    public:
        explicit TOffsetInfo(TSeqNo seqNo)
            : SeqNo_{seqNo}
            , LastUpdate_{TInstant::Now()}
        {}

        TOffsetInfo(const TOffsetInfo& other) = default;

        TOffsetInfo& operator=(const TOffsetInfo& other) = default;

        void Refresh() {
            LastUpdate_ = TInstant::Now();
        }

        void SetSeqNoSilently(TSeqNo newSeqNo) {
            SeqNo_ = newSeqNo;
            DirtySeqNo_ = Max(SeqNo_, DirtySeqNo_);
        }

        void SetSeqNo(TSeqNo newSeqNo) {
            SetSeqNoSilently(newSeqNo);
            Refresh();
        }

        void SetDirtySeqNo(TSeqNo dirtySeqNo) {
            DirtySeqNo_ = dirtySeqNo;
        }

        TSeqNo SeqNo() const {
            return SeqNo_;
        }

        TSeqNo DirtySeqNo() const {
            return DirtySeqNo_;
        }

        TInstant LastUpdate() const {
            return LastUpdate_;
        }

    private:
        TSeqNo SeqNo_;
        TSeqNo DirtySeqNo_;
        TInstant LastUpdate_;
    };

    // consumerId --> TSeqNo
    THashMap<TString, TOffsetInfo> Offsets_;
    TAdaptiveLock Lock_;
    TDuration SoftTTL_;
    TDuration HardTTL_;

public:
    TConsumerOffsets(TDuration softTTL, TDuration hardTTL)
        : SoftTTL_{softTTL}
        , HardTTL_{hardTTL}
    {}

    void CleanupByTTL() {
        auto g = Guard(Lock_);

        TVector<decltype(Offsets_)::iterator> obsoleteOffsets;
        std::optional<TSeqNo> minActiveSeqNo;
        TInstant now = TInstant::Now();

        for (auto it = Offsets_.begin(); it != Offsets_.end();) {
            auto currIt = it++;
            const TInstant ts = currIt->second.LastUpdate();
            const TSeqNo seqNo = currIt->second.SeqNo();

            if (ts + HardTTL_ < now) {
                SA_LOG(INFO) << "erasing an offset value for " << currIt->first
                             << " because of a hard TTL(" << HardTTL_.ToString() << ")";
                Offsets_.erase(currIt);
            } else if (ts + SoftTTL_ < now) {
                obsoleteOffsets.emplace_back(currIt);
            } else {
                minActiveSeqNo = minActiveSeqNo.has_value() ? Min(*minActiveSeqNo, seqNo): seqNo;
            }
        }

        if (!obsoleteOffsets.empty() && minActiveSeqNo.has_value()) {
            for (auto& it: obsoleteOffsets) {
                const TSeqNo seqNo = it->second.SeqNo();

                if (*minActiveSeqNo > seqNo) {
                    SA_LOG(INFO) << "changing an offset value for " << it->first
                                 << " from " << seqNo << " to " << *minActiveSeqNo
                                 << ", because of a soft TTL(" << SoftTTL_.ToString() << ")";
                    it->second.SetSeqNoSilently(*minActiveSeqNo);

                    // TODO: write skipped data to a persistent storage
                }
            }
        }
    }

    TSeqNo GetOrInitConsumerOffset(const TString& consumerId) {
        auto g = Guard(Lock_);

        auto [it, isInserted] = Offsets_.emplace(consumerId, TOffsetInfo{INITIAL_OFFSET});
        if (isInserted) {
            SA_LOG(INFO) << "Registering new consumer: " << consumerId
                         << " with offset " << INITIAL_OFFSET;
        }

        return it->second.SeqNo();
    }

    TSeqNo FirstUnreadChunk() const {
        auto g = Guard(Lock_);

        if (Offsets_.size() == 0) {
            return INITIAL_OFFSET;
        }

        TSeqNo result = TSeqNo::Max();

        for (auto c: Offsets_) {
            if (c.second.SeqNo().ChunkOffset() < result.ChunkOffset()) {
                result = c.second.SeqNo();
            }
        }

        return result;
    }

    bool Has(const TString& consumerId) const {
        auto g = Guard(Lock_);

        return Offsets_.contains(consumerId);
    }

    bool Commit(const TString& consumerId, TSeqNo seqNo, TSeqNo endSeqNo) {
        auto g = Guard(Lock_);
        auto it = Offsets_.find(consumerId);

        if (it == Offsets_.end()) {
            // New consumer. There are two options:
            // 1. Consumer has restarted and hence will start from 1
            // 2. Agent has restarted and hence a consumer should get all new data, starting from 1
            seqNo = INITIAL_OFFSET;
            Offsets_.emplace(consumerId, seqNo);

            SA_LOG(INFO) << "Registering new consumer: " << consumerId << " with offset " << seqNo;

            return true;
        }

        if (seqNo == TSeqNo{0}) {
            const auto& dirtySeqNo = it->second.DirtySeqNo();

            SA_LOG(INFO) << "shifting the offset value for consumer " << consumerId << " from "
                         << it->second.SeqNo() << " to the dirty offset " << dirtySeqNo;

            it->second.SetSeqNo(dirtySeqNo);

            return true;
        }

        if (seqNo == it->second.SeqNo()) {
            it->second.Refresh();

            SA_LOG(INFO) << consumerId << " is trying to commit the same offset " << seqNo
                         << ". Updating the ts value";
            return true;
        }

        if (seqNo < it->second.SeqNo()) {
            SA_LOG(INFO) << consumerId << " is trying to commit offset " << seqNo
                         << ", but its previous offset was " << it->second.SeqNo();
            return false;
        }

        if (seqNo > endSeqNo) {
            SA_LOG(INFO) << consumerId << " is trying to commit offset " << seqNo
                         << ", but the latest chunk is " << endSeqNo;
            return false;
        }

        it->second.SetSeqNo(seqNo);
        SA_LOG(DEBUG) << "Committing offset " << seqNo << " for " << consumerId;

        return true;
    }

    void CommitDirty(const TString& consumerId, TSeqNo seqNo) {
        auto g = Guard(Lock_);

        if (auto it = Offsets_.find(consumerId); it != Offsets_.end()) {
            if (it->second.DirtySeqNo() != seqNo) {
                SA_LOG(DEBUG) << "Changing a dirty offset value to " << seqNo << " for " << consumerId;
            }

            it->second.SetDirtySeqNo(seqNo);
        }
    }
};

struct TMetricWithCommonLabels {
    const TMetric& Metric;
    const TLabels& CommonLabels;
};

struct TMetricIteratorTraits {
    using TContainer = TChunkPointers;
    using TContainerRef = TContainer&;
    using TValue = TMetricWithCommonLabels;
    using TChunkIterator = TContainer::iterator;
};

struct TMetricConstIteratorTraits {
    using TContainer = TChunkPointers;
    using TContainerRef = const TContainer&;
    using TValue = TMetricWithCommonLabels;
    using TChunkIterator = TContainer::const_iterator;
};

template <typename TTraits>
class TMetricIteratorImpl: TNonCopyable {
    using TContainer = typename TTraits::TContainer;
    using TContainerRef = typename TTraits::TContainerRef;
    using TValue = typename TTraits::TValue;
    using TChunkIterator = typename TTraits::TChunkIterator;

public:
    explicit TMetricIteratorImpl()
        : SeqNo_{INITIAL_OFFSET}
        , End_{SeqNo_}
    {
    }

    explicit TMetricIteratorImpl(TContainer&& data, TSeqNo begin)
        : Data_{std::move(data)}
        , CurrentChunk_{std::begin(Data_)}
        , SeqNo_{begin}
        , End_{Data_.back()->SeqNo().ChunkOffset(), Data_.back()->Size()}
    {
    }

    TValue Next() {
        Y_ENSURE(SeqNo_.MetricOffset() < (*CurrentChunk_)->Size());

        const auto metricIdx = SeqNo_.MetricOffset();
        auto& elem = (*CurrentChunk_)->Get(metricIdx);
        const TLabels& commonLabels = (*CurrentChunk_)->CommonLabels();

        if (metricIdx == (*CurrentChunk_)->Size() - 1) {
            SeqNo_ = SeqNo_.NextChunkNo();
            ++CurrentChunk_;
        } else {
            SeqNo_ = SeqNo_.NextMetricNo();
        }

        return {elem, commonLabels};
    }

    bool HasNext() const {
        return SeqNo_ < End_;
    }

    TSeqNo ToSeqNo() const {
        return SeqNo_;
    }

private:
    TContainer Data_;
    TChunkIterator CurrentChunk_;

    TSeqNo SeqNo_;
    const TSeqNo End_;
};

using TMetricIterator = TMetricIteratorImpl<TMetricIteratorTraits>;
using TMetricConstIterator = TMetricIteratorImpl<TMetricConstIteratorTraits>;
} // namespace


template <typename TContainer>
auto FindChunk(TContainer& c, TSeqNo seqNo) -> decltype(std::begin(c)) {
    auto it = LowerBound(std::begin(c), std::end(c), seqNo, [] (const auto& val, TSeqNo sn) {
        return val->SeqNo().ChunkOffset() < sn.ChunkOffset();
    });

    if (it != std::end(c) && (*it)->SeqNo().ChunkOffset() == seqNo.ChunkOffset()) {
        return it;
    }

    return std::end(c);
}

class TMemStorage: public IStorage {
public:
    TMemStorage(
            const TStorageShardId& shardId,
            const TBytes perShardMemoryLimit,
            const ui32 maxChunks,
            const TOffsetsSettings& offsetsSettings,
            TMemoryUsageInfoPtr totalMemoryUsageInfo,
            IStorageUpdateListener* listener);

    IStorageMetricsConsumerPtr CreateConsumer(TInstant defaultTs) override;

protected:
    void RemoveReadData(TDataChunks& data) {
        // is called under the write lock

        if (data.empty()) {
            return;
        }

        auto firstUnreadSeqNo = Offsets_.FirstUnreadChunk();
        auto firstUnreadIt = LowerBound(std::begin(data), std::end(data), firstUnreadSeqNo,
            [](const auto& elem, TSeqNo seqNo) {
                return elem->SeqNo().ChunkOffset() < seqNo.ChunkOffset();
            }
        );

        bool readDataExists = firstUnreadIt != std::begin(data);
        if (readDataExists) {
            ui64 chunksRemoved = 0;
            ui64 pointsRemoved = 0;
            TBytes bytesFreed = 0;

            for (auto it = std::begin(data); it != firstUnreadIt; ++it) {
                ++chunksRemoved;
                pointsRemoved += (*it)->Size();
                bytesFreed += (*it)->MemorySizeBytes();
            }

            data.erase(std::begin(data), firstUnreadIt);
            UpdateListener_->OnPointsRemoved(pointsRemoved);

            CurrentShardMemoryUsage_ -= bytesFreed;

            TotalMemoryUsageInfo_->DecreaseSize(bytesFreed);
            TotalMemoryUsageInfo_->BytesRead(bytesFreed);

            UpdateListener_->OnSizeChanged(CurrentShardMemoryUsage_);
            UpdateListener_->OnBytesRead(bytesFreed);
            UpdateListener_->OnRemovedReadChunks(chunksRemoved);
        }
    }

    TBytes CalculateFreeMemory() const {
        return MemoryLimit_ - CurrentShardMemoryUsage_;
    }

    TBytes CalculateMemorySizeToFree(TBytes freeMemory, TBytes memoryRequested, TDataChunks& data) {
        // is called under the write lock

        TBytes willBeFreed = 0;

        auto it = data.begin();
        while (freeMemory + willBeFreed < memoryRequested) {
            willBeFreed += (*it)->MemorySizeBytes();
            ++it;
        }

        return willBeFreed;
    }

    TBytes RemoveDataForSize(TBytes freeMemory, TBytes memoryRequested, TDataChunks& data) {
        // is called under the write lock

        TBytes memoryFreed = 0;
        ui64 pointsRemoved = 0;
        ui64 chunksEvicted = 0;

        auto it = data.begin();

        while (freeMemory + memoryFreed < memoryRequested) {
            memoryFreed += (*it)->MemorySizeBytes();
            pointsRemoved += (*it)->Size();
            ++it;
            ++chunksEvicted;
        }

        data.erase(std::begin(data), it);
        UpdateListener_->OnOverflow(chunksEvicted);
        UpdateListener_->OnPointsRemoved(pointsRemoved);

        CurrentShardMemoryUsage_ -= memoryFreed;
        UpdateListener_->OnSizeChanged(CurrentShardMemoryUsage_);
        UpdateListener_->OnBytesEvicted(memoryFreed);
        TotalMemoryUsageInfo_->BytesEvicted(memoryFreed);

        SA_LOG(WARN) << "not enough free memory while writing to " << ShardId_ << "."
                     << " Evicted " << memoryFreed << " bytes of oldest data";

        return memoryFreed;
    }

    void TryWriteChunk(TChunkPtr chunk) {
        auto data = Data_.Write();

        const auto sn = EndSeqNo();
        chunk->SetSeqNo(sn);

        TBytes freeMemory = 0;
        TBytes toIncrease = 0; // how much memory should be added to Total memory usage
        TBytes toDecrease = 0; // how much memory should be subtracted from Total memory usage
        bool needToFreeAdditionalMemory = false;

        /*
         * Let's assume that
         *     - the per shard limit is 3MiB
         *     - memoryRequested (e.g. the size of a chunk to be written) is 2MiB
         *     - and current memory usage is like this:
         *
         *  Chunk1    Chunk2      freeMemory
         *  ____________________________________
         * | 0.7MiB | 0.6MiB |     1.7MiB       |
         *  ‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾
         * Then the storage usage before the write is 0.7 + 0.6 == 1.3MiB. Free memory is not enough, so we'll need
         * to remove the first chunk to get additional space -- needToFreeAdditionalMemory is true
         * After that the usage will be:
         *
         *  Chunk2           freeMemory
         *  ____________________________________
         * | 0.6MiB |          2.4MiB           |
         *  ‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾
         * Now free memory is more than enough. After writing a new chunk:
         *
         *  Chunk2         Chunk3       freeMemory
         *  ____________________________________
         * | 0.6MiB |      2.0MiB      | 0.4MiB |
         *  ‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾
         * Now storage usage is 2.6MiB, so:
         *   - toIncrease is 2.6 - 1.3 == 1.3MiB
         *   - toDecrease is 0
         *   - needToFreeAdditionalMemory is true
         *
         * Sometimes toDecrease can be > 0 if after deleting old chunks we have more memory than requested
         */

        TBytes memoryRequested = chunk->MemorySizeBytes();
        TBytes availableMemory = 0;
        while (true) {
            freeMemory = Min(CalculateFreeMemory(), TotalMemoryUsageInfo_->CalculateFreeMemory());
            availableMemory = CurrentShardMemoryUsage_ + freeMemory;

            if (memoryRequested > availableMemory) {
                ythrow yexception() << "skipping the write for " << ShardId_ << " due to the memory constraint."
                                    << " Tried to write " << memoryRequested << " bytes,"
                                    << " while the amount of available memory is " << availableMemory << " bytes";
            }

            if (freeMemory >= memoryRequested) {
                toIncrease = memoryRequested;
            } else {
                needToFreeAdditionalMemory = true;

                TBytes toFree = CalculateMemorySizeToFree(freeMemory, memoryRequested, *data);
                if (toFree > memoryRequested) {
                    toDecrease = toFree - memoryRequested;
                } else {
                    toIncrease = memoryRequested - toFree;
                }
            }

            if (toIncrease > 0 && !TotalMemoryUsageInfo_->TryIncreaseSize(toIncrease)) {
                // There's less available memory than before, because of other shards. Repeat
                continue;
            }

            break;
        }

        if (needToFreeAdditionalMemory) {
            RemoveDataForSize(freeMemory, memoryRequested, *data);
        }

        DoWriteChunk(std::move(chunk), sn, *data);

        // TODO: a place for an optimization
        // There will be a moment when a shard actually has less data than is counted in TotalMemoryUsageInfo.
        // Use a lock shared between all shards?
        if (toDecrease > 0) {
            TotalMemoryUsageInfo_->DecreaseSize(toDecrease);
        }
    }

    void DoWriteChunk(TChunkPtr chunk, TSeqNo sn, TDataChunks& data) {
        auto chunkSize = chunk->MemorySizeBytes();
        CurrentShardMemoryUsage_ += chunkSize;

        data.push_back(std::move(chunk));
        UpdateListener_->OnAdded(1);

        UpdateListener_->OnPointsAdded(data.back()->Size());
        UpdateListener_->OnBytesWritten(chunkSize);
        TotalMemoryUsageInfo_->BytesWritten(chunkSize);

        if (data.size() > MaxChunks_) {
            ui64 pointsRemoved = data.front()->Size();
            TBytes memoryFreed = data.front()->MemorySizeBytes();

            data.pop_front();
            CurrentShardMemoryUsage_ -= memoryFreed;

            UpdateListener_->OnOverflow(1);
            UpdateListener_->OnPointsRemoved(pointsRemoved);
            UpdateListener_->OnBytesEvicted(memoryFreed);
            TotalMemoryUsageInfo_->BytesEvicted(memoryFreed);
            TotalMemoryUsageInfo_->DecreaseSize(memoryFreed);
        }

        UpdateListener_->OnSizeChanged(CurrentShardMemoryUsage_);
        EndSeqNo_ = sn.NextChunkNo();
    }

private:
    // -- read member functions -----------------------------------------------

    TMetricConstIterator ReadRawData(const TQuery& query);
    void UpdateDirtyOffset(const TQuery& query, TSeqNo dirtySeqNo);

    TReadResult Read(
            const TQuery& query,
            NMonitoring::IMetricConsumer* c,
            const TReadOptions&) override;

    TFindResult Find(
            const TQuery& query,
            NMonitoring::IMetricConsumer* c,
            const TFindOptions&) override;

    // -- write member functions ----------------------------------------------

    void Delete(const TQuery&, const TDeleteOptions&) override;

    void Commit(const TString& consumerId, TSeqNo seqNo) override;

    // -- impl specific

    void Write(TDataChunks&& chunks);

    TSeqNo EndSeqNo() const;

    // must be called under read lock
    TSeqNo CalculateOffset(const TDataChunks& data, const TMaybe<TString>& consumerId, TMaybe<TSeqNo> offset);

    bool IsValidOffset(TSeqNo seqNo) const;

    void RegisterTimerFunction();

private:
    class TMemStorageMetricsConsumer;

    TStorageShardId ShardId_;

    TBytes CurrentShardMemoryUsage_;
    TBytes MemoryLimit_;
    TBytes DataMemoryLimit_;
    const ui32 MaxChunks_;
    TMemoryUsageInfoPtr TotalMemoryUsageInfo_;

    NSync::TLightRwLock<TDataChunks> Data_;
    TConsumerOffsets Offsets_;
    TSeqNo EndSeqNo_{INITIAL_OFFSET};
    TTimerDispatcher TimerDispatcher_;

    IStorageUpdateListener* UpdateListener_;
};

class TDummyStorageListener: public IStorageUpdateListener {
public:
    TDummyStorageListener() {}

    static TDummyStorageListener* Get() {
        static TDummyStorageListener d;
        return &d;
    }

    void OnFetcherUpdated(const TString&, ui64) noexcept override {
    }

    void OnPointsAdded(ui64) noexcept override {
    }

    void OnPointsRemoved(ui64) noexcept override {
    }

    void OnSizeChanged(TBytes) noexcept override {
    }

    void OnBytesWritten(TBytes) noexcept override {
    }

    void OnBytesEvicted(TBytes) noexcept override {
    }

    void OnBytesRead(TBytes) noexcept override {
    }

    void OnAdded(ui64) noexcept override {
    }

    void OnOverflow(ui64) noexcept override {
    }

    void OnRemovedReadChunks(ui64) noexcept override {
    }

    void OnLimitSet(TBytes) noexcept override {
    }
};
} // namespace NSolomon::NAgent {
