#include "worker.h"

#include "csv_writer.h"
#include "object_labeler.h"

#include <maps/libs/common/include/exception.h>
#include <maps/libs/log8/include/log8.h>
#include <maps/libs/pgpool/include/pgpool3.h>
#include <maps/libs/xml/include/xmlexception.h>
#include <yandex/maps/wiki/common/geom.h>
#include <yandex/maps/wiki/common/default_config.h>
#include <yandex/maps/wiki/common/extended_xml_doc.h>
#include <yandex/maps/wiki/common/pgpool3_helpers.h>
#include <yandex/maps/wiki/common/retry_duration.h>
#include <yandex/maps/wiki/common/revision_utils.h>
#include <yandex/maps/wiki/common/string_utils.h>
#include <yandex/maps/wiki/tasks/task_logger.h>
#include <yandex/maps/wiki/diffalert/revision/runner.h>
#include <yandex/maps/wiki/diffalert/revision/editor_config.h>
#include <yandex/maps/wiki/diffalert/revision/diff_context.h>
#include <yandex/maps/wiki/diffalert/revision/diff_envelopes.h>
#include <yandex/maps/wiki/diffalert/storage/results_writer.h>
#include <yandex/maps/wiki/diffalert/storage/stored_message.h>
#include <yandex/maps/wiki/revision/branch.h>
#include <yandex/maps/wiki/revision/branch_manager.h>
#include <yandex/maps/wiki/revision/filters.h>
#include <yandex/maps/wiki/revision/revisionsgateway.h>
#include <yandex/maps/wiki/revision/snapshot_id.h>
#include <yandex/maps/wiki/threadutils/executor.h>
#include <yandex/maps/wiki/threadutils/threadpool.h>

#include <atomic>
#include <cstdint>
#include <list>
#include <string>
#include <memory>
#include <set>
#include <thread>

namespace cfg = maps::wiki::configs::editor;
namespace common = maps::wiki::common;
namespace da = maps::wiki::diffalert;
namespace mwr = maps::wiki::revision;
namespace rf = maps::wiki::revision::filters;
namespace tasks = maps::wiki::tasks;

