#include "worker.h"

#include <maps/libs/common/include/exception.h>
#include <maps/wikimap/mapspro/libs/acl/include/aclgateway.h>
#include <yandex/maps/wiki/validator/result.h>
#include <yandex/maps/wiki/validator/validator_config.h>
#include <yandex/maps/wiki/validator/validator.h>
#include <yandex/maps/wiki/validator/changed_objects.h>
#include <yandex/maps/wiki/validator/storage/results_gateway.h>
#include <yandex/maps/wiki/common/extended_xml_doc.h>
#include <yandex/maps/wiki/common/pgpool3_helpers.h>
#include <yandex/maps/wiki/common/retry_duration.h>
#include <yandex/maps/wiki/common/string_utils.h>
#include <yandex/maps/wiki/tasks/task_logger.h>
#include <yandex/maps/wiki/tasks/task_manager.h>
#include <maps/libs/log8/include/log8.h>

#include <maps/wikimap/mapspro/libs/acl_utils/include/moderation.h>
#include <maps/wikimap/mapspro/services/tasks/validator-checks/splitter/checks_splitter.h>

#include <boost/range/algorithm_ext/erase.hpp>

#include <thread>

namespace storage = maps::wiki::validator::storage;

namespace maps::wiki::validation {

namespace {

const std::string TASKS_URL_XPATH = "/config/services/tasks/url";
const std::string BASE_WITH_CARDINALITY_CHECK_ID = "base_with_cardinality";
const std::string RD_TOPOLOGY_CHECK_ID = "rd_topology";
const std::string COVERAGE_DIR_PREFIX = "/var/tmp/yandex-maps-mpro-validator-coverage.";

constexpr size_t LIGHT_VALIDATION_CHECK_THREADS_COUNT = 8;

const std::string LONGREAD_DB_ID = "long-read";
const std::string VALIDATOR_POOL_ID = "validator";
const std::string VALIDATOR_HEAVY_POOL_ID = "validator.heavy";
const std::string RESULTS_DB_ID = "validation";
const std::string RESULTS_POOL_ID = "grinder";


bool isOutsourcer(pgpool3::Pool& pool, Uid uid)
{
    return common::retryDuration([&] {
        auto txnCore = pool.masterReadOnlyTransaction();
        auto user = acl::ACLGateway(*txnCore).user(uid);
        return acl_utils::isOutsourcer(user);
    });
}

acl::User::Status getUserStatus(pgpool3::Pool& pool, Uid uid)
{
    return common::retryDuration([&] {
        auto txnCore = pool.masterReadOnlyTransaction();
        auto user = acl::ACLGateway(*txnCore).user(uid);
        return user.status();
    });
}

void setCheckThreads(
    validator::Validator& validator,
    const TaskData& taskData,
    size_t maxCheckThreads)
{
    size_t hwThreads = std::thread::hardware_concurrency();
    INFO() << "Hardware threads: " << hwThreads;

    if (maxCheckThreads) {
        validator.setCheckThreadsCount(maxCheckThreads);
    } else if (taskData.isHeavy) {
        validator.setCheckThreadsCount(hwThreads / 2);
    } else {
        validator.setCheckThreadsCount(LIGHT_VALIDATION_CHECK_THREADS_COUNT);
    }
    INFO() << "Check threads count: " << validator.checkThreadsCount();
}

} // anonymous namespace


Worker::Worker(
        const common::ExtendedXmlDoc& config,
        const validator::ValidatorConfig& validatorConfig,
        bool isHeavy)
    : config_(config)
    , validatorConfig_(validatorConfig)
    , validator_(validatorConfig_)
    , longReadPoolHolder_(
        config_,
        LONGREAD_DB_ID, isHeavy ? VALIDATOR_HEAVY_POOL_ID : VALIDATOR_POOL_ID)
{
    validator_.initModules();
}

void Worker::dualInfo(const std::string& str) const
{
    INFO() << str;
    if (taskLogger_) {
        taskLogger_->logInfo() << str;
    }
}

void Worker::dualWarn(const std::string& str) const
{
    WARN() << str;
    if (taskLogger_) {
        taskLogger_->logWarn() << str;
    }
}

void Worker::dualError(const std::string& str) const
{
    ERROR() << str;
    if (taskLogger_) {
        taskLogger_->logError() << str;
    }
}

void Worker::initTaskLogger(const TaskData& taskData)
{
    if (taskLogger_) {
        return;
    }

    taskLogger_ = std::make_unique<tasks::TaskPgLogger>(
        longReadPoolHolder_.pool(), taskData.taskId);

    dualInfo("Task started. Grinder task id: " + taskData.grinderTaskId);
    if (!taskData.ytOperationId.empty()) {
        dualInfo("yt: " + taskData.ytOperationId);
    }

    dualInfo("branch: " + std::to_string(taskData.branchId) + ", "
             "commit: " + std::to_string(taskData.commitId));
    if (taskData.parentTaskId) {
        dualInfo("parent task id: " + std::to_string(taskData.parentTaskId));
    }
}

void Worker::prepareTaskData(TaskData& taskData)
{
    INFO() << "Received task: " << taskData.taskId << ". "
           << "Grinder task id: " << taskData.grinderTaskId;

    auto status = getUserStatus(longReadPoolHolder_.pool(), taskData.uid);
    INFO() << "User uid: " << taskData.uid << " status: " << status;
    if (status != acl::User::Status::Active) {
        initTaskLogger(taskData);
        std::ostringstream os;
        os << "Forbidden user status: " << status;
        dualError(os.str());
        throw maps::RuntimeError() << os.str();
    }

    CheckIds allChecks;
    for (const auto& module : validator_.modules()) {
        for (const auto& checkId : module.checkIds()) {
            allChecks.insert(checkId);
        }
    }

    boost::range::remove_erase_if(taskData.checks, [&](const CheckId& checkId) {
        if (!allChecks.count(checkId)) {
            dualError("Unsupported check '" + checkId + "'");
            return true;
        }
        return false;
    });
}

bool Worker::splitTasks(const TaskData& taskData)
{
    if (!taskData.canSplit()) {
        return false;
    }

    if (taskData.hasCheck(RD_TOPOLOGY_CHECK_ID) &&
        isOutsourcer(longReadPoolHolder_.pool(), taskData.uid))
    {
        initTaskLogger(taskData);
        throw maps::RuntimeError() << "Forbidden operation";
    }

    CheckIds checksSorted{taskData.checks.begin(), taskData.checks.end()};
    INFO() << "Checks: " << common::join(checksSorted, ',');

    ChecksSplitter checksSplitter(validator_);
    auto groupChecks = checksSplitter.split(checksSorted);
    if (groupChecks.size() <= 1) {
        return false;
    }

    initTaskLogger(taskData);
    tasks::TaskManager taskManager(config_.get<std::string>(TASKS_URL_XPATH), taskData.uid);

    for (const auto& checks : groupChecks) {
        INFO() << "Subtask data: "
               << taskData.uid << "," << taskData.branchId << "," << taskData.commitId << ","
               << taskData.regionId << "," << common::join(checks, ',');

        auto subtask = taskManager.startValidation(
            taskData.branchId, checks, tasks::NO_AOI,
            taskData.regionId
                ? boost::optional<tasks::ObjectId>(taskData.regionId)
                : tasks::NO_REGION,
            taskData.commitId, taskData.taskId);

        dualInfo("Subtask ID: " + std::to_string(subtask.id()));
    }
    return true;
}

Worker::Status Worker::run(
    const TaskData& taskData,
    std::function<bool()> checkCanceled)
{
    initTaskLogger(taskData);

    setCheckThreads(validator_, taskData, maxCheckThreads_);

    auto& longReadPgPool = longReadPoolHolder_.pool();

    validator_.setIsCanceledChecker(checkCanceled);

    try {
        if (taskData.hasCheck(BASE_WITH_CARDINALITY_CHECK_ID)) {
            // enable relations cardinality constraints checks defined
            // in editor config
            validator_.enableCardinalityCheck();
        }

        auto coverageDir = COVERAGE_DIR_PREFIX + taskData.grinderTaskId;
        validator::ResultPtr result;
        if (taskData.aoiGeom) {
            result = validator_.run(
                taskData.checks, longReadPgPool, taskData.branchId, taskData.commitId,
                *taskData.aoiGeom, coverageDir, taskData.aoiBuffer);
        } else if (!taskData.aoiIds.empty()) {
            result = validator_.run(
                taskData.checks, longReadPgPool, taskData.branchId, taskData.commitId,
                taskData.aoiIds, coverageDir, taskData.aoiBuffer);
        } else if (taskData.regionId) {
            result = validator_.run(
                taskData.checks, longReadPgPool, taskData.branchId, taskData.commitId,
                {taskData.regionId}, coverageDir, validator::ZERO_AOI_BUFFER);
        } else if (taskData.onlyChangedObjects) {
            auto objectIds = validator::findChangedObjectsInReleaseBranch(
                longReadPgPool, taskData.branchId, taskData.commitId);
            if (!objectIds.empty()) {
                result = validator_.run(
                    taskData.checks, longReadPgPool, taskData.branchId, taskData.commitId,
                    objectIds);
            } else {
                dualWarn("No changed objects found");
            }
        } else {
            result = validator_.run(
                taskData.checks, longReadPgPool, taskData.branchId, taskData.commitId);
        }

        if (!storage::storeResult(result, validationPool(), taskData.taskId)) {
            dualError("Writing to DB failed");
            return Status::Failed;
        }
    } catch (const validator::InitializationError& e) {
        ERROR() << e;
        dualWarn("Initialization failed, trying to restart task");
        return Status::NeedRetry;
    } catch (const validator::ValidationCanceledException&) {
        dualInfo("Task canceled");
        return Status::Canceled;
    } catch (const maps::Exception& e) {
        ERROR() << "Task " << taskData.taskId << " failed: " << e;
        taskLogger_->logError() << "Task failed";
        return Status::Failed;
    } catch (const std::exception& e) {
        ERROR() << "Task " << taskData.taskId << " failed: " << e.what();
        taskLogger_->logError() << "Task failed";
        return Status::Failed;
    }
    dualInfo("Task finished");
    if (taskData.parentTaskId) {
        tasks::TaskPgLogger logger(longReadPgPool, taskData.parentTaskId);
        logger.logInfo() << "Subtask " << taskData.taskId << ". Task finished";
    }
    return Status::Ok;
}

TaskData Worker::loadTaskData(DBID taskId)
{
    auto txnCore = longReadPoolHolder_.pool().masterReadOnlyTransaction();

    return {*txnCore, taskId};
}

pgpool3::Pool& Worker::validationPool()
{
    if (!validationPoolHolder_) {
        validationPoolHolder_ = std::make_unique<common::PoolHolder>(
            config_, RESULTS_DB_ID, RESULTS_POOL_ID
        );
    }
    return validationPoolHolder_->pool();
}

void Worker::cleanResults(DBID taskId)
{
    auto txn = validationPool().masterWriteableTransaction();
    for (std::string tableName : { "task_message", "task_message_stats" }) {
        txn->exec(
            "DELETE FROM validation." + tableName +
            " WHERE task_id = " + std::to_string(taskId));
    }
    txn->commit();
}

} // namespace maps::wiki::validation
