#include "bitmap_index.h"
#include "iterators.h"
#include "bitmap.h"

#include <solomon/libs/cpp/intern/str_pool.h>

#include <util/generic/vector.h>
#include <util/generic/utility.h>
#include <util/generic/cast.h>

#include <util/string/join.h>

#include <contrib/libs/croaring/cpp/roaring.hh>

using namespace NSolomon::NIntern;

namespace NSolomon::NSearch {
namespace {

struct TOneKeyIndex {
public:
    const TStringId LabelKey;
    const IBitmapPtr LabelValues;
    const THashMap<TStringId, IBitmapPtr> HitsByValue;
    const IBitmapPtr AllHits;

    TOneKeyIndex(TStringId labelKey, IBitmapPtr&& labelValues, THashMap<TStringId, IBitmapPtr>&& hitsByValue, IBitmapPtr&& allHits)
        : LabelKey(std::move(labelKey))
        , LabelValues(std::move(labelValues))
        , HitsByValue(std::move(hitsByValue))
        , AllHits(std::move(allHits))
    {
    }

public:
    using TPtr = THolder<TOneKeyIndex>;
};


class TOneKeyIndexBuilder {
public:
    void Add(const TStringId labelValue, ui32 metricId) {
        LabelValues_.add(labelValue);
        HitsByValue_[labelValue].add(metricId);
        AllHits_.add(metricId);
    }

    TOneKeyIndex::TPtr Finalize(TStringId key) {
        auto values = Optimize(std::move(LabelValues_));

        THashMap<TStringId, IBitmapPtr> valueBitmaps(values->Size());
        for (auto& [k, v]: HitsByValue_) {
            valueBitmaps[k] = Optimize(std::move(v));;
        }

        auto allHits = Optimize(std::move(AllHits_));

        return MakeHolder<TOneKeyIndex>(key, std::move(values), std::move(valueBitmaps), std::move(allHits));
    }

private:
    roaring::Roaring LabelValues_;
    roaring::Roaring AllHits_;
    THashMap<TStringId, roaring::Roaring> HitsByValue_;
};


class TBitmapSearchResult final: public ISearchResult {
public:
    explicit TBitmapSearchResult(roaring::Roaring roaring)
        : Roaring_(std::move(roaring))
        , Size_(Roaring_.cardinality())
        , It_(Roaring_.begin())
        , End_(Roaring_.end())
    {
    }

    size_t Size() const noexcept override {
        return Size_;
    }

    ui32 NextId() override {
        if (It_ == End_) {
            return npos;
        }
        ui32 id = *It_;
        ++It_;
        return id;
    }

    TVector<ui32> ToVector() const override {
        TVector<ui32> result;
        result.resize(Size_);
        Roaring_.toUint32Array(result.data());
        return result;
    }

public:
    static ISearchResultPtr FromRoaring(roaring::Roaring&& r, ui32 max) {
        if (r.cardinality() <= max) {
            return new TBitmapSearchResult(r);
        } else {
            IBitmapPtr truncated = Optimize(r, max);
            return new TBitmapSearchResult(truncated->ToRoaring());
        }
    }

    const roaring::Roaring& ToRoaring() const {
        return Roaring_;

    }

    static void UnpackToCombine(roaring::Roaring& r, const ISearchResult& s) {
        const TBitmapSearchResult* bs = dynamic_cast<const TBitmapSearchResult*>(&s);
        if (bs) {
            r |= bs->ToRoaring();
        } else {
            TVector<ui32> vals = s.ToVector();
            r.addMany(vals.size(), vals.data());
        }
    }

private:
    const roaring::Roaring Roaring_;
    const size_t Size_;
    roaring::RoaringSetBitForwardIterator It_;
    roaring::RoaringSetBitForwardIterator End_;
};

class TBitmapIndex: public IBitmapSegment {
protected:
    TAtomicSharedPtr<TStringPool> StringPool_;
    THashMap<TStringId, TOneKeyIndex::TPtr> Keys_;
    IBitmapPtr AllMetricIds_;
    roaring::Roaring Removals_;

protected:
    friend class TBitmapIndexMerger;

public:
    TBitmapIndex() = default;

