#include "cluster.h"

#include <maps/wikimap/mapspro/libs/acl/include/exception.h>
#include <yandex/maps/wiki/common/pg_utils.h>
#include <yandex/maps/wiki/common/pg_advisory_lock_ids.h>
#include <yandex/maps/wiki/common/string_utils.h>
#include <maps/libs/common/include/exception.h>

#include <map>
#include <set>

using namespace std::string_literals;

namespace maps::wiki::acl {

namespace {
const auto QUEUE_TABLE = "acl.cluster_update_queue"s;

std::set<ID>
getRolesPermissions(const std::set<ID>& roleIds, Transaction& work)
{
    if (roleIds.empty()) {
        return {};
    }
    auto rows =
        work.exec(
            "SELECT acl.array_merge_agg(p.leaf_ids) "
            "FROM acl.role_permission r, acl.permission p "
            "WHERE p.id=r.permission_id AND " +
            common::whereClause("r.role_id", roleIds));
    if (rows.empty() || rows[0][0].is_null()) {
        return {};
    }
    ASSERT(rows.size() == 1);
    auto permissionIds = common::parseSqlArray<ID>(rows[0][0].c_str());
    return std::set<ID>(permissionIds.begin(), permissionIds.end());
}

std::set<ID>
getGroupsPermissions(const std::set<ID>& groupIds, Transaction& work)
{
    if (groupIds.empty()) {
        return {};
    }
    auto rows =
        work.exec("SELECT acl.array_merge_agg(p.leaf_ids) "
            "FROM acl.policy po, acl.role_permission r, acl.permission p "
            "WHERE po.role_id = r.role_id AND p.id=r.permission_id AND " +
            common::whereClause("po.agent_id", groupIds));
    if (rows.empty() || rows[0][0].is_null()) {
        return {};
    }
    ASSERT(rows.size() == 1);
    auto permissionIds = common::parseSqlArray<ID>(rows[0][0].c_str());
    return std::set<ID>(permissionIds.begin(), permissionIds.end());
}

void lockForUpdate(Transaction& work)
{
    work.exec(
        "SELECT pg_advisory_xact_lock(" +
        std::to_string(static_cast<int64_t>(common::AdvisoryLockIds::ACL_CLUSTERS)) +
        ")");
}

} // namespace

ClusterManager::ClusterManager(Transaction& work)
    : work_(work)
{
}

ID
ClusterManager::requestCluster(
        const std::set<ID>& groupIds,
        const std::set<ID>& roleIds)
{
    if (groupIds.empty() && roleIds.empty()) {
        return false;
    }
    std::string groupIdsArray = "NULL";
    if (!groupIds.empty()) {
        groupIdsArray = "ARRAY[" + common::join(groupIds, ',') + "]::bigint[]";
    }
    std::string roleIdsArray = "NULL";
    if (!roleIds.empty()) {
        roleIdsArray = "ARRAY[" + common::join(roleIds, ',') + "]::bigint[]";
    }
    std::string query =
        "SELECT cluster_id FROM acl.cluster WHERE "
        "role_ids " + (roleIds.empty() ? " IS NULL " : ("=" + roleIdsArray)) + " AND "
        " group_ids " + (groupIds.empty() ? " IS NULL " : ("=" + groupIdsArray)) + ";";
    auto rows = work_.exec(query);
    if (!rows.empty()) {
        return rows[0][0].as<ID>();
    }
    auto rolesPermissionsLeafs = getRolesPermissions(roleIds, work_);
    auto groupsPermissionsLeafs = getGroupsPermissions(groupIds, work_);
    if (rolesPermissionsLeafs.empty() && groupsPermissionsLeafs.empty()) {
        return 0;
    }
    std::set<ID> allLeafs;
    allLeafs.insert(rolesPermissionsLeafs.begin(), rolesPermissionsLeafs.end());
    allLeafs.insert(groupsPermissionsLeafs.begin(), groupsPermissionsLeafs.end());
    std::string insert =
        "INSERT INTO acl.cluster (role_ids, group_ids, permission_leaf_ids)"
        " VALUES (" + roleIdsArray +
            "," + groupIdsArray +
            ", ARRAY[" + common::join(allLeafs, ",") +
            "]) RETURNING cluster_id";
    auto insertedRows = work_.exec(insert);
    return insertedRows[0][0].as<ID>();
}

void
ClusterManager::assignCluster(ID agentId, ID clusterId)
{
    work_.exec("DELETE FROM acl.agent_cluster WHERE agent_id=" +
        std::to_string(agentId));
    if (clusterId) {
        work_.exec(
            "INSERT INTO acl.agent_cluster (agent_id, cluster_id) "
            "VALUES (" + std::to_string(agentId) + ", " + std::to_string(clusterId) +
                ");");
    }
}

void
ClusterManager::removeClustersWithoutRefs()
{
    lockForUpdate(work_);
    const std::string FROM_CONDITION =
        " FROM acl.cluster c WHERE "
        " NOT EXISTS "
        "   (SELECT * FROM acl.agent_cluster ac WHERE"
        "    ac.cluster_id=c.cluster_id) ";
    auto rows = work_.exec(
        "SELECT cluster_id " +
        FROM_CONDITION +
        " ORDER BY cluster_id ASC FOR UPDATE;");
    if (rows.empty()) {
        return;
    }
    std::vector<ID> clusterIdsToRemove;
    clusterIdsToRemove.reserve(rows.size());
    for (const auto& row : rows) {
        clusterIdsToRemove.emplace_back(row[0].as<ID>());
    }
    work_.exec(
        "DELETE FROM acl.cluster WHERE " +
        common::whereClause("cluster_id", clusterIdsToRemove));
}

std::set<ID>
ClusterManager::keyGroupAffectedClusters(ID groupId) const
{
    auto rows =
        work_.exec(
            "SELECT cluster_id FROM acl.cluster WHERE "
            "group_ids @> ARRAY[" + std::to_string(groupId) + "]::bigint[]");
    if (rows.empty()) {
        return {};
    }
    std::set<ID> clusterIds;
    for (const auto& row : rows) {
        ASSERT(!row[0].is_null());
        clusterIds.insert(row[0].as<ID>());
    }
    return clusterIds;
}

std::set<ID>
ClusterManager::keyRoleAffectedClusters(ID roleId) const
{
    auto rows =
        work_.exec(
            "SELECT cluster_id FROM acl.cluster WHERE "
            "role_ids @> ARRAY[" + std::to_string(roleId) + "]::bigint[]");
    std::set<ID> clusterIds;
    for (const auto& row : rows) {
        ASSERT(!row[0].is_null());
        clusterIds.insert(row[0].as<ID>());
    }
    ACLGateway gw(work_);
    try {
        const auto policies = gw.role(roleId).policies();
        for (const auto& policy : policies) {
            auto byGroup = keyGroupAffectedClusters(policy.agentId());
            clusterIds.insert(byGroup.begin(), byGroup.end());
        }
    } catch (const RoleNotExists& ex) {
    }
    return clusterIds;
}

namespace {
std::set<ID>
selectAllIds(const std::string& idColumn, const std::string& table, Transaction& work)
{
    auto rows = work.exec("SELECT " + idColumn + " FROM " + table);
    std::set<ID> result;
    for (const auto& row : rows) {
        result.insert(row[0].as<ID>());
    }
    return result;
}

std::set<ID>
allRolesIds(Transaction& work)
{
    return selectAllIds("id", "acl.role", work);
}

std::set<ID>
allGroupIds(Transaction& work)
{
    return selectAllIds("id", "acl.group", work);
}

std::set<ID>
removeNotExisting(
    const std::vector<ID>& ids,
    const std::set<ID>& dbExisting,
    const std::string& table,
    Transaction& work)
{
    if (ids.empty()) {
        return {};
    }
    std::set<ID> result;
    if (!dbExisting.empty()) {
        for (const auto& id : ids) {
            if (dbExisting.count(id)) {
                result.insert(id);
            }
        }
        return result;
    }
    auto exisitingRolesRows =
            work.exec(
                "SELECT id FROM " + table +
                " WHERE " + common::whereClause("id", ids));
    for (const auto& row : exisitingRolesRows) {
        result.insert(row[0].as<ID>());
    }
    return result;
}

void updateCluster(
        ID clusterId,
        const std::set<ID>& dbExistingGroups,
        const std::set<ID>& dbExistingRoles,
        Transaction& work)
{
    auto clusterRows =
        work.exec(
            "SELECT role_ids, group_ids FROM acl.cluster WHERE cluster_id=" +
            std::to_string(clusterId));
    if (clusterRows.empty()) {
        return;
    }
    ASSERT(clusterRows.size() == 1);
    const auto& clusterRow = clusterRows[0];
    const auto& roleField = clusterRow["role_ids"];
    std::set<ID> existingRoleIds;
    if (!roleField.is_null()) {
        auto roleIds = common::parseSqlArray<ID>(roleField.c_str());
        existingRoleIds = removeNotExisting(roleIds, dbExistingRoles, "acl.role", work);
    }
    const auto& groupField = clusterRow["group_ids"];
    std::set<ID> existingGroupIds;
    if (!groupField.is_null()) {
        auto groupIds = common::parseSqlArray<ID>(groupField.c_str());
        existingGroupIds = removeNotExisting(groupIds, dbExistingGroups, "acl.group", work);
    }
    auto rolesPermissionsLeafs = getRolesPermissions(existingRoleIds, work);
    auto groupsPermissionsLeafs = getGroupsPermissions(existingGroupIds, work);
    const auto stringClusterId = std::to_string(clusterId);
    if (rolesPermissionsLeafs.empty() && groupsPermissionsLeafs.empty()) {
        work.exec(
            "DELETE FROM acl.agent_cluster WHERE cluster_id=" + stringClusterId +
            ";DELETE FROM acl.cluster WHERE cluster_id=" + stringClusterId);
        return;
    }
    std::set<ID> allLeafs;
    allLeafs.insert(rolesPermissionsLeafs.begin(), rolesPermissionsLeafs.end());
    allLeafs.insert(groupsPermissionsLeafs.begin(), groupsPermissionsLeafs.end());
    work.exec(
        "UPDATE acl.cluster SET "
        " role_ids = " +
            (existingRoleIds.empty()
                ? std::string("NULL,")
                : "ARRAY[" + common::join(existingRoleIds, ",") + "]::bigint[],") +
        " group_ids = " +
            (existingGroupIds.empty()
                ? std::string("NULL,")
                : "ARRAY[" + common::join(existingGroupIds, ",") + "]::bigint[],") +
        " permission_leaf_ids = ARRAY[" + common::join(allLeafs, ",") + "]::bigint[]"
        " WHERE cluster_id = " + stringClusterId);
}

void
updateClustersLocal(
    const std::set<ID>& clusterIds,
    Transaction& work)
{
    const auto dbAllGroupsIds = allGroupIds(work);
    const auto dbAllRolesIds = allRolesIds(work);
    for (const auto& id : clusterIds) {
        updateCluster(id, dbAllGroupsIds, dbAllRolesIds, work);
    }
}

} // namespace

void
ClusterManager::updateClusters(const std::set<ID>& clusterIds)
{
    if (clusterIds.empty()) {
        return;
    }
    lockForUpdate(work_);
    updateClustersLocal(clusterIds, work_);
}

void
ClusterManager::enqueueAll()
{
    removeClustersWithoutRefs();
    enqueueClusters(selectAllIds("cluster_id", "acl.cluster", work_));
}

void
ClusterManager::releaseClusterByAgentId(ID agentId)
{
    work_.exec(
        "DELETE FROM acl.agent_cluster WHERE agent_id=" +
        std::to_string(agentId));
}

namespace {
struct ClusterKey {
    std::set<maps::wiki::acl::ID> groupIds;
    std::set<maps::wiki::acl::ID> roleIds;

