#include "acquire.h"

#include "release.h"
#include <maps/wikimap/mapspro/libs/assessment/impl/units/load.h>
#include <maps/wikimap/mapspro/libs/assessment/impl/magic_strings.h>

namespace maps::wiki::assessment::samples {

namespace {

struct SampleTask
{
    TId id;
    TId unitId;
};

SampleTask makeSampleTask(const pqxx::row& row)
{
    return {
        row[sql::col::SAMPLE_TASK_ID].as<TId>(),
        row[sql::col::UNIT_ID].as<TId>()
    };
}

namespace sqlCondition {

const std::string IS_FREE_TASK =
    "(" +
        sql::col::ACQUIRED_AT + " IS NULL OR " +
        sql::col::ACQUIRED_AT + " <= NOW() - " + ACQUISITION_TIMEOUT_INTERVAL + " AND " +
        sql::col::GRADE_ID + " IS NULL"
    ")";

std::string isAcquiredBy(TUid acquiredBy)
{
    return
        "(" +
            sql::col::GRADE_ID + " IS NULL AND " +
            sql::col::ACQUIRED_BY + " = " + std::to_string(acquiredBy) + " AND " +
            sql::col::ACQUIRED_AT + " > NOW() - " + ACQUISITION_TIMEOUT_INTERVAL +
        ")";
}

} // namespace sqlCondition

namespace sqlSubquery {

std::string unitsSkippedBy(TUid skippedBy)
{
    return
        "("
            "SELECT " +
                sql::col::UNIT_ID + " "
            "FROM " +
                sql::table::UNIT_SKIP + " "
            "WHERE " +
                sql::col::SKIPPED_BY + " = " + std::to_string(skippedBy) +
        ")";
}

std::string sampleUnitsCompletedBy(TId sampleId, TUid completedBy)
{
    return
        "("
            "SELECT " +
                sql::col::UNIT_ID + " "
            "FROM " +
                sql::table::SAMPLE_TASK + " "
            "WHERE " +
                sql::col::SAMPLE_ID + " = " + std::to_string(sampleId) + " AND " +
                sql::col::ACQUIRED_BY + " = " + std::to_string(completedBy) + " AND " +
                sql::col::GRADE_ID + " IS NOT NULL"
        ")";
}

} // namespace sqlSubquery

std::optional<SampleTask> selectAcquiredTaskForUpdate(
    pqxx::transaction_base& txn,
    TId sampleId,
    TUid acquiredBy)
{
    const auto result = txn.exec(
        "SELECT " +
            sql::col::SAMPLE_TASK_ID + ", " +
            sql::col::UNIT_ID + " "
        "FROM " +
            sql::table::SAMPLE_TASK + " "
        "WHERE " +
            sql::col::SAMPLE_ID + " = " + std::to_string(sampleId) + " AND " +
            sqlCondition::isAcquiredBy(acquiredBy) + " "
        "LIMIT 2 " // just a precaution
        "FOR UPDATE");

    REQUIRE(
        result.size() <= 1,
        "Selected " << result.size() << " tasks "
        "acquired in sample " << sampleId << " by " << acquiredBy);

    if (result.empty()) {
        return {};
    }
    return makeSampleTask(result[0]);
}

std::optional<SampleTask> selectFreeTaskForUpdate(
    pqxx::transaction_base& txn,
    TId sampleId,
    TUid acquireBy)
{
    const auto result = txn.exec(
        "SELECT " +
            sql::col::SAMPLE_TASK_ID + ", " +
            sql::col::UNIT_ID +  " "
        "FROM " +
            sql::table::SAMPLE_TASK + " "
        "WHERE " +
            sql::col::SAMPLE_ID + " = " + std::to_string(sampleId) + " AND " +
            sql::col::UNIT_ID + " NOT IN " + sqlSubquery::unitsSkippedBy(acquireBy) + " AND " +
            sql::col::UNIT_ID + " NOT IN " + sqlSubquery::sampleUnitsCompletedBy(sampleId, acquireBy) + " AND " +
            sqlCondition::IS_FREE_TASK + " "
        "LIMIT 1 "
        "FOR UPDATE SKIP LOCKED");

    REQUIRE(
        result.size() <= 1,
        "Selected " << result.size() << " free tasks "
        "from sample " << sampleId << " for " << acquireBy);

    if (result.empty()) {
        return {};
    }
    return makeSampleTask(result[0]);
}

void skipUnit(
    pqxx::transaction_base& txn,
    TId unitId,
    TUid skipBy)
{
    const auto result = txn.exec(
        "INSERT INTO " +
            sql::table::UNIT_SKIP + " "
        "(" +
            sql::col::UNIT_ID + ", " +
            sql::col::SKIPPED_BY +
        ") "
        "VALUES "
        "(" +
            std::to_string(unitId) + ", " +
            std::to_string(skipBy) + " "
        ")"
        "ON CONFLICT (" + sql::col::UNIT_ID + ", " + sql::col::SKIPPED_BY + ") "
        "DO NOTHING");

    REQUIRE(
        result.affected_rows() == 1,
        "Skipped " << result.affected_rows() << " tasks "
        "by unitId = " << unitId << " for " << skipBy);
}

void acquireTask(
    pqxx::transaction_base& txn,
    TId taskId,
    TUid acquireBy)
{
    const auto result = txn.exec(
        "UPDATE " +
            sql::table::SAMPLE_TASK + " "
        "SET " +
            sql::col::ACQUIRED_BY + " = " + std::to_string(acquireBy) + ", " +
            sql::col::ACQUIRED_AT + " = NOW() "
        "WHERE " +
            sql::col::SAMPLE_TASK_ID + " = " + std::to_string(taskId) + " AND " +
            sql::col::GRADE_ID + " IS NULL");

    REQUIRE(
        result.affected_rows() == 1,
        "Acquired " << result.affected_rows() << " tasks for samplTaskId " << taskId);
}

} // namespace

std::optional<Unit> acquire(
    pqxx::transaction_base& txn,
    TId sampleId,
    TUid acquireBy,
    SkipAcquired skipAcquired)
{
    auto sampleTask = selectAcquiredTaskForUpdate(txn, sampleId, acquireBy);
    if (sampleTask && skipAcquired == SkipAcquired::Yes) {
        skipUnit(txn, sampleTask->unitId, acquireBy);
        samples::release(txn, sampleId, acquireBy);
        sampleTask = std::nullopt;
    }

    if (!sampleTask) {
        sampleTask = selectFreeTaskForUpdate(txn, sampleId, acquireBy);
    }

    if (!sampleTask) {
        return {};
    }

    acquireTask(txn, sampleTask->id, acquireBy);
    return units::loadById(txn, sampleTask->unitId);
}

} // maps::wiki::assessment::samples