    void AssignFrom(TBitmapIndex& builder) {
        StringPool_ = std::move(builder.StringPool_);
        Keys_ = std::move(builder.Keys_);
        AllMetricIds_ = std::move(builder.AllMetricIds_);
        Y_ASSERT(builder.Removals_.isEmpty());
    }

    TAtomicSharedPtr<TStringPool> GetSharedStringPool() const {
        return StringPool_;
    }

public:
    IBitmapPtr GetAllMetrics() const override {
        return AllMetricIds_;
    }

    size_t Size() const override {
        return AllMetricIds_ ? AllMetricIds_->Size() : 0;
    }

    void Remove(const IBitmap& b) override {
        roaring::Roaring ids = b.ToRoaring();
        roaring::Roaring all = AllMetricIds_->ToRoaring();
        roaring::Roaring removed = ids & all;
        if (!removed.isEmpty()) {
            all -= removed;
            AllMetricIds_ = Optimize(all);
            Removals_ |= removed;
        }
    }

    size_t AllocatedBytes() const override {
        return Keys_.size() * (sizeof(TStringId) + sizeof(TOneKeyIndex::TPtr))
            + (AllMetricIds_ ? AllMetricIds_->MemBytes() : 0)
            + Removals_.getSizeInBytes();
    }

private:
    // return true iff the selector allows key to be unknown
    // it is OK if key is not present for selectors like {key=-} or {key!=*}
    // Please note that this magic does not apply to things like key!=a* or key!=*|-
    static bool AllowedUnknownKey(const TSelector& selector) {
        if (selector.Negative()) {
            return selector.Type() == EMatcherType::ANY;
        } else {
            return selector.Match(nullptr);
        }
    }

    IBitmapPtr Match(const TOneKeyIndex& key, const IMatcher& m) const {
        Y_ASSERT(m.Type() != EMatcherType::MULTI);

        const NIntern::TStringPool& stringPool = *StringPool_;
        switch (m.Type()) {
        case EMatcherType::EXACT: {
            TStringId id = stringPool.Find(m.Pattern());
            if (id == InvalidStringId)
                return IBitmap::Empty();
            auto it = key.HitsByValue.find(id);
            if (it == key.HitsByValue.end())
                return IBitmap::Empty();
            return it->second;
        }

        case EMatcherType::ANY: {
            return key.AllHits;
        }

        case EMatcherType::ABSENT: {
            return {}; // ids without the label will be added afterwards
        }

        case EMatcherType::GLOB:
        case EMatcherType::REGEX: {
            return OrMatch(key, m);
        }

        case EMatcherType::MULTI:
            return IBitmap::Empty();
        }
    }

    // find all label values that match the predicate, then combine their lists of hits through Or()
    IBitmapPtr OrMatch(const TOneKeyIndex& key, const IMatcher& predicate) const {
        TVector<IBitmapPtr> matched;
        const NIntern::TStringPool& stringPool = *StringPool_;

        key.LabelValues->Consume([&](IBitmap::TId labelId) {
            const TStringBuf labelValue = stringPool.Find(labelId);
            if (predicate.Match(labelValue)) {
                auto p = key.HitsByValue.find(labelId);
                Y_ASSERT(p != key.HitsByValue.end());
                matched.push_back(p->second);
            }
        });

        return Or(matched);
    }

    IBitmapPtr OrMatchMulti(const TOneKeyIndex& key, const IMultiMatcher& matchers) const {
        TVector<IBitmapPtr> matched;
        for (size_t i = 0; i < matchers.Size(); ++i) {
            const IMatcher& matcher = *matchers.Get(i);
            switch (matcher.Type()) {
            case EMatcherType::EXACT:
            case EMatcherType::ANY:
            case EMatcherType::ABSENT:
            case EMatcherType::GLOB:
            case EMatcherType::REGEX: {
                IBitmapPtr res = Match(key, matcher);
                if (res)
                    matched.push_back(res);
                break;
            }
            case EMatcherType::MULTI: {
                Y_VERIFY(0, "nested MULTI Matchers are forbidden");
                return IBitmap::Empty();
            }
            }
        }

        return Or(matched);
    }

