#include "m2n_decoder.h"

#include <saas/rtyserver/merger/doc_extractor.h>
#include <saas/rtyserver/unistat_signals/signals.h>

#include <library/cpp/logger/global/global.h>

class TM2NDecoder::TBackDecoder {
public:
    TBackDecoder(const TOptions& options)
        : Pool(new TMemoryPool(1024 * 1024))
        , NewAddressByIdentifier(new TNewAddressByIdentifier(Pool.Get()))
        , CountMax(options.SegmentSize)
        , CountMaxDeviation(options.SegmentSize * options.SizeDeviation)
        , MaxDeadlineDocs(options.MaxDeadlineDocs)
        , Pruning(options.Pruning)
        , ReuseDocids(options.ReuseDocids)
        , EnabledSignals(options.PushSignals)
    {}

    double GetPrunRank(ui32 clusterId, ui32 docId) const {
        CHECK_WITH_LOG(NewToOld.size() > clusterId);
        return NewToOld[clusterId].GetPrunRank(docId);
    }

    void PushUnistatSignals() {
        if (EnabledSignals) {
            DecoderStats.PushUnistatSignals();
        }
    }

    ui32 AllocNewPool(ui32 idDest, ui32 shard) {
        auto it = PoolsInfo.find(std::make_pair(idDest, shard));
        if (it == PoolsInfo.end()) {
            PoolDocs.push_back(TClusterRemap(idDest, shard, CountMax, CountMaxDeviation, MaxDeadlineDocs, Pruning));
            PoolsInfo[std::make_pair(idDest, shard)] = PoolDocs.size() - 1;
            return PoolDocs.size() - 1;
        } else {
            return it->second;
        }
    }

    inline void AddSource(const IDDKManager* source) {
        Sources.push_back(source);
    }

    inline ui32 GetCurrentClusterId() const {
        return Sources.size() - 1;
    }

    inline TRTYMerger::TAddress AddInfo(ui32 oldDocid, ui32 idDest, ui32 shard, double pruningRank) {
        ui32 newClusterId = AllocNewPool(idDest, shard);
        TRTYMerger::TAddress result(REMAP_NOWHERE, REMAP_NOWHERE);
        TRTYMerger::TAddress* newAddress = &result;
        if (Sources.back()) {
            TDocSearchInfo::THash identifier = Sources.back()->GetIdentifier(oldDocid);
            std::pair<TNewAddressByIdentifier::iterator, bool> prevDoc = NewAddressByIdentifier->insert(std::make_pair(identifier, TRTYMerger::TAddress(REMAP_NOWHERE, REMAP_NOWHERE)));
            newAddress = &prevDoc.first->second;
            if (!prevDoc.second) {
                ui32 timeLiveStart = Sources.back()->GetTimeLiveStart(oldDocid);
                ui32 sourceOverride = Sources.back()->GetSourceWithNewVersion(oldDocid);
                const TRTYMerger::TAddress& prevOldAdress = DecodeByPools(newAddress->ClusterId, newAddress->DocId);
                ui32 prevTimeLiveStart = Sources[prevOldAdress.ClusterId]->GetTimeLiveStart(prevOldAdress.DocId);
                ui32 prevSourceOverride = Sources[prevOldAdress.ClusterId]->GetSourceWithNewVersion(prevOldAdress.DocId);
                if ((prevTimeLiveStart < timeLiveStart) || (prevSourceOverride && !sourceOverride)) {
                    if (Replace(*newAddress, oldDocid, newClusterId, pruningRank))
                        return *newAddress;
                } else
                    return result;
            }
        }
        PoolDocs[newClusterId].Add(GetCurrentClusterId(), oldDocid, pruningRank);
        newAddress->ClusterId = newClusterId;
        newAddress->DocId = PoolDocs[newClusterId].size() - 1;
        return *newAddress;
    }

    inline bool Replace(TRTYMerger::TAddress& newAddress, ui32 oldDocid, ui32 newClusterId, double pruningRank) {
        const bool inCluster = newAddress.ClusterId == newClusterId;
        if (inCluster) {
            if (ReuseDocids) {
                PoolDocs[newClusterId].Replace(newAddress.DocId, Sources.size() - 1, oldDocid, pruningRank);
                return true;
            }
            PoolDocs[newAddress.ClusterId].SkipDocId(newAddress.DocId); // document is added again latter in AddInfo
        } else {
            PoolDocs[newAddress.ClusterId].Remove(newAddress.DocId);
        }
        return false;
    }

