#include "policy.h"

#include <maps/wikimap/mapspro/libs/acl/include/aclgateway.h>
#include <maps/libs/common/include/profiletimer.h>
#include <maps/libs/introspection/include/comparison.h>
#include <maps/libs/introspection/include/hashing.h>

#include <algorithm>
#include <unordered_set>
#include <vector>

namespace maps::wiki::aclsrv {

namespace {

using Callback = std::function<bool(const Sequence&)>;
constexpr size_t SEARCH_LIMIT = 10; // seconds

bool intersects(const std::set<acl::ID>& first, const std::set<acl::ID>& second)
{
    auto it1 = first.begin();
    auto it2 = second.begin();
    while (it1 != first.end() && it2 != second.end()) {
        if (*it1 < *it2) {
            ++it1;
        } else if (*it2 < *it1) {
            ++it2;
        } else {
            return true;
        }
    }
    return false;
}

bool
generateBinarySequencesR(
    size_t maxLength,
    size_t onesCount,
    const Sequence& curSequence,
    const Callback& callback)
{
    if (curSequence.size() == maxLength) {
        return callback(curSequence);
    }
    if (onesCount > 0) {
        Sequence next = curSequence;
        next.push_back(true);
        if (!generateBinarySequencesR(maxLength, onesCount - 1, next, callback)) {
            return false;
        }
    }
    if (onesCount + curSequence.size() < maxLength) {
        Sequence next = curSequence;
        next.push_back(false);
        if (!generateBinarySequencesR(maxLength, onesCount, next, callback)) {
            return false;
        }
    }
    return true;
}

struct PolicyWithGroupIds {
    acl::ID roleId;
    acl::ID aoiId;
    acl::ID groupId;

    auto introspect() const
    { return std::tie(roleId, aoiId, groupId); }
};

using introspection::operator==;

std::vector<PolicyWithGroup>
getUniquePolicies(
    const std::vector<PolicyWithGroup>& policiesWithGroups)
{
    std::vector<PolicyWithGroup> result;
    std::unordered_set<PolicyWithGroupIds, introspection::Hasher> uniqueIds;

    for (const auto& policyWithGroup : policiesWithGroups) {
        PolicyWithGroupIds ids {
            policyWithGroup.role ? policyWithGroup.role->id() : 0,
            policyWithGroup.aoi ? policyWithGroup.aoi->id() : 0,
            policyWithGroup.group ? policyWithGroup.group->id() : 0
        };
        if (!uniqueIds.count(ids)) {
            result.push_back(policyWithGroup);
            uniqueIds.insert(ids);
        }
    }
    return result;
}

std::set<acl::ID>
getPermissionsLeafIds(
    acl::ACLGateway gw,
    const std::vector<acl::SubjectPath>& permissionPaths)
{
    std::set<acl::ID> result;
    for (const auto& path : permissionPaths) {
        auto permissionLeafIds = gw.permission(path).leafIds();
        result.insert(permissionLeafIds.begin(), permissionLeafIds.end());
    }
    return result;
}

bool
generateBinarySequences(size_t maxLength, const Callback& callback)
{
    std::vector<Sequence> result;
    for (size_t onesCount = 1; onesCount <= maxLength; ++onesCount) {
        if (!generateBinarySequencesR(maxLength, onesCount, {}, callback)) {
            return false;
        }
    }
    return true;
}

} // namespace

std::vector<Sequence>
generateBinarySequences(size_t maxLength)
{
    std::vector<Sequence> result;
    generateBinarySequences(
        maxLength,
        [&](const Sequence& sequence) { result.push_back(sequence); return true; }
    );
    return result;
}

bool
leftBinSeqContainsRight(
    const Sequence& leftSequence,
    const Sequence& rightSequence)
{
    ASSERT(leftSequence.size() == rightSequence.size());
    for (size_t i = 0; i < leftSequence.size(); ++i) {
        if (!leftSequence[i] && rightSequence[i]) {
            return false;
        }
    }
    return true;
}

PermittingPolicies
calculatePermittingPolicies(
    acl::ACLGateway gw,
    const std::vector<PolicyWithGroup>& userPolicies,
    const std::vector<acl::SubjectPath>& permissionPaths)
{
    if (permissionPaths.empty()) {
        return {};
    }

    const auto& filterLeafIds = getPermissionsLeafIds(gw, permissionPaths);
    std::unordered_map<acl::ID, std::set<acl::ID>> roleLeafsById;
    std::set<acl::ID> roleIds;
    for (const auto& policyWithGroup : userPolicies) {
        if (!policyWithGroup.role) {
            continue;
        }
        const auto& role = *policyWithGroup.role;
        if (roleIds.count(role.id())) {
            continue;
        } else {
            roleIds.insert(role.id());
        }
        const auto roleLeafIds = role.leafPermissionsIds();
        if (intersects(roleLeafIds, filterLeafIds)) {
            roleLeafsById[role.id()] = roleLeafIds;
        }
    }

    PermittingPolicies result;
    if (filterLeafIds.size() == 1) {
        for (const auto& [roleId, _] : roleLeafsById) {
            for (const auto& policyWithGroup : userPolicies) {
                if (policyWithGroup.role && policyWithGroup.role->id() == roleId) {
                    result.policies.push_back(policyWithGroup);
                }
            }
        }
    } else {
        ProfileTimer profileTimer;
        std::vector<Sequence> permittingSequences;
        std::vector<acl::ID> searchRoleIds;
        for (const auto& [roleId, _] : roleLeafsById) {
            searchRoleIds.push_back(roleId);
        }
        auto checkSequence = [&](const Sequence& sequence) -> bool {
            if (profileTimer.getElapsedTimeNumber() > SEARCH_LIMIT) {
                return false;
            }
            std::unordered_set<acl::ID> policySubsetLeafIds;

            for (size_t i = 0; i < sequence.size(); ++i) {
                if (!sequence[i]) {
                    continue;
                }
                const auto& roleLeafIds = roleLeafsById[searchRoleIds[i]];
                policySubsetLeafIds.insert(roleLeafIds.begin(), roleLeafIds.end());
            }

            auto containsAllLeafs = [&]() {
                return std::all_of(
                    filterLeafIds.begin(),
                    filterLeafIds.end(),
                    [&](acl::ID leafId) {
                        return policySubsetLeafIds.count(leafId);
                    });
            };

            auto alreadyFoundShorter = [&]() {
                return std::any_of(
                    permittingSequences.begin(),
                    permittingSequences.end(),
                    [&](const Sequence& permittingSequence) {
                        return leftBinSeqContainsRight(sequence, permittingSequence);
                    });
            };

            if (containsAllLeafs() && !alreadyFoundShorter()) {
                permittingSequences.push_back(sequence);
                for (size_t i = 0; i < sequence.size(); ++i) {
                    if (!sequence[i]) {
                        continue;
                    }
                    for (const auto& policyWithGroup : userPolicies) {
                        if (policyWithGroup.role && policyWithGroup.role->id() == searchRoleIds[i]) {
                            result.policies.push_back(policyWithGroup);
                        }
                    }
                }
            }
            return true;
        };
        result.complete = generateBinarySequences(searchRoleIds.size(), checkSequence);
    }
    result.policies = getUniquePolicies(result.policies);
    return result;
}

} // namespace maps::wiki::aclsrv