    ISearchResultPtr AllIds(ui32 max) const {
        return TBitmapSearchResult::FromRoaring(AllMetricIds_->ToRoaring(), max);
    }


public:
    ISearchResultPtr Search(const TSelectors& selectors, ui32 max) const override {
        if (selectors.empty()) {
            return AllIds(max);
        }

        // fast check: return empty if a selector is exact and a key is unknown
        for (const auto& selector: selectors) {
            if (selector.IsExact()) {
                Y_ASSERT(!selector.Negative());
                TStringId keyId = StringPool_->Find(selector.Key());
                if (Keys_.find(keyId) == Keys_.end()) {
                    return new TEmptySearchResult;
                }
            }
        }

        TVector<IBitmapPtr> inclusive, exclusive;
        for (const auto& selector: selectors) {
            const TOneKeyIndex* key = nullptr;
            TStringId keyId = StringPool_->Find(selector.Key());
            if (auto it = Keys_.find(keyId); it != Keys_.end()) {
                key = it->second.Get();
            } else {
                //FIXME: With IStorage, this behaves incorrectly - AllowedUnknownKey should be checked across all segments
                if (!AllowedUnknownKey(selector)) {
                    return new TEmptySearchResult;
                } else {
                    // skip because the selector allows key to be unknown
                    continue;
                }
            }

            IBitmapPtr bitmap;
            const auto selectorType = selector.Type();
            if (selectorType != EMatcherType::MULTI) {
                const IMatcher& matcher = *selector.MatcherPtr();
                bitmap = Match(*key, matcher);
            } else {
                const IMultiMatcher* mm = selector.MultiMatcherPtr();
                bitmap = OrMatchMulti(*key, *mm);
            }

            if (selector.Negative()) {
                bool isNotAny = selectorType == EMatcherType::ANY;
                if (!isNotAny) {
                    // a negative selector mandates that the key exists
                    inclusive.push_back(key->AllHits);
                }
            } else if (selectorType != EMatcherType::EXACT) {
                if (selector.MatcherPtr()->MatchesAbsent()) {
                    // selector allows the key to be absent
                    if (selectorType == EMatcherType::ABSENT) {
                        // shortcut: will calculate the negation later
                        Y_ASSERT(!bitmap);
                        exclusive.push_back(key->AllHits);
                    } else if (bitmap.Get() == key->AllHits.Get()) {
                        // shortcut: this is -|* expression - even the same pointer - do nothing
                        bitmap.Drop();
                    } else {
                        // no shortcut is possible
                        roaring::Roaring r = AllMetricIds_->ToRoaring();
                        key->AllHits->AndNot(r);
                        if (bitmap)
                            bitmap->Or(r);
                        bitmap = Optimize(std::move(r));
                    }
                }
            }

            if (bitmap) {
                auto& masks = selector.Negative() ? exclusive : inclusive;
                masks.push_back(bitmap);
            }
        }

        if (inclusive.empty()) {
            if (exclusive.empty()) {
                return AllIds(max);
            } else {
                inclusive.push_back(AllMetricIds_);
            }
        } else if (!Removals_.isEmpty()) {
            inclusive.push_back(AllMetricIds_);
        }

        auto bitmap = Combine(inclusive, exclusive, max);
        return new TBitmapSearchResult(bitmap->ToRoaring());
    };
};

class TBitmapIndexBuilder : public TBitmapIndex, public ISearchIndexBuilder {
private:
    THashMap<TStringId, TOneKeyIndexBuilder> KeyBuilders_;

public:
    TBitmapIndexBuilder(TAtomicSharedPtr<TStringPool> interner) {
        if (!interner) {
            interner = MakeAtomicShared<TStringPool>();
        }
        StringPool_ = interner;
    }

    void Add(const TVector<std::pair<NMonitoring::ILabelsPtr, ui32>>& targets) override {
        for (const auto& [target, metricId]: targets) {
            for (size_t i = 0, size = target->Size(); i < size; i++) {
                const auto* label = target->Get(i);
                TStringId keyId = StringPool_->Intern(label->Name());
                TStringId valueId = StringPool_->Intern(label->Value());
                KeyBuilders_[keyId].Add(valueId, metricId);
            }
        }
    }

    void MarkRemoved(const TVector<ui32>& metricIds) override {
        roaring::Roaring r;
        r.addMany(metricIds.size(), metricIds.data());
        IBitmapPtr bitmap = Optimize(r);
        Remove(*bitmap);
    }