    inline ui32 GetNewClusterSize(ui32 newClusterId) const {
        return NewToOld[newClusterId].GetSize();
    }

    inline ui32 GetNewClustersCount() const {
        return NewToOld.size();
    }

    inline const TRTYMerger::TAddress& Decode(ui32 newClusterId, ui32 newDocid) const{
        return NewToOld[newClusterId][newDocid].Address;
    }

    inline const TRTYMerger::TAddress& DecodeByPools(ui32 newClusterId, ui32 newDocid) const{
        return PoolDocs[newClusterId][newDocid].Address;
    }

    void Remap(ui32 iDest, const TVector<ui32>& remap) {
        NewToOld[iDest].Remap(remap);
    }

    void Finalize() {
        VERIFY_WITH_LOG(NewToOld.empty(), "Incorrect class usage");
        for (ui32 i = 0; i < PoolDocs.size(); ++i) {
            TVector<TClusterDecoder> decoder = PoolDocs[i].Finalize(Sources);
            DecoderStats.Merge(PoolDocs[i].GetRemovalStats());
            NewToOld.insert(NewToOld.end(), decoder.begin(), decoder.end());
        }
        RTY_MEM_LOG("Before clear addition structures in M2N decoder");
        NewAddressByIdentifier.Reset(nullptr);
        Pool.Reset(nullptr);
        Sources.clear();
        RTY_MEM_LOG("After clear addition structures in M2N decoder");
    }

    void Print(const char* prefix) const {
        Y_UNUSED(prefix);
//        for (TDecoder::const_iterator i = NewToOld.begin(); i != NewToOld.end(); ++i) {
//            for (ui32 j = 0; j < i->GetSize(); ++j)
//                DEBUG_LOG << prefix << (*i)[j].Address.ClusterId << "-" << (*i)[j].Address.DocId << " -> " << i - NewToOld.begin() << "-" << j << Endl;
//        }
    }

    TVector<TClusterInfo> GetClusters(ui32 idDest) const {
        TVector<TClusterInfo> result;
        for (ui32 i = 0; i < NewToOld.size(); ++i) {
            if (NewToOld[i].GetIdDest() == idDest) {
                result.push_back(TClusterInfo(NewToOld[i].GetShard(), i));
            }
        }
        return result;
    }

private:
    typedef TVector<const IDDKManager*> TSources;
    const static double RemovedPruningRank;

    struct TDoc {
        TDoc()
            : Address(REMAP_NOWHERE, REMAP_NOWHERE)
            , PruningRank(RemovedPruningRank)
        {}

        TDoc(ui32 oldClusterId, ui32 oldDocid, double pruningRank)
            : Address(oldClusterId, oldDocid)
            , PruningRank(pruningRank)
        {}

        inline bool operator < (const TDoc& other) const {
            return PruningRank > other.PruningRank;
        }

        TRTYMerger::TAddress Address;
        double PruningRank;
    };

    class TClusterDecoder {
    private:
        ui32 IdDest;
        ui32 Shard;
        TDoc* Begin;
        size_t Size;

    public:

        TClusterDecoder(TDoc* begin, size_t size, ui32 idDest, ui32 shard) {
            Begin = begin;
            Size = size;
            IdDest = idDest;
            Shard = shard;
        }

        void Remap(const TVector<ui32>& remap) {
            VERIFY_WITH_LOG(Size == remap.size(), "Incorrect condition: %lu != %lu", Size, remap.size());
            TVector<TDoc> result(Size);
            ui32 removeCount = 0;
            for (ui32 i = 0; i < remap.size(); ++i) {
                if (remap[i] != REMAP_NOWHERE) {
                    result[remap[i]] = *(Begin + i);
                } else {
                    ++removeCount;
                }
            }
            bool foundRemoved = false;
            for (ui32 i = 0; i < result.size(); ++i) {
                if (result[i].Address.DocId != REMAP_NOWHERE) {
                    CHECK_WITH_LOG(!foundRemoved);
                    result[i].PruningRank = result.size() - i + 1;
                    *(Begin + i) = result[i];
                } else {
                    foundRemoved = true;
                    (Begin + i)->PruningRank = RemovedPruningRank;
                }
            }
            Size -= removeCount;
        }

        double GetPrunRank(ui32 toDocId) const {
            CHECK_WITH_LOG(toDocId < Size);
            return Begin[toDocId].PruningRank;
        }

