#include <maps/wikimap/mapspro/libs/assessment/include/rate_limiter.h>
#include <maps/wikimap/mapspro/libs/assessment/impl/magic_strings.h>
#include <maps/wikimap/mapspro/libs/common/include/yandex/maps/wiki/common/string_utils.h>
#include <maps/libs/common/include/exception.h>
#include <maps/libs/enum_io/include/enum_io.h>

namespace maps::wiki::assessment {

namespace {

const std::string TOTAL = "total";
const std::string SECTION_PREFIX = "assessment-";

social::UserActivity loadUserActivity(
    pqxx::transaction_base& txn,
    TUid uid,
    const std::vector<std::chrono::seconds>& timeIntervals,
    std::optional<Entity::Domain> entityDomain)
{
    if (timeIntervals.empty()) {
        return {};
    }

    auto commonQuery = ", COUNT(1) FROM " + sql::table::GRADE;
    if (entityDomain) {
        commonQuery += " JOIN " + sql::table::UNIT + " USING (" + sql::col::UNIT_ID + ")";
    }
    commonQuery += " WHERE " + sql::col::GRADED_BY + "=" + std::to_string(uid) + " AND ";
    if (entityDomain) {
        commonQuery += sql::col::ENTITY_DOMAIN + "=" +
            txn.quote(std::string(toString(*entityDomain))) + " AND ";
    }

    std::vector<std::string> queryParts;
    for (const auto& interval : timeIntervals) {
        auto intervalStr = std::to_string(interval.count());
        queryParts.push_back(
            "SELECT " + intervalStr + commonQuery +
            sql::col::GRADED_AT + " > (NOW() - INTERVAL '" + intervalStr + " sec')");
    }

    social::UserActivity activity;
    auto r = txn.exec(common::join(queryParts, " UNION ALL "));
    for (const auto& row : r) {
        activity.emplace(
            std::chrono::seconds(row[0].as<size_t>()),
            row[1].as<size_t>());
    }
    return activity;
}

std::optional<std::chrono::seconds> checkLimitExceededInternal(
    pqxx::transaction_base& txn,
    TUid uid,
    const social::RateLimiterConfigLoader::Limits& limits,
    std::optional<Entity::Domain> entityDomain)
{
    REQUIRE(uid, "Invalid uid");

    auto activity = loadUserActivity(txn, uid, limits.timeIntervals, entityDomain);
    return social::findLimitExceeded(limits, activity);
}

} // namespace

RateLimiter::RateLimiter(const maps::xml3::Node& node)
{
    if (node.isNull()) {
        return;
    }

    social::RateLimiterConfigLoader loader(node);
    totalLimits_ = loader.loadLimits(SECTION_PREFIX + TOTAL);

    for (auto entityDomain : enum_io::enumerateValues<Entity::Domain>()) {
        std::string entityDomainStr(toString(entityDomain));
        auto limits = loader.loadLimits(SECTION_PREFIX + entityDomainStr);
        if (!limits.timeIntervals.empty()) {
            entityDomainLimits_.emplace(entityDomain, std::move(limits));
        }
    }
}

std::optional<std::chrono::seconds> RateLimiter::checkTotalLimitExceeded(
    pqxx::transaction_base& txn,
    TUid uid) const
{
    return checkLimitExceededInternal(txn, uid, totalLimits_, std::nullopt);
}

std::optional<std::chrono::seconds> RateLimiter::checkLimitExceeded(
    pqxx::transaction_base& txn,
    TUid uid,
    Entity::Domain entityDomain) const
{
    auto it = entityDomainLimits_.find(entityDomain);
    if (it == entityDomainLimits_.end()) {
        return std::nullopt;
    }
    return checkLimitExceededInternal(txn, uid, it->second, entityDomain);
}

} // namespace maps::wiki::assessment