    ISearchIndexPtr Finalize() override {
        TVector<IBitmapPtr> allMetrics;
        for (auto& [keyId, builder]: KeyBuilders_) {
            auto keyData = builder.Finalize(keyId);
            if (allMetrics.size() < 16) {
                allMetrics.push_back(keyData->AllHits);
            } else {
                IBitmapPtr merged = Or(allMetrics);
                allMetrics.clear();
                allMetrics.emplace_back(std::move(merged));
            }
            Keys_[keyId] = std::move(keyData);
        }
        AllMetricIds_ = Or(allMetrics);
        return DoFinalize();
    }

private:
    virtual ISearchIndexPtr DoFinalize() {
        auto index = MakeIntrusive<TBitmapIndex>();
        index->AssignFrom(*this);
        return index;
    }
};

class TBitmapIndexMerger : public TBitmapIndex {
public:
    TBitmapIndexMerger(const TAtomicSharedPtr<TStringPool>& stringPool)
    {
        Y_ENSURE(stringPool);
        StringPool_ = stringPool;
    }

    void MergeFrom(const TBitmapIndex& source) {
        Y_ENSURE(source.StringPool_.Get() == StringPool_.Get(), "merged segments should use a shared interner");

        IBitmapPtr hits1Filter = !Removals_.isEmpty() ? AllMetricIds_ : IBitmapPtr{};
        IBitmapPtr hits2Filter = !source.Removals_.isEmpty() ? source.AllMetricIds_ : IBitmapPtr{};
        MergeKeys(Keys_, hits1Filter, source.Keys_, hits2Filter);
        if (AllMetricIds_) {
            AllMetricIds_ = Or({AllMetricIds_, source.AllMetricIds_});
        } else {
            AllMetricIds_ = source.AllMetricIds_;
        }

        Removals_ = {};
    }

    virtual TIntrusivePtr<IBitmapSegment> Finalize() {
        auto index = MakeIntrusive<TBitmapIndex>();
        index->AssignFrom(*this);
        return index;
    }

private:
    using TKeys = THashMap<TStringId, TOneKeyIndex::TPtr>;
    using TValues = THashMap<TStringId, IBitmapPtr>;

    static TOneKeyIndex::TPtr MergeKey(const TOneKeyIndex& key1, const IBitmapPtr& hits1Filter, const TOneKeyIndex& key2, const IBitmapPtr& hits2Filter) {
        Y_ASSERT(key1.LabelKey == key2.LabelKey);
        roaring::Roaring labelValues;
        TValues mergedMap;
        for (const auto& [l, hits1]: key1.HitsByValue) {
            IBitmapPtr hitsA = !hits1Filter ? hits1 : And({hits1, hits1Filter});

            IBitmapPtr merged;
            auto it = key2.HitsByValue.find(l);
            if (it == key2.HitsByValue.end()) {
                merged = std::move(hitsA);
            } else {
                const IBitmapPtr& hits2 = it->second;
                IBitmapPtr hitsB = !hits2Filter ? hits2 : And({hits2, hits2Filter});
                merged = Or({hitsA, hitsB});
            }
            if (merged->Size()) {
                mergedMap[l] = merged;
                labelValues.add(l);
            }
        }
        for (const auto& [l, hits2]: key2.HitsByValue) {
            auto it = mergedMap.find(l);
            if (it == mergedMap.end()) {
                IBitmapPtr hitsB = !hits2Filter ? hits2 : And({hits2, hits2Filter});
                if (hitsB->Size()) {
                    mergedMap[l] = hitsB;
                    labelValues.add(l);
                }
            }
        }

        IBitmapPtr allHitsA = !hits1Filter ? key1.AllHits : And({key1.AllHits, hits1Filter});
        IBitmapPtr allHitsB = !hits2Filter ? key2.AllHits : And({key2.AllHits, hits2Filter});
        IBitmapPtr allHits = Or({allHitsA, allHitsB});
        return MakeHolder<TOneKeyIndex>(key1.LabelKey, Optimize(labelValues), std::move(mergedMap), std::move(allHits));
    }

    static TOneKeyIndex::TPtr MergeKey(const TOneKeyIndex& key1, const IBitmapPtr& hits1Filter) {
        Y_ASSERT(hits1Filter);
        roaring::Roaring labelValues;
        TValues mergedMap;
        for (const auto& [l, hits1]: key1.HitsByValue) {
            IBitmapPtr hits = !hits1Filter ? hits1 : And({hits1, hits1Filter});
            if (hits->Size()) {
                mergedMap[l] = hits;
                labelValues.add(l);
            }
        }
        return MakeHolder<TOneKeyIndex>(key1.LabelKey, Optimize(labelValues), std::move(mergedMap), And({key1.AllHits, hits1Filter}));
    }

