#include "stream_data.h"

#include <library/cpp/logger/global/global.h>
#include <library/cpp/protobuf/util/pb_io.h>

#include <util/folder/dirut.h>

namespace NRTYServer {
    template <class TBaseStreamData, class TDerived>
    void IIndexStreamData<TBaseStreamData, TDerived>::ReadNoLock() {
        const TFsPath& file = IndexDir / Filename;
        if (NFs::Exists(file)) {
            try {
                typename TDerived::TProtoType protobuf;
                if (TryParseFromTextFormat(file, protobuf)) {
                    static_cast<TDerived&>(*this).LoadFromProtobuf(protobuf);
                }
            }
            catch (...) {
                TBaseStreamData::ClearNoLock();
                ProcessError("Exception while reading " + file.GetPath() + ": " + CurrentExceptionMessage());
            }
            if (!TBaseStreamData::IsPresentNoLock()) {
                ProcessError(file.GetPath() + " is corrupted");
            }
        }
        else {
            // For temp index, it starts out having nothing.
            // For merged index, it is created through DDK component's Merge().
            INFO_LOG << file.GetPath() << " does not exist" << Endl;
        }
    }

    template <class TBaseStreamData, class TDerived>
    void IIndexStreamData<TBaseStreamData, TDerived>::FlushToDirNoLock(const TFsPath& dir) const {
        if (TBaseStreamData::IsPresentNoLock()) {
            const TFsPath file(dir / Filename);
            const TFsPath tmpFile(dir / (Filename + ".tmp"));
            try {
                typename TDerived::TProtoType protobuf;
                static_cast<const TDerived&>(*this).FillProtobuf(protobuf);

                SerializeToTextFormat(protobuf, tmpFile);

                tmpFile.RenameTo(file);
                DEBUG_LOG << "Created " << file.GetPath() << Endl;
            }
            catch (...) {
                ProcessError("Exception while creating " + file.GetPath() + ": " + CurrentExceptionMessage());
            }
        }
    }

    template <class TBaseStreamData, class TDerived>
    void IIndexStreamData<TBaseStreamData, TDerived>::ProcessError(const TString& message) const {
        ERROR_LOG << message << Endl;
    }

    template class IIndexStreamData<TBaseTimestamp, TIndexTimestamp>;
    template class IIndexStreamData<TBasePositions, TIndexPositions>;

    void TBaseTimestamp::Merge(const TBaseTimestamp& other, ui64 updateTimestamp) {
        if (this == &other) {
            return;
        }
        if (this < &other) {
            TWriteGuard gw(ReadWriteMutex);
            TReadGuard gr(other.ReadWriteMutex);
            MergeNoLock(other, updateTimestamp);
        } else {
            TReadGuard gr(other.ReadWriteMutex);
            TWriteGuard gw(ReadWriteMutex);
            MergeNoLock(other, updateTimestamp);
        }
    }

    void TBaseTimestamp::SetNoLock(TStreamId streamId, TTimestampValue timestamp, ui64 updateTimestamp) {
        TTs& ts = Snapshot[streamId];
        ts.MaxValue = timestamp;
        ts.MinValue = timestamp;
        ts.AvgsSum = ts.AvgsPortions * GetLow(timestamp);
        ts.UpdateTimestamp = updateTimestamp;

        RecalculateCommonSnapshot();
    }

    void TBaseTimestamp::RecalculateCommonSnapshot() {
        CommonSnapshot = TTs();
        for (const auto& i : Snapshot) {
            UpdateCommonData(CommonSnapshot, i.second.MaxValue, i.second.UpdateTimestamp, i.second.AvgsSum, i.second.AvgsPortions);
            UpdateCommonData(CommonSnapshot, i.second.MinValue, i.second.UpdateTimestamp);
        }
    }

    void TBaseTimestamp::UpdateData(TTs& ts, TTimestampValue value, ui64 updateTimestamp, TTimestampValue avgSum, ui64 avgPortions) {
        ts.MaxValue = Max(ts.MaxValue, value);
        ts.MinValue = ts.MinValue ? Min(ts.MinValue, value) : value;
        Y_ASSERT(ts.MinValue <= ts.MaxValue);
        if (avgPortions) {
            ts.AvgsSum = ts.AvgsSum + avgSum;
            ts.AvgsPortions += avgPortions;
        }
        ts.UpdateTimestamp = Max(ts.UpdateTimestamp, updateTimestamp);
    }

    void TBaseTimestamp::UpdateCommonData(TTs& ts, TTimestampValue value, ui64 updateTimestamp, TTimestampValue avgSum, ui64 avgPortions) {
        // temporary fix: do not process not timestamps values (value > Max<ui64>()) in common aggregation SAAS-4148
        if (!GetHigh(value)) {
            UpdateData(ts, value, updateTimestamp, avgSum,avgPortions);
        }
    }

    void TBaseTimestamp::UpdateStream(TStreamId streamId, TTimestampValue value, ui64 updateTimestamp, TTimestampValue avgSum, ui64 avgPortions) {
        UpdateData(Snapshot[streamId], value, updateTimestamp, avgSum, avgPortions);
        UpdateCommonData(CommonSnapshot, value, updateTimestamp, avgSum, avgPortions);
    }

    void TBaseTimestamp::UpdateNoLock(TStreamId streamId, TTimestampValue value, ui64 updateTimestamp, TTimestampValue avgSum, ui64 avgPortions) {
        UpdateStream(streamId, value, updateTimestamp, avgSum, avgPortions);
    }

