#include "exclusions.h"
#include "extractors.h"
#include "helpers.h"

#include <maps/wikimap/mapspro/libs/acl/include/aclgateway.h>
#include <yandex/maps/wiki/common/batch.h>
#include <yandex/maps/wiki/common/retry_duration.h>
#include <yandex/maps/wiki/common/revision_utils.h>
#include <yandex/maps/wiki/common/string_utils.h>
#include <yandex/maps/wiki/common/moderation.h>
#include <yandex/maps/wiki/revision/revisionsgateway.h>
#include <yandex/maps/wiki/revision/historical_snapshot.h>

#include <maps/libs/common/include/profiletimer.h>
#include <maps/libs/log8/include/log8.h>

#include <future>

namespace rev = maps::wiki::revision;
namespace rf = maps::wiki::revision::filters;

namespace maps {
namespace wiki {
namespace diffalert {
namespace {

const std::string STR_GROUP_MOVED = "group-moved";
const std::string STR_GROUP_MODIFIED_ATTRIBUTES = "group-modified-attributes";
const std::string STR_GROUP_DELETED = "group-deleted";
const std::string STR_LONG_TASK = "long-task";

const std::string OUTSOURCE_ROLE = "outsource-role";

const std::string STR_IMPORT = "import";
const std::string STR_EMPTY;

const size_t BATCH_LOADING_SIZE = 500;

struct LoadResult
{
    CommitIdToInfo commitIdToInfo;
    ObjectIdToCommitIds objectIdsToCommitIds;