    static TOneKeyIndex::TPtr CloneKey(const TOneKeyIndex& key2) {
        TStringId labelKey = key2.LabelKey;
        IBitmapPtr labelValues = key2.LabelValues;
        THashMap<TStringId, IBitmapPtr> hitsByValue = key2.HitsByValue;
        IBitmapPtr allHits = key2.AllHits;
        return MakeHolder<TOneKeyIndex>(labelKey, std::move(labelValues), std::move(hitsByValue), std::move(allHits));
    }

    static void MergeKeys(TKeys& keys1, const IBitmapPtr& hits1Filter, const TKeys& keys2, const IBitmapPtr& hits2Filter) {
        for (auto&& [k, idx]: keys1) {
            auto it = keys2.find(k);
            if (it == keys2.end()) {
                if (!hits1Filter) {
                    continue;
                } else {
                    idx = MergeKey(*idx, hits1Filter);
                }
            } else {
                TOneKeyIndex& idx2 = *it->second;
                idx = MergeKey(*idx, hits1Filter, idx2, hits2Filter);
            }
        }

        for (auto&& [k, idx]: keys2) {
            auto it = keys1.find(k);
            if (it == keys1.end()) {
                TOneKeyIndex::TPtr clone = !hits2Filter ? CloneKey(*idx) : MergeKey(*idx, hits2Filter);
                keys1.emplace(k, std::move(clone));
            }
        }
    }
};


} // namespace

THolder<ISearchIndexBuilder> CreateBitmapIndexBuilder(const TAtomicSharedPtr<TStringPool>& interner) {
    return MakeHolder<TBitmapIndexBuilder>(interner);
}

TIntrusivePtr<IBitmapSegment> CastToSegment(const TIntrusivePtr<ISearchIndex>& object) {
    return TIntrusivePtr<IBitmapSegment>(dynamic_cast<IBitmapSegment*>(object.Get()));
}

TIntrusivePtr<IBitmapSegment> MergeSegments(const TVector<IBitmapSegment::TPtr>& segments) {
    Y_ENSURE(!segments.empty());
    THolder<TBitmapIndexMerger> merger;
    TVector<const TBitmapIndex*> bitmapIndexes;

    for (const auto& pSegment: segments) {
        const TBitmapIndex* impl = dynamic_cast<TBitmapIndex*>(pSegment.Get());
        Y_ENSURE(impl, "MergeSegments expects TBitmapIndex args");

        if (!merger) {
            auto stringPool = impl->GetSharedStringPool();
            merger = MakeHolder<TBitmapIndexMerger>(stringPool);
        }
        merger->MergeFrom(*impl);
    }
    return merger->Finalize();
}

ISearchResultPtr CombineSegmentResults(const TVector<ISearchResultPtr>& segmentResults, ui32 max) {
    ISearchResult::TPtr result0 = nullptr;
    auto it = segmentResults.begin();
    for (; it != segmentResults.end(); ++it) {
        if ((*it)->Size() == 0) {
            continue;
        }
        if (!result0) {
            result0 = *it;
            continue;
        }
        break;
    }
    if (!result0) {
        return new TEmptySearchResult;
    }
    if (it == segmentResults.end() || result0->Size() >= max) {
        return result0;
    } else {
        // more than one non-empty result
        roaring::Roaring r;
        TBitmapSearchResult::UnpackToCombine(r, *result0);
        for (; it != segmentResults.end(); ++it) {
            if ((*it)->Size() != 0) {
                TBitmapSearchResult::UnpackToCombine(r, **it);
                if (r.cardinality() >= max)
                    break;
            }
        }
        return TBitmapSearchResult::FromRoaring(std::move(r), max);
    }
}

ISearchResultPtr Truncate(const ISearchResultPtr& segmentResult, ui32 limit) {
    roaring::Roaring r;
    TBitmapSearchResult::UnpackToCombine(r, *segmentResult);
    return TBitmapSearchResult::FromRoaring(std::move(r), limit);
}
} // namespace NSolomon::NSearch