    void TBaseTimestamp::MergeNoLock(const TBaseTimestamp& other, ui64 updateTimestamp) {
        for (const auto& i : other.Snapshot) {
            UpdateNoLock(i.first, i.second.MaxValue, updateTimestamp, i.second.AvgsSum, i.second.AvgsPortions);
            UpdateNoLock(i.first, i.second.MinValue, updateTimestamp);
        }
    }

    void TBaseTimestamp::ClearNocLock() {
        IStreamData<TTimestampSnapshot>::ClearNoLock();
        RecalculateCommonSnapshot();
    }


    TPositionValue TBasePositions::Get(TStreamId streamId, const TString& key) const {
        TReadGuard g(ReadWriteMutex);
        if (const auto positions = Snapshot.FindPtr(streamId)) {
            return positions->Value(key, 0);
        }
        return 0;
    }

    void TBasePositions::Update(TStreamId streamId, const TString& key, TPositionValue value) {
        TWriteGuard g(ReadWriteMutex);
        UpdateNoLock(streamId, key, value);
    }

    void TBasePositions::Merge(const TBasePositions& other) {
        if (this == &other) {
            return;
        }
        TWriteGuard gw(ReadWriteMutex);
        TReadGuard gr(other.ReadWriteMutex);
        for (const auto& positions : other.Snapshot) {
            for (const auto& i : positions.second) {
                UpdateNoLock(positions.first, i.first, i.second);
            }
        }
    }

    void TBasePositions::UpdateNoLock(TStreamId streamId, const TString& key, TPositionValue value) {
        if (auto positions = Snapshot.FindPtr(streamId)) {
            if (auto oldValue = positions->FindPtr(key)) {
                *oldValue = Max(*oldValue, value);
                return;
            }
        }
        Snapshot[streamId][key] = value;
    }

    TIndexTimestamp::TIndexTimestamp(const TString& indexDir)
        : IIndexStreamData(indexDir, "timestamp")
    {
        Read();
    }

    void TIndexTimestamp::FillProtobuf(TTimestamp& protobuf) const {
        for (const auto& [stream, value]: Snapshot) {
            const TTimestampValue min = value.MinValue;
            const TTimestampValue max = value.MaxValue;
            const ui64 portions = value.AvgsPortions;
            const ui64 avg = portions ? GetLow(value.AvgsSum) : 0;
            const ui64 updateTimestamp = value.UpdateTimestamp;

            TStreamTimestamp* st = protobuf.AddStreamTimestamp();
            st->SetStreamId(stream);
            st->SetMaxValue(GetLow(max));
            if (GetHigh(max)) st->SetMaxValueEx(GetHigh(max));
            st->SetMinValue(GetLow(min));
            if (GetHigh(min)) st->SetMinValueEx(GetHigh(min));
            st->SetAvgSumValue(avg);
            st->SetAvgPortions(portions);
            st->SetUpdateTimestamp(updateTimestamp);
        }
    }

    void TIndexTimestamp::LoadFromProtobuf(const TTimestamp& protobuf) {
        Snapshot.clear();

        for (size_t i = 0; i < protobuf.StreamTimestampSize(); ++i) {
            const auto& stream = protobuf.GetStreamTimestamp(i);
            const TStreamId streamId = stream.GetStreamId();
            const TTimestampValue min = { stream.GetMinValueEx(), stream.GetMinValue() };
            const TTimestampValue max = { stream.GetMaxValueEx(), stream.GetMaxValue() };
            const TTimestampValue avg = { 0, stream.GetAvgSumValue() };
            const ui64 portions = stream.GetAvgPortions();
            const ui64 updateTimestamp = stream.GetUpdateTimestamp();
            if (min > max) {
                throw yexception() << "incorrect timestamps for stream " << streamId << ": minimum is greater than maximum";
            }

            UpdateNoLock(streamId, min, updateTimestamp, avg, portions);
            UpdateNoLock(streamId, max, updateTimestamp);
        }

        for (const auto& [stream, value]: Snapshot) {
            const TTimestampValue min = value.MinValue;
            const TTimestampValue max = value.MaxValue;
            DEBUG_LOG << "Read timestamp of " << IndexDir.Basename() << ": " << stream << "-" << min << "-" << max << Endl;
        }
    }

    TIndexPositions::TIndexPositions(const TString& indexDir)
        : IIndexStreamData(indexDir, "positions")
    {
        Read();
    }

    void TIndexPositions::FillProtobuf(TPositions& protobuf) const {
        for (const auto& [stream, positions]: Snapshot) {
            TStreamPositions* sp = protobuf.AddStreamPositions();

            sp->SetStreamId(stream);
            for (const auto& [key, value]: positions) {
                TPosition* position =  sp->AddPositions();
                position->SetKey(key);
                position->SetValue(value);
            }
        }
    }
    void TIndexPositions::LoadFromProtobuf(const TPositions& protobuf) {
        Snapshot.clear();

        for (const auto& stream: protobuf.GetStreamPositions()) {
            const TStreamId streamId = stream.GetStreamId();
            for (const auto& position: stream.GetPositions()) {
                UpdateNoLock(streamId, position.GetKey(), position.GetValue());
                DEBUG_LOG << "Read position of " << IndexDir.Basename() << ": stream=" << streamId << " " << position.GetKey() << "->" << position.GetValue() << Endl;
            }
        }
    }

}