        const TDoc& operator[] (ui32 index) const {
            CHECK_WITH_LOG(index < Size);
            return Begin[index];
        }

        size_t GetSize() const {
            return Size;
        }

        ui32 GetIdDest() const {
            return IdDest;
        }

        ui32 GetShard() const {
            return Shard;
        }

        void SetSize(ui32 size) {
            Size = size;
        }
    };

    class TClusterRemap : public TVector<TDoc> {
    public:
        TClusterRemap(ui32 idDest, ui32 shard, ui32 countMax, ui32 countMaxDeviation, ui32 maxDeadlineDocs, bool pruning)
            : NeedSort(false)
            , RemoveCount(0)
            , SkippedDocIdCount(0)
            , IdDest(idDest)
            , Shard(shard)
            , CountMax(countMax)
            , CountMaxDeviation(countMaxDeviation)
            , MaxDeadlineDocs(maxDeadlineDocs)
            , Pruning(pruning)
        {}

        TVector<TClusterDecoder> Finalize(const TSources& sources) {
            ProcessMaxDeadLineDocs(sources);
            if (NeedSort || RemoveCount) {
                StableSort(begin(), end());
                NeedSort = false;
            }
            if (RemoveCount) {
                resize(size() - RemoveCount);
                RemoveCount = 0;
                SkippedDocIdCount = 0;
            }

            TVector<TClusterDecoder> result;
            auto it = begin();
            while (it != end()) {
                result.push_back(TClusterDecoder(&*it, 0, IdDest, Shard));
                ui32 maxDelta = Min<ui32>(size() - (it - begin()), CountMax + CountMaxDeviation);
                ui32 minDelta = Min<ui32>(size() - (it - begin()), CountMax - CountMaxDeviation);
                auto itNext = it + maxDelta;
                double pr = -1;
                if (itNext != end()) {
                    pr = itNext->PruningRank;
                }
                bool keyFound = false;
                while (itNext != it + minDelta && itNext != end()) {
                    itNext--;
                    if (pr != itNext->PruningRank) {
                        keyFound = true;
                        break;
                    }
                    pr = itNext->PruningRank;
                }
                if (!keyFound) {
                    result.back().SetSize(maxDelta);
                } else {
                    result.back().SetSize(itNext - it);
                }
                it += result.back().GetSize();
            }
            return result;

        }

        const TM2NDecoder::TDecoderStats& GetRemovalStats() const {
            return RemovalStats;
        }

        inline void SkipDocId(ui32 docid) {
            Remove(docid);
            ++SkippedDocIdCount;
        }

        inline void Remove(ui32 docid) {
            at(docid).PruningRank = RemovedPruningRank;
            ++RemoveCount;
        }

        inline void Add(ui32 oldClusterId, ui32 oldDocid, double pruningRank) {
            NeedSort |= !empty() && pruningRank > back().PruningRank;
            push_back(TDoc(oldClusterId, oldDocid, pruningRank));
        }

        inline void Replace(ui32 newDocid, ui32 oldClusterId, ui32 oldDocid, double pruningRank) {
            at(newDocid) = TDoc(oldClusterId, oldDocid, pruningRank);
            if (newDocid < size() - 1)
                NeedSort |= at(newDocid + 1).PruningRank > pruningRank;
            if (newDocid > 0)
                NeedSort |= pruningRank > at(newDocid - 1).PruningRank;
        }

    private:
        struct TDeadLineComparer {
            inline TDeadLineComparer(const TClusterRemap& remap, const TSources& sources)
                : Remap(remap)
                , Sources(sources)
            {}

            inline bool operator()(const size_t& i, const size_t& j) const {
                const TDoc& jDoc = Remap.at(j);
                const TDoc& iDoc = Remap.at(i);
                if (jDoc.PruningRank == RemovedPruningRank)
                    return iDoc.PruningRank != RemovedPruningRank;
                if (iDoc.PruningRank == RemovedPruningRank)
                    return false;
                ui32 iDL = Sources[iDoc.Address.ClusterId]->GetDeadlineIfEnabled(iDoc.Address.DocId);
                ui32 jDL = Sources[jDoc.Address.ClusterId]->GetDeadlineIfEnabled(jDoc.Address.DocId);
                return iDL > jDL;
            }
            const TClusterRemap& Remap;
            const TSources& Sources;
        };

