#include "task_context.h"
#include "check_context_impl.h"
#include "common/exception.h"
#include "common/magic_strings.h"
#include "common/utils.h"

#include <yandex/maps/wiki/validator/result.h>
#include <yandex/maps/wiki/threadutils/scheduler.h>
#include <yandex/maps/wiki/common/string_utils.h>

#include <boost/none.hpp>
#include <vector>

namespace maps::wiki::validator {

namespace {

void runCheckPart(
    TCheckId checkId,
    TCheckPartId checkPartId,
    ICheckPartPtr runner,
    CheckContext& context)
{
    std::string checkPartNote;
    if (!checkPartId.empty()) {
        checkPartNote = "part " + checkPartId + " of ";
    }

    try {
        Timer timer;
        INFO() << "Starting " << checkPartNote << "check " << checkId;

        runner->run(&context);

        INFO() << checkPartNote << "check " << checkId << " finished in "
               << timer.elapsed().count() << "s.";
    } catch (...) {
        context.fatal("check-internal-error", boost::none, {});
        logException(log8::Level::ERROR,
            "while executing " + checkPartNote + "check " + checkId);
    }
}

std::set<TCategoryId> categoriesFromCheckPart(
        const DataSource& dataSource,
        const CheckPartMeta& checkPart,
        LoaderType loaderType)
{
    std::set<TCategoryId> result;

    std::set<TCategoryId> current(
        checkPart.dependencies.begin(), checkPart.dependencies.end());
    while (!current.empty()) {
        std::set<TCategoryId> next;
        for (const auto& category : current) {
            result.insert(category);
            for (const auto& dep : dataSource.dependencies(category, loaderType)) {
                if (!result.count(dep)) {
                    next.insert(dep);
                }
            }
        }
        current.swap(next);
    }
    return result;
}

} // namespace

TaskContext::TaskContext(
        const ValidatorConfig& validatorConfig,
        CheckCardinality checkCardinality,
        pgpool3::Pool& pgPool,
        DBID branchId,
        DBID commitId,
        AreaOfInterest aoi,
        const ObjectIdSet& objectIds,
        const std::vector<CheckMeta>& checks,
        size_t checkThreadsCount,
        CanceledChecker canceledChecker)
    : dataSource_(validatorConfig, checkCardinality, pgPool, branchId, commitId, std::move(aoi), objectIds)
    , dbReadWorkers_(dataSource_.dbGateway().maxConnections())
    , categoryLoadWorkers_(dataSource_.dbGateway().maxConnections())
    , checks_(checks)
    , checkThreadsCount_(checkThreadsCount)
    , canceledChecker_(std::move(canceledChecker))
{ }

TaskContext::~TaskContext()
{
    if (masterThread_.joinable()) {
        masterThread_.join();
    }
}

std::vector<TCheckId> TaskContext::checkIds() const
{
    std::vector<TCheckId> checkIds;
    checkIds.reserve(checks_.size());

    for (const auto& check : checks_) {
        checkIds.push_back(check.id);
    }

    return checkIds;
}

void TaskContext::run()
{
    if (!masterThread_.joinable()) {
        masterThread_ = std::thread(&TaskContext::runMasterThread, this);
    }
}

MessageBuffer TaskContext::popMessages()
{
    if (masterThread_.joinable()) {
        masterThread_.join();
    }
    auto result = taskMessages_.popMessages();
    if (result.empty() && lastError_) {
        std::rethrow_exception(lastError_);
    }
    return result;
}

class TaskContext::CategoryLoadTasks
{
public:
    CategoryLoadTasks(
            TaskContext& taskContext,
            Scheduler& scheduler,
            LoaderType loaderType)
        : taskContext_(taskContext)
        , scheduler_(scheduler)
        , loaderType_(loaderType)
    {
    }