    bool operator <(const ClusterKey& other) const
    {
        return std::tie(groupIds, roleIds) < std::tie(other.groupIds, other.roleIds);
    }
};

ClusterKey clusterKey(const User& user)
{
    ClusterKey result;
    const auto groups = user.groups();
    for (const auto& group : groups) {
        result.groupIds.insert(group.id());
    }
    const auto policies = user.policies();
    for (const auto& policy : policies) {
        result.roleIds.insert(policy.roleId());
    }
    return result;
}

ClusterKey clusterKey(const Group& group)
{
    ClusterKey result;
    const auto policies = group.policies();
    for (const auto& policy : policies) {
        result.roleIds.insert(policy.roleId());
    }
    return result;
}

} // namespace

void
ClusterManager::updateUserCluster(ID userId)
{
    ACLGateway gw(work_);
    try {
        auto targetUser = gw.userById(userId);
        const auto clusterKeyData = clusterKey(targetUser);
        assignCluster(userId, requestCluster(clusterKeyData.groupIds, clusterKeyData.roleIds));
    } catch (const UserNotExists&) {
        assignCluster(userId, 0);
    }
}

namespace {
std::map<ID, ClusterKey>
getAgentToClusterKeyMap(Transaction& work, bool onlyMissing)
{
    std::string query = "SELECT a.id,"
            " '{'||array_to_string(array_agg(group_id), ',')||'}', "
            " '{'||array_to_string(array_agg(role_id), ',')||'}' "
            "FROM acl.agent a LEFT JOIN acl.policy po ON a.id = po.agent_id "
                "LEFT JOIN acl.group_user gu ON a.id=gu.user_id";
    if (onlyMissing) {
        query +=
            " WHERE NOT EXISTS "
            " (SELECT cluster_id FROM acl.agent_cluster ac WHERE ac.agent_id = a.id) ";
    }
    query += " GROUP BY 1";
    auto allAgentRows =
        work.exec(query);
    std::map<ID, ClusterKey> result;
    for (const auto& agentRow : allAgentRows) {
        ClusterKey key;
        const auto agentId = agentRow[0].as<ID>();
        if (!agentRow[1].is_null()) {
            const auto agentGroups = common::parseSqlArray<ID>(agentRow[1].c_str());
            key.groupIds.insert(agentGroups.begin(), agentGroups.end());
        }
        if (!agentRow[2].is_null()) {
            const auto agentRoles = common::parseSqlArray<ID>(agentRow[2].c_str());
            key.roleIds.insert(agentRoles.begin(), agentRoles.end());
        }
        result.emplace(agentId, key);
    }
    return result;
}
};

void
ClusterManager::updateAgentsClusters(bool onlyMissing)
{
    auto agentToClusterKey = getAgentToClusterKeyMap(work_, onlyMissing);
    std::map<ClusterKey, ID> createdClusters;
    std::map<ID, ID> agentToCluster;
    for (const auto& [agentId, key] : agentToClusterKey) {
        auto it = createdClusters.find(key);
        ID clusterId = 0;
        if (it == createdClusters.end()) {
            clusterId = requestCluster(key.groupIds, key.roleIds);
            createdClusters.emplace(key, clusterId);
        } else {
            clusterId = it->second;
        }
        agentToCluster.emplace(agentId, clusterId);
    }
    if (agentToCluster.empty()) {
        return;
    }
    std::string insertQuery =
        "INSERT INTO acl.agent_cluster (agent_id, cluster_id)"
        " VALUES ";
    insertQuery += common::join(agentToCluster,
        [&](const auto& agentClusterPair) {
            return "(" + std::to_string(agentClusterPair.first) +
                "," + std::to_string(agentClusterPair.second) + ")";
        },
        ",");
    if (!onlyMissing) {
        insertQuery = "DELETE FROM acl.agent_cluster;" + insertQuery;
    }
    work_.exec(insertQuery);
}

void
ClusterManager::updateGroupCluster(ID groupId)
{
    try {
        ACLGateway gw(work_);
        auto targetGroup = gw.group(groupId);
        const auto clusterKeyData = clusterKey(targetGroup);
        assignCluster(groupId, requestCluster(clusterKeyData.groupIds, clusterKeyData.roleIds));
    } catch (const GroupNotExists& ex) {
        assignCluster(groupId, 0);
    }
}

void
ClusterManager::enqueueGroupCluster(ID groupId)
{
    work_.exec(
        "INSERT INTO " + QUEUE_TABLE + " (group_id)"
        " VALUES (" + std::to_string(groupId) + ")");
}

void
ClusterManager::enqueueUserCluster(ID userId)
{
    work_.exec(
        "INSERT INTO " + QUEUE_TABLE + " (user_id)"
        " VALUES (" + std::to_string(userId) + ")");
}

void
ClusterManager::enqueueClusters(const std::set<ID>& clusterIds)
{
    if (clusterIds.empty()) {
        return;
    }
    work_.exec(
        "INSERT INTO " + QUEUE_TABLE + " (cluster_id)"
        " VALUES (" + common::join(clusterIds, "),(") + ")");
}

size_t
ClusterManager::processClustersUpdateQueue()
{
    auto rows = work_.exec(
        "SELECT id, cluster_id, user_id, group_id FROM " + QUEUE_TABLE +
        " ORDER BY enqueued_at ASC");
    std::set<ID> clusterIds;
    std::set<ID> userIds;
    std::set<ID> groupIds;
    std::set<ID> recordIds;
    for (const auto& row : rows) {
        recordIds.insert(row["id"].as<ID>());
        const auto& cluster = row["cluster_id"];
        if (!cluster.is_null()) {
            clusterIds.insert(cluster.as<ID>());
        }
        const auto& group = row["group_id"];
        if (!group.is_null()) {
            groupIds.insert(group.as<ID>());
        }
        const auto& user = row["user_id"];
        if (!user.is_null()) {
            userIds.insert(user.as<ID>());
        }
    };
    if (!clusterIds.empty()) {
        updateClustersLocal(clusterIds, work_);
    }
    if (!groupIds.empty()) {
        for (const auto groupId : groupIds) {
            updateGroupCluster(groupId);
        }
    }
    if (!userIds.empty()) {
        for (const auto userId : userIds) {
            updateUserCluster(userId);
        }
    }
    if (recordIds.empty()) {
        removeClustersWithoutRefs();
    } else {
        work_.exec("DELETE FROM " + QUEUE_TABLE + " WHERE " +
            common::whereClause("id", recordIds));
    }
    return rows.size();
}

size_t
ClusterManager::clustersUpdateQueueSize() const
{
    auto rows = work_.exec(
        "SELECT COUNT(*) FROM " + QUEUE_TABLE);
    return rows[0][0].as<size_t>();
}

} // namespace maps::wiki::acl
