#include "copy_schema.h"
#include "sync_table.h"
#include <maps/wikimap/mapspro/services/editor/src/sync/db_helpers.h>
#include <maps/wikimap/mapspro/services/editor/src/sync/lock_helpers.h>
#include <maps/wikimap/mapspro/services/editor/src/common.h>
#include <maps/wikimap/mapspro/services/editor/src/branch_helpers.h>

#include <yandex/maps/wiki/common/pg_retry_helpers.h>
#include <yandex/maps/wiki/common/retry_duration.h>
#include <yandex/maps/wiki/threadutils/scheduler.h>
#include <yandex/maps/wiki/threadutils/threadpool.h>

namespace maps::wiki::tool {

namespace {

// Limit maximum thread count to avoid intense xlog growth in the database
// https://st.yandex-team.ru/NMAPS-7921
const size_t MAX_THREADS_COUNT_LIMIT = 10;

const std::string SCHEMANAME_TMP = "vrevisions_tmp";
const std::string DROP_SCHEMANAME_TMP =
    "DROP SCHEMA IF EXISTS " + SCHEMANAME_TMP + " CASCADE;";
const std::string RENAME_SCHEMANAME_TMP_TO =
    "ALTER SCHEMA " + SCHEMANAME_TMP + " RENAME TO ";

const std::string POOL_VIEW_NAME = "view";
const std::string POOL_LABELS_NAME = "labels";

const std::string DEFAULT_SEARCH_PATH;
const std::string NO_QUERY;

StringVec
loadSyncTableNames(pgpool3::Pool& pgPool, TBranchId branchId)
{
    auto query =
        "SELECT relname, n_live_tup"
        " FROM pg_stat_user_tables"
        " WHERE schemaname='" + vrevisionsSchemaName(branchId) + "'"
          " AND relname NOT IN"
            " (SELECT relname FROM pg_inherits JOIN pg_class ON inhparent=oid)"
        " ORDER BY 2 DESC";

    auto rows = common::retryDuration([&] {
        auto work = pgPool.masterReadOnlyTransaction();
        return work->exec(query);
    });

    StringVec tablenames;
    for (const auto& row : rows) {
        tablenames.emplace_back(row[0].c_str());
    }
    return tablenames;
}

void fillCopyTasks(
    pgpool3::Pool& pgPool,
    const std::string& pgPoolName,
    const ExecutionStatePtr& executionState,
    Scheduler& scheduler,
    TBranchId sourceBranchId,
    Scheduler::Executor executor)
{
    auto tableNames = loadSyncTableNames(pgPool, sourceBranchId);

    auto sourceSchemaName = vrevisionsSchemaName(sourceBranchId);
    const auto& targetSchemaName = SCHEMANAME_TMP;

    for (const auto& tableName : tableNames) {
        auto queries = buildSyncTableQueries(
            pgPool, sourceSchemaName, targetSchemaName, tableName);

        auto pgPoolTableName = pgPoolName + ": " + tableName;

        auto dropTaskId = scheduler.addTask(
            SyncTableQuery(
                pgPool, executionState, pgPoolTableName + " (drop indexes)",
                queries.searchPath, queries.dropIndexes),
            executor,
            {});

        auto dataTaskId = scheduler.addTask(
            SyncTableQuery(
                pgPool, executionState, pgPoolTableName + " (data)",
                queries.searchPath, queries.createData, SyncTable::DisplayRows::Yes),
            executor,
            {dropTaskId});

        std::vector<Scheduler::TTaskId> indexTaskIds;
        for (const auto& name2definition : queries.indexName2definition) {
            auto indexTaskId = scheduler.addTask(
                SyncTableQuery(
                    pgPool, executionState,
                    pgPoolTableName + " (index " + name2definition.first + ")",
                    queries.searchPath, name2definition.second),
                executor,
                {dataTaskId});
            indexTaskIds.push_back(indexTaskId);
        }
        if (!queries.addPrimaryKey.empty()) {
            scheduler.addTask(
                SyncTableQuery(
                    pgPool, executionState, pgPoolTableName + " (primary key)",
                    queries.searchPath, queries.addPrimaryKey),
                executor,
                indexTaskIds);
        }
    }
}

} // namespace

void copyDataFromBranch(
    const ExecutionStatePtr& executionState,
    const revision::Branch& sourceBranch,
    revision::Branch& targetBranch,
    size_t threadCount)
{
    auto sourceBranchId = sourceBranch.id();
    auto targetBranchId = targetBranch.id();
    ASSERT(targetBranchId != sourceBranchId);
    ASSERT(sourceBranchId != revision::TRUNK_BRANCH_ID);
    ASSERT(targetBranchId != revision::TRUNK_BRANCH_ID);

    threadCount = std::min(threadCount, MAX_THREADS_COUNT_LIMIT);
    INFO() << "Copy data from branch "
           << sourceBranchId << " into " << targetBranchId
           << " threads: " << threadCount;

    ThreadPool threadPool(threadCount);
    Scheduler scheduler;

    Scheduler::Executor executor =
        [&](Scheduler::Runner runner) { threadPool.push(runner); };

    scheduler.addTask(
        SyncTableRevisionMeta(
            cfg()->poolCore(), executionState, sourceBranchId, targetBranchId),
        executor,
        {});

    auto& poolView = cfg()->poolViewStable();
    std::map<std::string, pgpool3::Pool*> pgPools;
    pgPools[POOL_VIEW_NAME] = &poolView;
    if (sync::isSplittedViewLabels(targetBranchId)) {
        auto& poolLabels = cfg()->poolLabelsStable();
        pgPools[POOL_LABELS_NAME] = &poolLabels;
    }

    for (const auto& [poolName, poolPtr] : pgPools) {
        fillCopyTasks(
            *poolPtr,
            poolName,
            executionState,
            scheduler,
            sourceBranchId,
            executor);
    }

    scheduler.executeAll();
    threadPool.shutdown();

    if (!executionState->isOk()) {
        for (const auto& [poolName, poolPtr] : pgPools) {
            common::execCommitWithRetries(
                *poolPtr,
                poolName + " drop " + SCHEMANAME_TMP,
                DEFAULT_SEARCH_PATH,
                DROP_SCHEMANAME_TMP);
        }
        return;
    }

    auto schemaName = vrevisionsSchemaName(targetBranchId);
    auto mergeViewStableQuery =
        "SELECT vrevisions_stable.merge_approved_to_stable_branch(" +
                std::to_string(targetBranchId) + ")";

    for (const auto& [poolName, poolPtr] : pgPools) {
        common::execCommitWithRetries(
            *poolPtr,
            poolName + " rename " + SCHEMANAME_TMP,
            DEFAULT_SEARCH_PATH,
            NO_QUERY,
            [&, poolName=poolName] (pqxx::transaction_base& txn) {
                if (poolName == POOL_VIEW_NAME) {
                    sync::lockViews(txn, sync::LockType::Exclusive);
                    txn.exec(mergeViewStableQuery);
                }
                txn.exec(RENAME_SCHEMANAME_TMP_TO + schemaName);
                INFO() << "renamed " << poolName << ": " << SCHEMANAME_TMP << " to " << schemaName;
            }
        );
    }

    common::execCommitWithRetries(
        cfg()->poolCore(),
        "set normal state for branch: " + std::to_string(targetBranchId),
        DEFAULT_SEARCH_PATH,
        NO_QUERY,
        [&] (pqxx::transaction_base& txn) {
            targetBranch.setState(txn, revision::BranchState::Normal);
        }
    );
}

} // namespace maps::wiki::tool