    Scheduler::TTaskId taskId(const TCategoryId& category)
    {
        std::unordered_set<TCategoryId> pending;
        return taskIdInternal(category, pending);
    }

private:
    Scheduler::TTaskId taskIdInternal(
            const TCategoryId& category,
            std::unordered_set<TCategoryId>& pending)
    {
        auto taskIt = taskIds_.find(category);
        if (taskIt != taskIds_.end()) {
            return taskIt->second;
        }
        pending.insert(category);
        std::vector<Scheduler::TTaskId> dependencyTaskIds;
        for (const TCategoryId& dep :
                taskContext_.dataSource_.dependencies(category, loaderType_)) {
            REQUIRE(!pending.count(dep),
                    "cycle in dependencies for category " << category);
            dependencyTaskIds.push_back(taskIdInternal(dep, pending));
        }

        auto newTaskId = scheduler_.addTask(
            [&, category]() {
                taskContext_.runCategoryLoad(category, loaderType_);
            },
            [&](Scheduler::Runner runner) {
                taskContext_.categoryLoadWorkers_.push(std::move(runner));
            },
            dependencyTaskIds);
        taskIds_.emplace(category, newTaskId);
        pending.erase(category);

        return newTaskId;
    }

private:
    TaskContext& taskContext_;
    Scheduler& scheduler_;
    LoaderType loaderType_;
    std::unordered_map<TCategoryId, Scheduler::TTaskId> taskIds_;
};

void TaskContext::runMasterThread()
{
    Timer timer;
    try {
        Scheduler scheduler;
        if (canceledChecker_) {
            scheduler.setIsCanceledChecker(canceledChecker_);
        }

        auto loaderType =
            dataSource_.aoi().empty()
            ? dataSource_.hasSelectedObjectIds()
                ? LoaderType::LoadBySelectedObjects
                : LoaderType::LoadAll
            : LoaderType::LoadFromAoi;

        CategoryLoadTasks categoryLoadTasks(*this, scheduler, loaderType);

        ThreadPool checkWorkers(checkThreadsCount_);

        struct CheckPartData {
            TCheckId id;
            const CheckPartMeta* checkPartMetaPtr;
            std::set<TCategoryId> categories;
        };
        std::map<TCategoryId, std::atomic<size_t>> categoryCounts;

        std::vector<CheckPartData> checks;
        for (const auto& check : checks_) {
            for (const auto& checkPart : check.parts) {
                auto categories = categoriesFromCheckPart(dataSource_, checkPart, loaderType);
                for (const auto& category : categories) {
                    ++(categoryCounts[category]);
                }
                checks.push_back(CheckPartData{
                    check.id, &checkPart, std::move(categories)});
            }
        }
        for (const auto& pair : categoryCounts) {
            INFO() << "Check category: " << pair.first << " " << pair.second;
        }

        for (const auto& check : checks) {
            const auto& checkPart = *check.checkPartMetaPtr;
            INFO() << "Check: " << check.id << " " << checkPart.id << " " << common::join(check.categories, ',');

            std::vector<Scheduler::TTaskId> dependencyTaskIds;
            for (const TCategoryId& dep : checkPart.dependencies) {
                dependencyTaskIds.push_back(categoryLoadTasks.taskId(dep));
            }

            scheduler.addTask(
                [this, check, &checkWorkers, &categoryCounts] {
                    const auto& checkPart = *check.checkPartMetaPtr;

                    CheckContext context(new CheckContext::Impl(
                        check.id, checkPart, dataSource_, taskMessages_, checkWorkers));

                    runCheckPart(check.id, checkPart.id, checkPart.runner, context);
                    for (const auto& category : check.categories) {
                        auto& cnt = categoryCounts.at(category);
                        auto prevCount = cnt.fetch_sub(1);
                        if (prevCount == 1) {
                            INFO() << "Unload category: " << category;
                            dataSource_.unload(category);
                        }
                    }
                },
                [&](Scheduler::Runner runner) {
                    checkWorkers.push(std::move(runner));
                },
                dependencyTaskIds);
        }

        scheduler.executeAll();

        INFO() << "All checks finished in " << timer.elapsed().count() << "s.";
    } catch (const ExecutionCanceled& e) {
        WARN() << "Validation canceled after " << timer.elapsed().count() << "s.";
        lastError_ = std::make_exception_ptr(ValidationCanceledException() << e.what());
    } catch (...) {
        lastError_ = std::current_exception();
        logException(log8::Level::ERROR, "while running master thread");
    }
}

void TaskContext::runCategoryLoad(TCategoryId category, LoaderType loaderType)
{
    auto baseCheckMessages = dataSource_.load(category, loaderType, dbReadWorkers_);
    taskMessages_.addBaseCheckMessages(std::move(baseCheckMessages));
}

} // namespace maps::wiki::validator
