#pragma once

#include "settings.h"
#include <common/botpeer.h>
#include <common/types.h>
#include <yplatform/algorithm/leaky_map.h>

namespace botserver::message_processor {

class otp_rate_limiter
{
public:
    otp_rate_limiter(
        optional<rate_limit_settings> botpeer_limit,
        optional<rate_limit_settings> uid_limit)
        : botpeer_limit(botpeer_limit)
        , uid_limit(uid_limit)
        , botpeer_counters(
              botpeer_limit ? botpeer_limit->attempts : 0u,
              botpeer_limit ? to_ms(botpeer_limit->period) : 0u)
        , uid_counters(
              uid_limit ? uid_limit->attempts : 0u,
              uid_limit ? to_ms(uid_limit->period) : 0u)
    {
    }

    bool acquire(botpeer botpeer, string uid)
    {
        lock_guard<struct mutex> lock(mutex);
        if (botpeer_limit_exceeded(botpeer) || uid_limit_exceeded(uid))
        {
            return false;
        }
        increase_botpeer_counter(botpeer);
        increase_uid_counter(uid);
        return true;
    }

private:
    uint64_t to_ms(duration d)
    {
        return duration_cast<milliseconds>(d).count();
    }

    bool botpeer_limit_exceeded(botpeer botpeer)
    {
        if (!botpeer_limit) return false;
        return botpeer_counters.get(botpeer) >= botpeer_limit->attempts;
    }

    bool uid_limit_exceeded(string uid)
    {
        if (!uid_limit) return false;
        return uid_counters.get(uid) >= uid_limit->attempts;
    }

    void increase_botpeer_counter(botpeer botpeer)
    {
        if (!botpeer_limit) return;
        botpeer_counters.add(botpeer, 1u);
    }

    void increase_uid_counter(string uid)
    {
        if (!uid_limit) return;
        uid_counters.add(uid, 1u);
    }

    optional<rate_limit_settings> botpeer_limit;
    optional<rate_limit_settings> uid_limit;
    yplatform::leaky_map<botpeer, uint64_t> botpeer_counters;
    yplatform::leaky_map<string, uint64_t> uid_counters;
    mutex mutex;
};

using otp_rate_limiter_ptr = shared_ptr<otp_rate_limiter>;

}