#include "str_pool_hash_map.h"
#include "arena.h"
#include "str_ref_count.h"

#include <library/cpp/containers/absl_flat_hash/flat_hash_map.h>
#include <util/generic/vector.h>
#include <util/str_stl.h>

#include <memory>

namespace NSolomon::NIntern {

constexpr ui32 IdBitCount = 28;
constexpr ui32 IdMsk = (1 << IdBitCount) - 1;
static_assert(sizeof(TStringId) == 4);
constexpr ui32 GroupBitCount = 32 - IdBitCount;
constexpr ui32 StartGroup = 3; // skip 1, 2, 4
constexpr ui32 GroupCount = (1 << GroupBitCount) - StartGroup;

namespace {
/*
 * Conains Arena with fixed allocation size
 */
class TBlockArena : private TMoveOnly {
private:
    TArena Arena_;
    TStringId LastId_ = 1;
    TVector<TStringId> FreeId_;
    const size_t BlockSize_;

    static size_t CalcArenaSize(size_t PageSize, size_t blockSize) {
        Y_VERIFY(blockSize > 0, "Block cannot be empty");
        if (blockSize >= PageSize) {
            return blockSize;
        } else {
            return PageSize - PageSize % blockSize;
        }
    }
public:
    TBlockArena(size_t PageSize, size_t blockSize)
        : Arena_(CalcArenaSize(PageSize, blockSize), TDefaultAllocator::Instance())
        , BlockSize_(blockSize)
    {
        TStringRefCount::CreateFromString(Arena_, "", BlockSize_);
    }

    TBlockArena(TBlockArena&& oth) = default;
    TBlockArena& operator=(TBlockArena&& oth) = default;

    std::pair<TStringId, TStringBuf> NewId(TStringBuf strBuf) {
        if (!FreeId_.empty()) {
            auto id = FreeId_.back();
            FreeId_.pop_back();
            auto strCopy = TStringRefCount::UpdateOffset(Arena_, id * BlockSize_, strBuf);
            return {id, strCopy};
        }
        auto strCopy = TStringRefCount::CreateFromString(Arena_, strBuf, BlockSize_);
        return {LastId_++, strCopy};
    }

    void AddRef(TStringId stringId) {
        (*Arena_.Ptr<ui32*>(stringId * BlockSize_))++;
    }

    ui32 DecRef(TStringId stringId) {
        return --(*Arena_.Ptr<ui32*>(stringId * BlockSize_));
    }

    TStringBuf Find(TStringId id) const {
        if (id >= LastId_) {
            id = InvalidStringId;
        }
        return TStringRefCount::CreateFromOffset(Arena_, id * BlockSize_);
    }

    size_t Size() const {
        return LastId_ - FreeId_.size() - 1u;
    }

    void Update(TStringId id, TStringBuf strBuf) {
        TStringRefCount::UpdateOffset(Arena_, id * BlockSize_, strBuf);
    }

    void Remove(TStringId id) {
        FreeId_.push_back(id);
        Update(id, "");
    }

    size_t AllocatedBytes() const {
        return sizeof(*this) + Arena_.AllocatedBytes() + FreeId_.capacity() * sizeof(size_t);
    }
};
} // namespace

struct THashMapStringPool::TImpl {
private:
    std::pair<TStringId, TStringBuf> NewId(TStringBuf strBuf) noexcept;
    static constexpr size_t PageSize = 4096;
public:
    TImpl();

    TStringId Intern(TStringBuf strBuf);

    TStringBuf Find(TStringId id) const noexcept;

    TStringId Find(TStringBuf strBuf) const noexcept;

    size_t Size() const noexcept;

    size_t AllocatedBytes() const noexcept;

    void Remove(TStringId id);

