#include "bucket_rate_limiter.h"

#include <util/system/yassert.h>

#include <vector>

using namespace quasar;

BucketRateLimiter::BucketRateLimiter(uint32_t bucketsCount,
                                     uint32_t limit,
                                     std::unordered_set<std::string> filteredEvents)
    : limit_(limit)
{
    Y_VERIFY(bucketsCount > 0);
    Y_VERIFY(limit > 0);

    for (const auto& event : filteredEvents) {
        counters_.emplace(event, EventCounter(bucketsCount));
    }
}

void BucketRateLimiter::setOverflowOnBucketRotatingCb(std::function<void(const OverflowInfo&)> overflowOnBucketRotatingCb) {
    std::scoped_lock guard(mutex_);

    overflowOnBucketRotatingCb_ = std::move(overflowOnBucketRotatingCb);
}

void BucketRateLimiter::setRotatingBucketsPeriod(std::chrono::milliseconds rotatingBucketsPeriod) {
    rotatingBucketsPeriod_ = rotatingBucketsPeriod;
}

void BucketRateLimiter::start() {
    std::scoped_lock guard(mutex_);

    bucketRotator_ = std::make_unique<PeriodicExecutor>(
        std::bind(&BucketRateLimiter::rotateBuckets, this),
        rotatingBucketsPeriod_,
        PeriodicExecutor::PeriodicType::SLEEP_FIRST);
}

void BucketRateLimiter::stop() {
    std::scoped_lock guard(mutex_);

    bucketRotator_.reset();
}

BucketRateLimiter::~BucketRateLimiter() {
    stop();
}

BucketRateLimiter::OverflowStatus BucketRateLimiter::addEvent(const std::string& event) {
    std::scoped_lock guard(mutex_);
    auto it = counters_.find(event);
    if (it == counters_.end()) {
        return BucketRateLimiter::OverflowStatus::NOT_OVERFLOWED;
    }

    auto& counter = it->second;
    counter.addEvent();
    if (counter.total() > limit_) {
        return BucketRateLimiter::OverflowStatus::OVERFLOWED;
    }

    return BucketRateLimiter::OverflowStatus::NOT_OVERFLOWED;
}

void BucketRateLimiter::rotateBuckets() {
    std::vector<OverflowInfo> overflowInfos;
    decltype(overflowOnBucketRotatingCb_) overflowOnBucketRotatingCbCopy;

    {
        std::scoped_lock guard(mutex_);

        overflowOnBucketRotatingCbCopy = overflowOnBucketRotatingCb_;

        for (auto& [event, counter] : counters_) {
            if (overflowOnBucketRotatingCbCopy &&
                counter.total() > limit_) {
                overflowInfos.push_back({
                    .event = event,
                    .limit = limit_,
                    .eventCount = counter.total(),
                });
            }
            counter.rotateBuckets();
        }
    }

    if (overflowOnBucketRotatingCbCopy) {
        for (const auto& overflowInfo : overflowInfos) {
            overflowOnBucketRotatingCb_(overflowInfo);
        }
    }
}

BucketRateLimiter::EventCounter::EventCounter(uint32_t bucketsCount)
    : bucketCounters_(bucketsCount, 0)
{
}

void BucketRateLimiter::EventCounter::addEvent() {
    ++bucketCounters_.front();
    ++total_;
}

void BucketRateLimiter::EventCounter::rotateBuckets() {
    total_ -= bucketCounters_.back();
    bucketCounters_.pop_back();
    bucketCounters_.push_front(0);
}

uint32_t BucketRateLimiter::EventCounter::total() const {
    return total_;
}

bool BucketRateLimiter::OverflowInfo::operator==(const BucketRateLimiter::OverflowInfo& rhs) const {
    return this->event == rhs.event && this->limit == rhs.limit && this->eventCount == rhs.eventCount;
}