        void ProcessMaxDeadLineDocs(const TSources& sources) {
            if (!sources.size() || !sources.front())
                return;
            const ui32 mdld = MaxDeadlineDocs;
            if (!mdld || size() - RemoveCount <= mdld)
                return;
            if (!Pruning) {
                TVector<size_t> idxArray;
                idxArray.reserve(size());
                for (size_t i = 0; i < size(); ++i) {
                    idxArray.push_back(i);
                }
                Sort(idxArray, TDeadLineComparer(*this, sources));
                const auto& maxDeadlineDoc = (*this)[idxArray[mdld]].Address; // mdld < size() - RemoveCount
                const TInstant maxDeadline = TInstant::Minutes(sources[maxDeadlineDoc.ClusterId]->GetDeadlineIfEnabled(maxDeadlineDoc.DocId));
                RemovalStats.AddSurplusDocs(size() - RemoveCount - mdld, maxDeadline);
                for (ui32 end = size() - RemoveCount, i = mdld; i < end; ++i) {
                    auto& entry = (*this)[idxArray[i]];
                    Y_ASSERT(entry.PruningRank != RemovedPruningRank);
                    entry.PruningRank = RemovedPruningRank;
                }
            }
            RemoveCount = size() - mdld;
        }

        bool NeedSort;
        ui32 RemoveCount;
        ui32 SkippedDocIdCount;
        ui32 IdDest;
        ui32 Shard;
        ui32 CountMax;
        ui32 CountMaxDeviation;
        ui32 MaxDeadlineDocs;
        bool Pruning;
        TDecoderStats RemovalStats;
    };

    typedef THashMap<TDocSearchInfo::THash, TRTYMerger::TAddress, THash<TDocSearchInfo::THash>, TEqualTo<TDocSearchInfo::THash>, TPoolAlloc<TRTYMerger::TAddress> > TNewAddressByIdentifier;
    typedef TVector<TClusterRemap> TMatrix;
    typedef TVector<TClusterDecoder> TDecoder;
    TMatrix PoolDocs;
    TDecoder NewToOld;
    THolder<TMemoryPool> Pool;
    THolder<TNewAddressByIdentifier> NewAddressByIdentifier;
    TSources Sources;
    ui32 CountMax;
    ui32 CountMaxDeviation;
    ui32 MaxDeadlineDocs;
    bool Pruning;
    bool ReuseDocids;
    TDecoderStats DecoderStats;
    TMap<std::pair<ui32, ui32>, ui32> PoolsInfo;
    bool EnabledSignals = false;
};

const double TM2NDecoder::TBackDecoder::RemovedPruningRank = -Max<double>();

//TM2NDecoder
TM2NDecoder::TM2NDecoder(const TOptions& options)
    : BackDecoder(new TBackDecoder(options))
    , OldDocsCount(0)
{
}

TM2NDecoder::~TM2NDecoder()
{}

ui32 TM2NDecoder::GetNewClustersCount() const {
    return BackDecoder->GetNewClustersCount();
}

TRTYMerger::TAddress TM2NDecoder::Decode(ui32 clusterId, ui32 docId) const {
    InitDirectDecoder();
    if (DirectDecoder->size() > clusterId && DirectDecoder->at(clusterId).size() > docId)
        return DirectDecoder->at(clusterId)[docId];
    //Print();
    FAIL_LOG("Incorrect address in decoder: %u, %u", clusterId, docId);
    return TRTYMerger::TAddress(REMAP_NOWHERE, REMAP_NOWHERE);
}

void TM2NDecoder::PatchDestMap(ui32 dstClusterId, const TVector<ui32>& remap) {
    Print("PDM ");
    DEBUG_LOG << "remap order:" << Endl;
    for (ui32 i = 0; i < remap.size(); ++i)
        DEBUG_LOG << i << " -> " << remap[i] << Endl;
    BackDecoder->Remap(dstClusterId, remap);
    DirectDecoder.Reset(nullptr);
}

bool TM2NDecoder::Check(ui32 clusterId, ui32 docId) const {
    InitDirectDecoder();
    return (DirectDecoder->size() > clusterId && DirectDecoder->at(clusterId).size() > docId);
}

void TM2NDecoder::Finalize() {
    BackDecoder->Finalize();
    BackDecoder->PushUnistatSignals();
    Print("Fin ");
}

