#include "conflicts_checker.h"
#include "sql_strings.h"
#include "helpers.h"

#include <yandex/maps/wiki/revision/exception.h>

#include <boost/format.hpp>

namespace maps::wiki::revision {

using namespace helpers;

namespace {

const boost::format CHECK_CONFLICTS_IN_TRUNK_BRANCH (
    "SELECT COUNT(1), SUM((next_commit_id > 0)::int)"
    " FROM revision.object_revision WHERE %1%"
);

const boost::format CHECK_CONFLICTS_IN_STABLE_OR_ARCHIVE_BRANCH (
    "SELECT COUNT(1), SUM(object_id IN ("
    "   SELECT object_id"
    "       FROM revision.object_revision"
    "       WHERE %3% AND commit_id IN ("
    "           SELECT id"
    "               FROM revision.commit"
    "               WHERE stable_branch_id = %2% AND NOT trunk))::int)"
    " FROM revision.object_revision WHERE %1%"
);


std::string makeCheckConflictsExpression(
    const std::string& queryIdsStr,
    BranchType branchType,
    DBID branchId,
    const CommitIdToObjectIds& revisionIds)
{
    if (branchType == BranchType::Trunk) {
        boost::format query = CHECK_CONFLICTS_IN_TRUNK_BRANCH;
        query % queryIdsStr;
        return query.str();
    }
    boost::format query = CHECK_CONFLICTS_IN_STABLE_OR_ARCHIVE_BRANCH;
    query % queryIdsStr;
    query % branchId;
    query % QueryGenerator::buildSpecialFilter(revisionIds, sql::col::PREV_COMMIT_ID);
    return query.str();
}

} // namespace

ConflictsChecker::ConflictsChecker(Transaction& work)
    : work_(work)
{}

void
ConflictsChecker::checkConflicts(
    BranchType branchType,
    DBID branchId,
    const CommitIdToObjectIds& ids) const
{
    REQUIRE(branchType != BranchType::Approved,
            "Conflicts checking is forbidden for approved branch");

    auto queryIdsStr = QueryGenerator::buildSpecialFilter(ids, sql::col::COMMIT_ID);

    auto query = makeCheckConflictsExpression(queryIdsStr, branchType, branchId, ids);
    auto result = work_.exec(query);
    ASSERT(result.size() == 1);

    const auto& row = result[0];
    auto allCount = row[0].as<size_t>();
    auto errCount = row[1].as<size_t>(0); // SUM(NULL) -> 0
    if (errCount) {
        throw ConflictsFoundException() <<
            "found " << errCount << " conflicts, "
            "count: " << allCount << " ids: " << queryIdsStr;
    }

    const auto rowsCount = QueryGenerator::countRows(ids);
    REQUIRE(allCount == rowsCount, "found " << rowsCount - allCount <<
            " non-existent revisions while checking conflicts, ids: " << queryIdsStr);
}

} // namespace maps::wiki::revision
