#include "bitmap.h"

#include <util/generic/vector.h>
#include <util/generic/ptr.h>

#include <contrib/libs/croaring/cpp/roaring.hh>

namespace NSolomon::NSearch {
namespace {

class TEmptyBitmap final: public IBitmap {
    size_t Size() const noexcept override {
        return 0;
    }

    size_t MemBytes() const noexcept override {
        return sizeof(*this);
    }

    void Consume(const TIdConsumer&) const override {
    }

    void Or(roaring::Roaring&) const noexcept override {
    }

    void And(roaring::Roaring& r) const noexcept override {
        r = ToRoaring();
    }

    void AndNot(roaring::Roaring&) const noexcept override {
    }

    roaring::Roaring ToRoaring() const noexcept override {
        return {};
    }
};

class TSmallBitmap final: public IBitmap {
public:
    explicit TSmallBitmap(TVector<TId> ids)
        : Ids_(std::move(ids))
    {
    }

private:
    size_t MemBytes() const noexcept override {
        return sizeof(*this) + Ids_.capacity() * sizeof(TId);
    }

    size_t Size() const noexcept override {
        return Ids_.size();
    }

    void Consume(const TIdConsumer& fn) const override {
        for (TId id: Ids_) {
            fn(id);
        }
    }

    void Or(roaring::Roaring& r) const noexcept override {
        r.addMany(Ids_.size(), Ids_.data());
    }

    void And(roaring::Roaring& r) const noexcept override {
        roaring::Roaring mask(Ids_.size(), Ids_.data());
        r &= mask;
    }

    void AndNot(roaring::Roaring& r) const noexcept override {
        for (TId id: Ids_) {
            r.remove(id);
        }
    }

    roaring::Roaring ToRoaring() const noexcept override {
        roaring::Roaring result;
        result.addMany(Ids_.size(), Ids_.data());
        return result;
    }

private:
    const std::vector<TId> Ids_;
};

class TRoaringBitmap final: public IBitmap {
public:
    explicit TRoaringBitmap(roaring::Roaring r)
        : R_(std::move(r))
    {
    }

private:
    size_t MemBytes() const noexcept override {
        return sizeof(*this) + R_.getSizeInBytes();
    }

    size_t Size() const noexcept override {
        return R_.cardinality();
    }

    void Consume(const TIdConsumer& fn) const override {
        for (TId id: R_) {
            fn(id);
        }
    }

    void Or(roaring::Roaring& r) const noexcept override {
        r |= R_;
    }

    void And(roaring::Roaring& r) const noexcept override {
        r &= R_;
    }

    void AndNot(roaring::Roaring& r) const noexcept override {
        r -= R_;
    }

    roaring::Roaring ToRoaring() const noexcept override {
        return R_;
    }

private:
    const roaring::Roaring R_;
};


void TruncateR(roaring::Roaring& result, ui32 limit) {
    uint32_t hitRangeEnd;
    if (result.select(limit, &hitRangeEnd)) {
        roaring::Roaring mask;
        mask.addRange(0, hitRangeEnd);
        result &= mask;
    }
}
} // namespace

IBitmapPtr IBitmap::Empty() {
    return new TEmptyBitmap();
}

IBitmapPtr Optimize(const roaring::Roaring& r) {
    size_t size = r.cardinality();
    if (size == 0) {
        return new TEmptyBitmap();
    }

    if (size <= 16) {
        TVector<IBitmap::TId> ids;
        ids.resize(size);
        r.toUint32Array(ids.data());
        return new TSmallBitmap(std::move(ids));
    }

    roaring::Roaring r1(r);
    r1.runOptimize();
    r1.shrinkToFit();

    return new TRoaringBitmap(std::move(r1));
}

IBitmapPtr Optimize(const roaring::Roaring& r, ui32 max) {
    if (r.cardinality() <= max) {
        return Optimize(r);
    } else {
        roaring::Roaring r2 = r;
        TruncateR(r2, max);
        return Optimize(r2);
    }
}

IBitmapPtr Or(const TVector<IBitmapPtr>& bitmaps) {
    size_t size = bitmaps.size();
    if (size == 0) {
        return new TEmptyBitmap();
    }

    if (size == 1) {
        return bitmaps[0];
    }

    roaring::Roaring result;
    for (auto& bitmap: bitmaps) {
        bitmap->Or(result);
    }
    return Optimize(std::move(result));
}

IBitmapPtr And(const TVector<IBitmapPtr>& bitmaps) {
    size_t size = bitmaps.size();
    if (size == 0) {
        return new TEmptyBitmap();
    }

    if (size == 1) {
        return bitmaps[0];
    }

    roaring::Roaring result = bitmaps[0]->ToRoaring();
    for (size_t i = 1, size = bitmaps.size(); i < size; i++) {
        bitmaps[i]->And(result);
    }

    return Optimize(std::move(result));
}

IBitmapPtr Combine(
    const TVector<IBitmapPtr>& inclusive,
    const TVector<IBitmapPtr>& exclusive,
    ui32 limit)
{
    Y_ASSERT(!inclusive.empty() || exclusive.empty());

    if (inclusive.empty()) {
        // the "only exclusion" case should be handled by the caller
        return new TEmptyBitmap;
    }

    IBitmapPtr combined;
    combined = (inclusive.size() > 1) ? And(inclusive) : inclusive[0];

    if (exclusive.empty() && combined->Size() <= limit) {
        return combined;
    } else {
        auto result = combined->ToRoaring();
        for (const auto& exc: exclusive) {
            exc->AndNot(result);
        }

        if (result.cardinality() > limit) {
            TruncateR(result, limit);
        }
        return Optimize(std::move(result));
    }
}

// TODO(yrum): Invert() is not used
IBitmapPtr Invert(const IBitmap& bitmap, ui32 maxCount) {
    roaring::Roaring result = bitmap.ToRoaring();
    result.flip(0, maxCount);
    return Optimize(std::move(result));
}


} // namespace NSolomon::NSearch