    ui32 GetGroup(size_t strSize);

private:
    absl::flat_hash_map<TStringBuf, TStringId, THash<TStringBuf>> StringToId_;
    TVector<TBlockArena> Strings_;
    constexpr static size_t MaxBlockSize = static_cast<size_t>(Max<ui16>()) + 1;
};

THashMapStringPool::TImpl::TImpl() {
    for (size_t groupId = 0; groupId < GroupCount; ++groupId) {
        Strings_.emplace_back(PageSize, 1u << (groupId + StartGroup));
    }
}

ui32 THashMapStringPool::TImpl::GetGroup(size_t strSize) {
    Y_VERIFY(strSize + TStringRefCount::NonStringSize() <= MaxBlockSize,
             "Cannot store such long string. Str length: %zu, max length %zu",
             strSize, MaxBlockSize - TStringRefCount::NonStringSize());
    return MostSignificantBit(strSize + TStringRefCount::NonStringSize() - 1) - StartGroup + 1;
}

std::pair<TStringId, TStringBuf> THashMapStringPool::TImpl::NewId(TStringBuf strBuf) noexcept {
    ui32 groupId = GetGroup(strBuf.size());
    auto [id, str] = Strings_[groupId].NewId(strBuf);
    return {(groupId << IdBitCount) | id, str};
}

NIntern::TStringId THashMapStringPool::TImpl::Intern(TStringBuf strBuf) {
    auto it = StringToId_.find(strBuf);
    TStringId newId;
    if (it == StringToId_.end()) {
        auto [id, strCopy] = NewId(strBuf);
        auto[_, success] = StringToId_.try_emplace(strCopy, id);
        Y_VERIFY(success, "Can't emplace into hash map: str=%s, id=%ul", strCopy.data(), id);
        newId = id;
    } else {
        newId = it->second;
    }
    Strings_[newId >> IdBitCount].AddRef(newId & IdMsk);
    return newId;
}

TStringBuf THashMapStringPool::TImpl::Find(TStringId id) const noexcept {
    return Strings_[id >> IdBitCount].Find(id & IdMsk);
}

TStringId THashMapStringPool::TImpl::Find(TStringBuf strBuf) const noexcept {
    if (auto it = StringToId_.find(strBuf); it == StringToId_.end()) {
        return InvalidStringId;
    } else {
        return it->second;
    }
}

size_t THashMapStringPool::TImpl::Size() const noexcept {
    size_t sum = 0;
    for (auto& group: Strings_) {
        sum += group.Size();
    }
    return sum;
}

size_t THashMapStringPool::TImpl::AllocatedBytes() const noexcept {
    size_t sum = sizeof(*this);

    for (const auto& arena: Strings_) {
        sum += arena.AllocatedBytes();
    }
    sum += StringToId_.capacity() * (sizeof(TStringBuf) + sizeof(TStringId));
    return sum;
}

void THashMapStringPool::TImpl::Remove(TStringId id) {
    Y_VERIFY((id & IdMsk) != InvalidStringId, "Can't remove InvalidStringId");
    if (Strings_[id >> IdBitCount].DecRef(id & IdMsk) == 0) {
        auto strRef = Strings_[id >> IdBitCount].Find(id & IdMsk);
        StringToId_.erase(strRef);
        Strings_[id >> IdBitCount].Remove(id & IdMsk);
    }
}

THashMapStringPool::THashMapStringPool()
    : Impl_(new THashMapStringPool::TImpl)
{
}

THashMapStringPool::~THashMapStringPool() = default;

TStringId THashMapStringPool::Intern(TStringBuf strBuf) {
    auto id = Impl_->Intern(strBuf);
    return id;
}

TStringBuf THashMapStringPool::Find(TStringId id) const noexcept {
    return Impl_->Find(id);
}

size_t THashMapStringPool::Size() const noexcept {
    return Impl_->Size();
}

size_t THashMapStringPool::AllocatedBytes() const noexcept {
    return Impl_->AllocatedBytes();
}

void THashMapStringPool::Remove(TStringId id) {
    return Impl_->Remove(id);
}

TStringId THashMapStringPool::Find(TStringBuf str) const noexcept {
    return Impl_->Find(str);
}

} // namespace NSolomon::NIntern