void TM2NDecoder::InitDirectDecoder() const {
    if (!DirectDecoder) {
        TGuard<TMutex> g(DirectDecoderMutex);
        if (!DirectDecoder) {
            DirectDecoder.Reset(new TDirectDecoder(OldClusterSizes.size()));
            for (size_t cl = 0; cl < DirectDecoder->size(); ++cl)
                DirectDecoder->at(cl).resize(OldClusterSizes[cl], TRTYMerger::TAddress(REMAP_NOWHERE, REMAP_NOWHERE));
            for (ui32 cl = 0; cl < BackDecoder->GetNewClustersCount(); ++cl) {
                for (ui32 docid = 0; docid < BackDecoder->GetNewClusterSize(cl); ++docid) {
                    const TRTYMerger::TAddress& oldAddr = BackDecoder->Decode(cl, docid);
                    if (DirectDecoder->at(oldAddr.ClusterId).size() <= oldAddr.DocId)
                        DirectDecoder->at(oldAddr.ClusterId).resize(oldAddr.DocId + 1, TRTYMerger::TAddress(REMAP_NOWHERE, REMAP_NOWHERE));
                    TRTYMerger::TAddress& newAddr = DirectDecoder->at(oldAddr.ClusterId)[oldAddr.DocId];
                    newAddr.ClusterId = cl;
                    newAddr.DocId = docid;
                }
            }
        }
    }
}

ui32 TM2NDecoder::GetSizeOfCluster(ui32 clusterId) const {
    if (OldClusterSizes.size() > clusterId)
        return OldClusterSizes[clusterId];
    else
        return 0;
}

ui32 TM2NDecoder::GetSize() const {
    return OldDocsCount;
}

bool TM2NDecoder::IsValidClusterId(ui32 clusterId) const {
    return OldClusterSizes.size() > clusterId;
}

void TM2NDecoder::Print(const char* pref) const {
    BackDecoder->Print(pref);
}

void TM2NDecoder::AddInfo(const IDDKManager* source) {
    BackDecoder->AddSource(source);
    OldClusterSizes.push_back(0);
}

void TM2NDecoder::AddInfo(ui32 docId, i32 destinationId, const TDocPlaceInfo& docPlace, ui32 segmentShard) {
    ui32 clusterId = BackDecoder->GetCurrentClusterId();
    ++OldDocsCount;
    ++OldClusterSizes.back();
    if (destinationId != -1) {
//        TRTYMerger::TAddress newAddress =
          BackDecoder->AddInfo(docId, destinationId, segmentShard, docPlace.PruningRank);
//        DEBUG_LOG << "Add " << clusterId << "-" << docId << "-" << segmentShard << "/" << docPlace.Shard << "/" << docPlace.PruningRank << "->" << newAddress.ClusterId << "-" << newAddress.DocId << Endl;
    } else {
        DEBUG_LOG << "Add " << clusterId << "-" << docId << "->" << "REMOVE" << Endl;
    }
}

TVector<TClusterInfo> TM2NDecoder::GetClusters(ui32 destId) const {
    return BackDecoder->GetClusters(destId);
}

ui32 TM2NDecoder::GetNewDocsCount(ui32 clusterId) const {
    if (clusterId == Max<ui32>()) {
        ui32 result = 0;
        for (ui32 cl = 0; cl < BackDecoder->GetNewClustersCount(); ++cl)
            result += BackDecoder->GetNewClusterSize(cl);
        return result;
    }
    if (clusterId < BackDecoder->GetNewClustersCount())
        return BackDecoder->GetNewClusterSize(clusterId);
    return 0;
}

bool TM2NDecoder::NewToOld(ui32 clusterId, ui32 docId, TRTYMerger::TAddress& addr) const {
    if (clusterId >= BackDecoder->GetNewClustersCount() || docId >= BackDecoder->GetNewClusterSize(clusterId))
        return false;
    addr = BackDecoder->Decode(clusterId, docId);
    return true;
}

void TM2NDecoder::TDecoderStats::AddSurplusDocs(ui32 count, TInstant maxDeadline) {
    RemovedDeadlineDocs += count;
    MaxRemovedDeadline = std::max(MaxRemovedDeadline, maxDeadline);
}

void TM2NDecoder::TDecoderStats::Merge(const TDecoderStats& stats) {
    AddSurplusDocs(stats.RemovedDeadlineDocs, stats.MaxRemovedDeadline);
}

void TM2NDecoder::TDecoderStats::PushUnistatSignals() const {
    TSaasRTYServerSignals::UpdateSurplusDocCount(RemovedDeadlineDocs, MaxRemovedDeadline);
}