namespace maps::wiki::diffalert_worker {

namespace {

const size_t DEFAULT_THREADS_COUNT = 4;

const std::string REVISION_DB_ID = "long-read";
const std::string REVISION_POOL_ID = "diffalert";
const std::string VIEW_STABLE_DB_ID = "view-stable";
const std::string VIEW_STABLE_POOL_ID = "diffalert";
const std::string RESULTS_DB_ID = "validation";
const std::string RESULTS_POOL_ID = "grinder";

const std::string CAT_DIFFALERT_REGION = "cat:diffalert_region";
const std::string DIFFALERT_REGION_PRIORITY = "diffalert_region:priority";

const std::string AOIS_XPATH = "/config/services/tasks/diffalert/aois";

const std::string AOI_NODE = "aoi";
const std::string ID_ATTR = "id";

const std::string ALL_USER_FILTER = "all";
const std::string COMMON_USER_FILTER = "common";
const std::string COMMON_OR_OUTSOURCER_USER_FILTER = "common-or-outsourcer";

namespace attr {
const std::string CAT_AOI = "cat:aoi";
const std::string AOI_NAME = "aoi:name";
} // namespace attr

class TaskCanceledException : public maps::Exception {};

struct SnapshotsPair {
    mwr::Branch oldBranch;
    mwr::SnapshotId oldSnapshotId;
    mwr::Branch newBranch;
    mwr::SnapshotId newSnapshotId;
};

SnapshotsPair getSnapshotsPair(
    const TaskData& taskData,
    maps::pgpool3::Pool& revisionPool)
{
    return common::retryDuration([&] {
        auto txn = revisionPool.masterReadOnlyTransaction();

        mwr::BranchManager branchMgr(*txn);
        auto oldBranch = branchMgr.load(taskData.oldBranchId);
        auto oldSnapshotId = mwr::SnapshotId::fromCommit(
            taskData.oldCommitId, oldBranch.type(), *txn);
        auto newBranch = branchMgr.load(taskData.newBranchId);
        auto newSnapshotId = mwr::SnapshotId::fromCommit(
            taskData.newCommitId, newBranch.type(), *txn);
        return SnapshotsPair{oldBranch, oldSnapshotId, newBranch, newSnapshotId};
    });
}

da::Envelope
msgEnvelope(const da::Message& msg, const da::DiffEnvelopes& diffEnvelopes)
{
    da::Envelope envelope;
    if (msg.scope() == da::Message::Scope::GeomDiff) {
        envelope.expandToInclude(&diffEnvelopes.removed);
        envelope.expandToInclude(&diffEnvelopes.added);
    } else {
        envelope.expandToInclude(&diffEnvelopes.before);
        envelope.expandToInclude(&diffEnvelopes.after);
    }
    return envelope;
}

std::vector<da::DiffalertRegion> getDiffalertRegions(
    maps::pgpool3::Pool& revisionPool)
{
    auto revisionFilter = rf::Attr(CAT_DIFFALERT_REGION).defined()
                           && rf::ObjRevAttr::isNotRelation()
                           && rf::ObjRevAttr::isNotDeleted()
                           && rf::Geom::defined();

    auto revisions = common::retryDuration([&] {
        auto txn = revisionPool.masterReadOnlyTransaction();

        auto branch = mwr::BranchManager(*txn).loadTrunk();
        mwr::RevisionsGateway gateway(*txn, branch);
        auto commitId = gateway.headCommitId();
        auto snapshot = gateway.snapshot(commitId);
        return snapshot.objectRevisionsByFilter(revisionFilter);
    });

    std::vector<da::DiffalertRegion> regions;
    for (const auto& rev : revisions) {
        REQUIRE(rev.data().attributes, "Object " << rev.id() << " has no attributes");
        REQUIRE(rev.data().geometry, "Object " << rev.id() << " has no geometry");

        common::Geom geom(*rev.data().geometry);
        auto priority = da::RegionPriority::Low; //for old regions without attribute

        const auto& attributes = *rev.data().attributes;
        auto it = attributes.find(DIFFALERT_REGION_PRIORITY);
        if (it != attributes.end()) {
            priority = static_cast<da::RegionPriority>(boost::lexical_cast<uint32_t>(it->second));
        }

        regions.push_back(da::DiffalertRegion{std::move(geom), priority});
    }

    std::sort(regions.begin(), regions.end());

    return regions;
}

std::vector<da::Aoi> getAoiRegions(
    const common::ExtendedXmlDoc& config,
    maps::pgpool3::Pool& revisionPool)
{
    auto aoisNode = config.node(AOIS_XPATH, /*quiet =*/true);
    if (aoisNode.isNull()) {
        return {};
    }

    std::set<mwr::DBID> aoiIds;
    auto aoiNodes = aoisNode.nodes(AOI_NODE, /*quiet =*/true);
    for (size_t i = 0; i < aoiNodes.size(); ++i) {
        aoiIds.insert(aoiNodes[i].attr<mwr::DBID>(ID_ATTR));
    }

    if (aoiIds.empty()) {
        return {};
    }

    auto revisions = common::retryDuration([&] {
        auto txn = revisionPool.masterReadOnlyTransaction();

        auto branch = mwr::BranchManager(*txn).loadTrunk();
        mwr::RevisionsGateway gateway(*txn, branch);
        auto commitId = gateway.headCommitId();
        auto snapshot = gateway.snapshot(commitId);
        return snapshot.objectRevisions(aoiIds);
    });

    std::vector<da::Aoi> aois;
    for (const auto& [id, rev] : revisions) {
        REQUIRE(rev.data().attributes, "Object " << rev.id() << " has no attributes");
        REQUIRE(rev.data().attributes->count(attr::CAT_AOI), "Object " << rev.id() << " is not aoi");

        std::string name;
        auto it = rev.data().attributes->find(attr::AOI_NAME);
        if (it != rev.data().attributes->end()) {
            name = it->second;
        } else {
            name = std::to_string(id);
        }

        REQUIRE(rev.data().geometry, "Object " << rev.id() << " has no geometry");
        aois.emplace_back(da::Aoi{std::move(name), common::Geom(*rev.data().geometry)});
    }

    return aois;
}

maps::pgpool3::PoolConstants overridePgPoolConstants(maps::pgpool3::PoolConstants p)
{
    // Try to wait for master availability
    p.timeoutEarlyOnMasterUnavailable = false;
    return p;
}

std::unique_ptr<common::PoolHolder> createViewStablePoolHolder(
    const common::ExtendedXmlDoc& config)
{
    try {
        return std::make_unique<common::PoolHolder>(
            overridePgPoolConstants, config, VIEW_STABLE_DB_ID, VIEW_STABLE_POOL_ID);
    } catch (const maps::xml3::NodeNotFound&) {
        WARN() << "Couldn't init " << VIEW_STABLE_DB_ID
            << " database for view access, trying " << REVISION_DB_ID;
    }
    return std::make_unique<common::PoolHolder>(
        overridePgPoolConstants, config, REVISION_DB_ID, REVISION_POOL_ID);
}

size_t calculateThreadsCount(
    maps::pgpool3::Pool& revisionPool,
    maps::pgpool3::Pool& viewStablePool)
{
    size_t threadsCount = std::thread::hardware_concurrency() * 3 / 4;
    if (!threadsCount) {
        threadsCount = DEFAULT_THREADS_COUNT;
    }
    threadsCount = std::min(
        threadsCount,
        revisionPool.state().constants.slaveMaxSize);
    threadsCount = std::min(
        threadsCount,
        viewStablePool.state().constants.slaveMaxSize);
    return threadsCount;
}

void generateDiffAlertResults(
    const common::ExtendedXmlDoc& config,
    const cfg::ConfigHolder& editorConfig,
    const TaskData& taskData,
    std::function<bool()> checkCanceled,
    tasks::TaskPgLogger& taskLogger,
    maps::pgpool3::Pool& revisionPool,
    maps::pgpool3::Pool& viewStablePool,
    da::CsvWriter& csvWriter)
{
    da::DiffLabeler labeler(editorConfig);

    INFO() << "Task " << taskData.taskId << ": loading diff contexts";
    const auto snapshotsPair = getSnapshotsPair(taskData, revisionPool);

    taskLogger.logInfo() << "Old branch " << snapshotsPair.oldBranch.id();
    taskLogger.logInfo() << "New branch " << snapshotsPair.newBranch.id();

    da::CommitFilter commitFilter;

    //always exclude group move commits
    da::ActionTypeFlags excludedActionTypes{da::ActionType::GroupMove};
    if (!taskData.withImportedObjects) {
        excludedActionTypes.set(da::ActionType::Import);
        excludedActionTypes.set(da::ActionType::GroupEditAttributes);
        excludedActionTypes.set(da::ActionType::GroupDelete);
    }
    commitFilter.setExcludedActionTypes(excludedActionTypes);

    da::UserTypeFlags includedUserTypes;
    if (taskData.userFilter == COMMON_USER_FILTER) {
        includedUserTypes.set(da::UserType::Common);
    } else if (taskData.userFilter == COMMON_OR_OUTSOURCER_USER_FILTER) {
        includedUserTypes.set(da::UserType::Common);
        includedUserTypes.set(da::UserType::Outsourcer);
    } else {
        REQUIRE(taskData.userFilter == ALL_USER_FILTER,
                "Wrong userFilter parameter '" << taskData.userFilter << "'");
        includedUserTypes.flip();
    }
    commitFilter.setIncludedUserTypes(includedUserTypes);

    da::EditorConfig editorCategoriesConfig(editorConfig.doc());
    auto compareResult = da::LongtaskDiffContext::compareSnapshots(
            snapshotsPair.oldBranch, snapshotsPair.oldSnapshotId,
            snapshotsPair.newBranch, snapshotsPair.newSnapshotId,
            revisionPool,
            viewStablePool,
            editorCategoriesConfig,
            commitFilter);

    if (!compareResult.badObjects().empty()) {
        taskLogger.logError() << "Bad objects: "
            << common::join(compareResult.badObjects(), ", ");
    }
    if (!compareResult.badRelations().empty()) {
        taskLogger.logError() << "Bad relations: "
            << common::join(compareResult.badRelations(), ", ");
    }

    std::map<da::TId, const da::LongtaskDiffContext*> diffContextById;
    for (const auto& d : compareResult.diffContexts()) {
        diffContextById.insert({d.objectId(), &d});
    }

    std::atomic<std::size_t> counter{0};

    INFO() << "Task " << taskData.taskId << ": running checks";
    common::PoolHolder resultsPoolHolder(
        overridePgPoolConstants, config, RESULTS_DB_ID, RESULTS_POOL_ID);
    da::ResultsWriter dbWriter(
        taskData.taskId,
        resultsPoolHolder.pool(),
        getDiffalertRegions(revisionPool));

    auto calcObjectDiffEnvelopes = [&](const da::LongtaskDiffContext& diffContext) -> da::DiffEnvelopes {
        return common::retryDuration([&] {
            auto txn = common::getReadTransactionForCommit(
                revisionPool,
                snapshotsPair.oldBranch.id(),
                snapshotsPair.oldSnapshotId.commitId(),
                [](const std::string& msg) { INFO() << msg; });

            mwr::RevisionsGateway rg(*txn, snapshotsPair.oldBranch);
            auto oldSnapshot = rg.snapshot(snapshotsPair.oldSnapshotId);

            return da::calcObjectDiffEnvelopes(
                diffContext.objectId(), diffContextById, oldSnapshot);
        });
    };

    auto processDiff = [&](const da::LongtaskDiffContext& diffContext)
    {
        if (checkCanceled()) {
            throw TaskCanceledException();
        }

        REQUIRE(!dbWriter.failed(), "DB writer failed");

        auto count = counter++;
        if (count % 1000 == 0) {
            INFO() << "Diff context processed " << count << "/" << compareResult.diffContexts().size();
        }

        auto messages = da::runLongTaskChecks(diffContext);
        if (messages.empty()) {
            return;
        }

        auto diffEnvelopes = calcObjectDiffEnvelopes(diffContext);

        auto [objectLabel, hasOwnName] = labeler.nameForObject(diffContext);

        std::list<da::StoredMessage> messagesToStore;
        for (const auto& msg : messages) {
            auto envelope = msgEnvelope(msg, diffEnvelopes);
            if (!std::isnormal(envelope.getWidth())
                || !std::isnormal(envelope.getHeight())) {
                ERROR() << "Invalid envelope for object " << diffContext.objectId();
                envelope.setToNull();
            }
            messagesToStore.emplace_back(msg, diffContext.categoryId(), objectLabel, hasOwnName, envelope);
        }
        csvWriter.put(diffContext, messagesToStore);
        dbWriter.put(std::move(messagesToStore));
    };

    maps::wiki::Executor runnersExecutor;
    for (const auto& diffContext : compareResult.diffContexts()) {
        runnersExecutor.addTask([&]{ processDiff(diffContext); });
    }

    auto threadsCount = calculateThreadsCount(revisionPool, viewStablePool);
    INFO() << "Threads count " << threadsCount;

    maps::wiki::ThreadPool runnersPool(threadsCount);
    runnersExecutor.executeAllInThreads(runnersPool);
    runnersPool.shutdown();

    INFO() << "Task " << taskData.taskId << ": syncing results";
    dbWriter.finish();
}

} // namespace


Worker::Worker(
        const common::ExtendedXmlDoc& config,
        const configs::editor::ConfigHolder& editorConfig)
    : config_(config)
    , editorConfig_(editorConfig)
    , revisionPoolHolder_(overridePgPoolConstants, config_, REVISION_DB_ID, REVISION_POOL_ID)
    , viewStablePoolHolder_(createViewStablePoolHolder(config_))
{}

Worker::Status Worker::run(
    const TaskData& taskData,
    std::function<bool()> checkCanceled)
{
    auto& revisionPool = revisionPoolHolder_.pool();
    auto& viewStablePool = viewStablePoolHolder_->pool();

    tasks::TaskPgLogger taskLogger(revisionPool, taskData.taskId);
    taskLogger.logInfo() << "Task started. Grinder task id: " << taskData.grinderTaskId;
    if (!taskData.ytOperationId.empty()) {
        taskLogger.logInfo() << "yt: " << taskData.ytOperationId;
    }
    taskLogger.logInfo() << "With imported objects " << taskData.withImportedObjects;
    taskLogger.logInfo() << "User filter " << taskData.userFilter;

    try {
        da::CsvWriter csvWriter(editorConfig_, getAoiRegions(config_, revisionPool));

        generateDiffAlertResults(
            config_,
            editorConfig_,
            taskData,
            checkCanceled,
            taskLogger,
            revisionPool,
            viewStablePool,
            csvWriter);

        csvWriter.publishToMds(taskData.taskId, config_, revisionPool, taskLogger);
        INFO() << "Task " << taskData.taskId << ": completed successfully";
        taskLogger.logInfo() << "Task finished";
        return Status::Ok;
    } catch (const TaskCanceledException&) {
        WARN() << "Task " << taskData.taskId << ": cancelled";
        taskLogger.logInfo() << "Task cancelled";
        return Status::Canceled;
    } catch (const maps::Exception& ex) {
        ERROR() << "Task " << taskData.taskId << " failed: " << ex;
        taskLogger.logError() << "Task failed";
        return Status::Failed;
    } catch (const std::exception& ex) {
        ERROR() << "Task " << taskData.taskId << " failed: " << ex.what();
        taskLogger.logError() << "Task failed";
        return Status::Failed;
    }
}

TaskData Worker::loadTaskData(DBID taskId)
{
    return common::retryDuration([&] {
        auto txnCore = revisionPoolHolder_.pool().masterReadOnlyTransaction();

        return TaskData(*txnCore, taskId);
    });
}

} // namespace maps::wiki::diffalert_worker