    void operator+=(const LoadResult& rhs)
    {
        commitIdToInfo.insert(rhs.commitIdToInfo.begin(), rhs.commitIdToInfo.end());
        for (const auto& [objectId, commitIds] : rhs.objectIdsToCommitIds) {
            objectIdsToCommitIds[objectId].insert(commitIds.begin(), commitIds.end());
        }
    }
};

ObjectIdToCommitIds loadObjectIdsAffectedByCommitIds(
    pqxx::transaction_base& txn,
    const rev::Branch& branch,
    const std::set<CommitId>& commitIds)
{
    ObjectIdToCommitIds result;

    auto minCommitId = *commitIds.begin();
    auto maxCommitId = *commitIds.rbegin();

    rev::RevisionsGateway gateway(txn, branch);
    auto snapshot = gateway.historicalSnapshot(minCommitId, maxCommitId + 1);

    // load revision ids
    common::applyBatchOp(
        commitIds,
        BATCH_LOADING_SIZE,
        [&](const auto& batchCommitIds) {
            auto revIds = snapshot.revisionIdsByFilter(
                rf::ObjRevAttr::commitId().in(batchCommitIds)
                    && rf::ObjRevAttr::isNotRelation());
            for (const auto& revId : revIds) {
                result[revId.objectId()].insert(revId.commitId());
            }
        });

    // load relations
    common::applyBatchOp(
        commitIds,
        BATCH_LOADING_SIZE,
        [&](const auto& batchCommitIds) {
            auto relations = snapshot.relationsByFilter(
                rf::ObjRevAttr::commitId().in(batchCommitIds));
            for (const auto& rel : relations) {
                const auto& relData = *rel.data().relationData;
                result[relData.masterObjectId()].insert(rel.id().commitId());
                result[relData.slaveObjectId()].insert(rel.id().commitId());
            }
        });

    return result;
}

LoadResult loadCommitsFromBranch(
    const rf::ProxyFilterExpr& branchDifferencingFilter,
    const rev::Branch& branch,
    const rev::SnapshotId& snapshotId,
    pgpool3::Pool& tdsConnPool)
{
    return common::retryDuration([&] {
        LoadResult result;

        auto txn = common::getReadTransactionForCommit(
            tdsConnPool,
            branch.id(),
            snapshotId.commitId(),
            [](const std::string& msg) {
                INFO() << msg;
            });

        auto commits = rev::Commit::load(*txn,
            branchDifferencingFilter
            && rf::CommitAttr::id() <= snapshotId.commitId());
        if (commits.empty()) {
            INFO() << "No commits found in branch " << branch.id();
            return result;
        }

        std::set<CommitId> commitIds;
        for (const auto& commit : commits) {
            commitIds.insert(commit.id());
            result.commitIdToInfo.emplace(commit.id(), CommitInfo(commit));
        }
        result.objectIdsToCommitIds = loadObjectIdsAffectedByCommitIds(*txn, branch, commitIds);
        return result;
    });
}

using MasterIdToSlaveIds = std::unordered_map<ObjectId, std::unordered_set<ObjectId>>;

MasterIdToSlaveIds loadMasterToSlaveRelationsFromBranch(
    const std::unordered_set<ObjectId>& geomPartIds,
    const rev::Branch& branch,
    const rev::SnapshotId& snapshotId,
    pgpool3::Pool& tdsConnPool)
{
    return common::retryDuration([&] {
        auto txn = common::getReadTransactionForCommit(
            tdsConnPool,
            branch.id(),
            snapshotId.commitId(),
            [](const std::string& msg) {
                INFO() << msg;
            });

        rev::RevisionsGateway gateway(*txn, branch);
        auto snapshot = gateway.snapshot(snapshotId);

        MasterIdToSlaveIds masterIdToGeomPartIds;

        common::applyBatchOp(
            geomPartIds,
            BATCH_LOADING_SIZE,
            [&](const auto& batchGeomPartIds) {
                auto relations = snapshot.relationsByFilter(
                    rf::ObjRevAttr::slaveObjectId().in(batchGeomPartIds)
                    && rf::ObjRevAttr::isNotDeleted());

                for (const auto& rel : relations) {
                    const auto& relData = *rel.data().relationData;
                    masterIdToGeomPartIds[relData.masterObjectId()].insert(relData.slaveObjectId());
                }
            });

        return masterIdToGeomPartIds;
    });
}

std::unordered_map<TUId, UserType> loadUidToRole(
    pgpool3::Pool& tdsConnPool,
    const std::unordered_set<TUId>& uids)
{
    std::unordered_map<TUId, UserType> uidToRole;

    common::applyBatchOp(
        uids,
        BATCH_LOADING_SIZE,
        [&](const auto& batchUids) {
            common::retryDuration([&] {
                auto txn = tdsConnPool.slaveTransaction();
                acl::ACLGateway aclGateway(*txn);

                auto batchUsers = aclGateway.users(
                    std::vector<TUId>(batchUids.begin(), batchUids.end()));

                auto idToRole = aclGateway.firstApplicableRoles(
                    batchUsers,
                    {OUTSOURCE_ROLE, common::MODERATION_STATUS_CARTOGRAPHER},
                    {});

                for (const auto& user : batchUsers) {
                    const auto& role = idToRole[user.id()];
                    if (role == common::MODERATION_STATUS_CARTOGRAPHER) {
                        uidToRole.emplace(user.uid(), UserType::Cartographer);
                    } else if (role == OUTSOURCE_ROLE) {
                        uidToRole.emplace(user.uid(), UserType::Outsourcer);
                    }
                }
            });
        });
    return uidToRole;
}

void fillUpUserTypesForCommits(
    pgpool3::Pool& tdsConnPool,
    CommitIdToInfo& commitIdToInfo)
{
    std::unordered_set<TUId> uids;
    for (const auto& [commitId, commitInfo] : commitIdToInfo) {
        uids.insert(commitInfo.createdBy);
    }

    auto uidToRole = loadUidToRole(tdsConnPool, uids);

    for (auto& [commitId, commitInfo] : commitIdToInfo) {
        auto it = uidToRole.find(commitInfo.createdBy);

        commitInfo.userType = it != uidToRole.end()
            ? it->second
            : UserType::Common;
    }
}

} // namespace

CommitInfo::CommitInfo(const rev::Commit& commit)
    : commitId(commit.id())
    , createdBy(commit.createdBy())
    , userType(UserType::Common)
    , actionType(ActionType::Other)
{
    const auto& action = commit.action();
    if (action == STR_IMPORT) {
        actionType = ActionType::Import;
        return;
    }

    if (commit.source() != STR_LONG_TASK) {
        // group edit commits that were made in the editor have no `source` attribute
        return;
    }

    if (action == STR_GROUP_MOVED) {
        actionType = ActionType::GroupMove;
    } else if (action == STR_GROUP_MODIFIED_ATTRIBUTES) {
        actionType = ActionType::GroupEditAttributes;
    } else if (action == STR_GROUP_DELETED) {
        actionType = ActionType::GroupDelete;
    }
}

bool CommitInfo::isCommitExcluded(const CommitFilter& commitFilter) const
{
    return commitFilter.excludedActionTypes().isSet(actionType);
}

bool CommitInfo::isCommitIncluded(const CommitFilter& commitFilter) const
{
    return commitFilter.includedUserTypes().isSet(userType);
}

//========================== ExclusionFinder =========================

ExclusionFinder::ExclusionFinder(
        pgpool3::Pool& tdsConnPool,
        const CommitFilter& commitFilter,
        const revision::Branch& oldBranch,
        const revision::SnapshotId& oldSnapshotId,
        const revision::Branch& newBranch,
        const revision::SnapshotId& newSnapshotId)
    : tdsConnPool_(tdsConnPool)
    , commitFilter_(commitFilter)
    , oldBranch_(oldBranch)
    , oldSnapshotId_(oldSnapshotId)
    , newBranch_(newBranch)
    , newSnapshotId_(newSnapshotId)
{
}

void ExclusionFinder::loadCommits(
    const revision::filters::ProxyFilterExpr& oldSnapshotFilter,
    const revision::filters::ProxyFilterExpr& newSnapshotFilter)
{
    if (!commitFilter_) {
        return;
    }

    ProfileTimer timer;

    auto loadResultFromNewBranch = std::async(std::launch::async,
        [&]() {
            return loadCommitsFromBranch(newSnapshotFilter, newBranch_, newSnapshotId_, tdsConnPool_);
        });

    auto loadResult = loadCommitsFromBranch(oldSnapshotFilter, oldBranch_, oldSnapshotId_, tdsConnPool_);
    loadResult += loadResultFromNewBranch.get();

    if (loadResult.commitIdToInfo.empty()) {
        WARN() << "No commits loaded";
        return;
    }

    commitIdToInfo_ = std::move(loadResult.commitIdToInfo);
    objectIdsToCommitIds_ = std::move(loadResult.objectIdsToCommitIds);

    fillUpUserTypesForCommits(tdsConnPool_, commitIdToInfo_);

    INFO() << "Exclusion data has been loaded in " << timer.getElapsedTime();
}

void ExclusionFinder::propagateCommitsToMasters(
    const std::map<TId, LongtaskDiffContext::Impl>& diffContextImpls)
{
    if (!commitFilter_ || commitIdToInfo_.empty()) {
        return;
    }

    ProfileTimer timer;

    std::unordered_set<ObjectId> geomPartIds;
    for (const auto& [objectId, diffContextImpl] : diffContextImpls) {
        const auto& categoryId = diffContextImpl.objectDiff.anyObject().categoryId();
        if (isGeomPartCategory(categoryId)) {
            geomPartIds.insert(objectId);
        }
    }

    auto masterIds = propagateCommitsToMasters(geomPartIds);

    std::unordered_set<ObjectId> newGeomPartIds;
    for (auto masterId : masterIds) {
        auto it = diffContextImpls.find(masterId);
        if (it == diffContextImpls.end()) {
            continue;
        }
        const auto& categoryId = it->second.objectDiff.anyObject().categoryId();
        if (isGeomPartCategory(categoryId)) {
            newGeomPartIds.insert(masterId);
        }
    }

    propagateCommitsToMasters(newGeomPartIds);

    INFO() << "Propagation has been performed in " << timer.getElapsedTime();
}

std::unordered_set<ObjectId> ExclusionFinder::propagateCommitsToMasters(
    const std::unordered_set<ObjectId>& geomPartIds)
{
    if (geomPartIds.empty()) {
        return {};
    }

    auto newMasterToSlaveRelations = std::async(std::launch::async,
        [&]() {
            return loadMasterToSlaveRelationsFromBranch(geomPartIds, newBranch_, newSnapshotId_, tdsConnPool_);
        });

    auto masterToSlaveRelations = loadMasterToSlaveRelationsFromBranch(geomPartIds, oldBranch_, oldSnapshotId_, tdsConnPool_);

    // merge loading results
    for (const auto& [masterId, geomPartIds] : newMasterToSlaveRelations.get()) {
        masterToSlaveRelations[masterId].insert(geomPartIds.begin(), geomPartIds.end());
    }

    std::unordered_set<ObjectId> masterIds;

    // propagate commits from geom parts to masters
    for (const auto& [masterId, geomPartIds] : masterToSlaveRelations) {
        auto& masterCommitIds = objectIdsToCommitIds_[masterId];
        for (auto geomPartId : geomPartIds) {
            const auto& geomPartCommitIds = objectIdsToCommitIds_[geomPartId];
            masterCommitIds.insert(geomPartCommitIds.begin(), geomPartCommitIds.end());
        }
        masterIds.insert(masterId);
    }

    return masterIds;
}

bool ExclusionFinder::isObjectExcluded(ObjectId objectId) const
{
    if (!commitFilter_) {
        return false;
    }

    auto it = objectIdsToCommitIds_.find(objectId);
    REQUIRE(it != objectIdsToCommitIds_.end(), "Failed to find commits for object id " << objectId);
    const auto& commitIds = it->second;

    bool excluded = std::all_of(commitIds.begin(), commitIds.end(),
        [&](const auto& commitId) {
            auto it = commitIdToInfo_.find(commitId);
            ASSERT(it != commitIdToInfo_.end());
            const auto& commitInfo = it->second;
            return commitInfo.isCommitExcluded(commitFilter_);
        });

    if (excluded) {
        return true;
    }

    excluded = std::any_of(commitIds.begin(), commitIds.end(),
        [&](const auto& commitId) {
            auto it = commitIdToInfo_.find(commitId);
            ASSERT(it != commitIdToInfo_.end());
            const auto& commitInfo = it->second;
            return !commitInfo.isCommitIncluded(commitFilter_);
        });

    return excluded;
}

} // namespace diffalert
} // namespace wiki
} // namespace maps
