#include "repository.h"

#include <library/cpp/consistent_hashing/consistent_hashing.h>

#include <util/digest/city.h>
#include <util/system/guard.h>

#include <cmath>
#include <memory>

namespace NRateSrv::NStorage {

TRepository::TRepository(size_t partCount)
    : Parts(partCount)
{}

void TRepository::Read(TReadCallback callback) {
    bool stop = false;

    for (auto& part : Parts) {
        std::shared_lock partGuard(part.Lock);

        for (const auto& [key, counterPair] : part.Values) {
            auto counterGuard = Guard(counterPair.second);
            if (!callback(key, counterPair.first)) {
                stop = true;
                break;
            }
        }

        if (stop) {
            break;
        }
    }
}

void TRepository::Update(TUpdateHandles handles, bool createNonExistent) {
    TSeparatedUpdateHandles separatedHandles;
    for (auto& handle : handles) {
        size_t partNum = CalcPartNumber(handle.Key);
        separatedHandles[partNum].push_back(std::move(handle));
    }
    UpdateParts(
        std::move(separatedHandles),
        createNonExistent ? EUpdatePolicy::Read : EUpdatePolicy::ReadWithSkipNonExistent);
}

void TRepository::Erase(TEraseKeys keys, TEraseCallback callback) {
    TSeparatedEraseKeys separatedKeys;
    for (auto& key : keys) {
        size_t partNum = CalcPartNumber(key);
        separatedKeys[partNum].push_back(std::move(key));
    }
    EraseFromParts(std::move(separatedKeys), std::move(callback));
}

size_t TRepository::Size() {
    size_t size = 0;
    for (auto& part : Parts) {
        size += part.Size.load(std::memory_order_acquire);
    }
    return size;
}

size_t TRepository::CalcPartNumber(const std::string& key) {
    ui64 hashVal = CityHash64(key.c_str(), key.size());
    size_t partNum = ConsistentHashing(hashVal, Parts.size());
    return partNum;
}

void TRepository::UpdateParts(TSeparatedUpdateHandles separatedHandles, EUpdatePolicy policy) {
    for (auto& [partNum, handles] : separatedHandles) {
        auto& part = Parts[partNum];
        UpdatePart(part, std::move(handles), policy);
    }
}

void TRepository::UpdatePart(TPart& part, TUpdateHandles handles, EUpdatePolicy policy) {
    std::shared_lock readGuard(part.Lock, std::defer_lock);
    std::unique_lock writeGuard(part.Lock, std::defer_lock);

    (policy == EUpdatePolicy::Write) ? writeGuard.lock() : readGuard.lock();

    TUpdateHandles nonExistentCounters;
    auto newSize = 0ULL;

    for (auto& handle : handles) {
        auto counterIt = part.Values.find(handle.Key);
        if (counterIt == part.Values.end()) {
            switch (policy) {
                case EUpdatePolicy::Write:
                    part.Values.emplace(
                        std::move(handle.Key),
                        std::make_pair(handle.Callback(nullptr), TAdaptiveLock())
                    );
                    ++newSize;
                    break;
                case EUpdatePolicy::Read:
                    nonExistentCounters.push_back(std::move(handle));
                    break;
                case EUpdatePolicy::ReadWithSkipNonExistent:
                    handle.Callback(nullptr);
                    break;
            }
        } else {
            auto& [counter, lock] = counterIt->second;
            auto counterGuard = Guard(lock);
            counter = handle.Callback(&counter);
        }
    }

    if (newSize > 0) {
        part.Size.fetch_add(newSize, std::memory_order_acq_rel);
    }

    if (!nonExistentCounters.empty()) {
        if (readGuard.owns_lock()) {
            readGuard.unlock();
        } else if (writeGuard.owns_lock()) {
            writeGuard.unlock();
        }
        UpdatePart(part, std::move(nonExistentCounters), EUpdatePolicy::Write);
    }
}

void TRepository::EraseFromParts(TSeparatedEraseKeys separatedKeys, TEraseCallback callback) {
    for (auto& [partNum, keys] : separatedKeys) {
        auto eraseSize = 0ULL;
        auto& part = Parts[partNum];
        std::unique_lock partGuard(part.Lock);

        for (auto& key : keys) {
            auto counterIt = part.Values.find(key);
            if (counterIt == part.Values.end()) {
                continue;
            }
            auto& [counter, lock] = counterIt->second;
            if (callback(counter)) {
                part.Values.erase(counterIt);
                ++eraseSize;
            }
        }

        if (eraseSize > 0) {
            part.Size.fetch_sub(eraseSize, std::memory_order_acq_rel);
        }
    }
}

} // namespace NRateSrv::NStorage
