#include "load.h"
#include "status.h"
#include "grade_status.h"

#include <maps/wikimap/mapspro/libs/assessment/impl/sql_helpers.h>
#include <maps/wikimap/mapspro/libs/assessment/impl/magic_strings.h>

#include <yandex/maps/wiki/common/string_utils.h>

#include <boost/lexical_cast.hpp>
#include <fmt/format.h>

namespace maps::wiki::assessment::units {

using namespace fmt::literals;

namespace {

Unit makeUnit(const pqxx::row& row)
{
    return {
        .id = row[sql::col::UNIT_ID].as<TId>(),
        .entity = {
            row[sql::col::ENTITY_ID].c_str(),
            enum_io::fromString<Entity::Domain>(row[sql::col::ENTITY_DOMAIN].c_str())
        },
        .action {
            row[sql::col::ACTION].c_str(),
            row[sql::col::ACTION_BY].as<TUid>(),
            chrono::parseSqlDateTime(row[sql::col::ACTION_AT].c_str())
        }
    };
}

template<typename T>
std::vector<std::optional<T>> parseArrayNullable(
    const pqxx::field& field,
    std::optional<size_t> requiredSize = {})
{
    using juncture = pqxx::array_parser::juncture;
    std::vector<std::optional<T>> result;
    auto parser = field.as_array();

    auto elem = parser.get_next();
    if (elem.first == juncture::done) {
        return result;
    }

    ASSERT(elem.first == juncture::row_start);
    elem = parser.get_next();

    for ( ; elem.first != juncture::row_end; elem = parser.get_next()) {
        if (elem.first == juncture::string_value) {
            result.emplace_back(boost::lexical_cast<T>(elem.second));
        } else if (elem.first == juncture::null_value) {
            result.emplace_back(std::nullopt);
        } else {
            throw RuntimeError() << "Unexpected juncture " << elem.first << " in " << field.c_str();
        }
    }

    elem = parser.get_next();
    ASSERT(elem.first == juncture::done);

    ASSERT(!requiredSize || result.size() == *requiredSize);
    return result;
}

template<typename T>
std::vector<T> parseArray(const pqxx::field& field, std::optional<size_t> requiredSize = {})
{
    std::vector<T> result;
    for (auto& element : parseArrayNullable<T>(field, requiredSize)) {
        ASSERT(element);
        result.push_back(*element);
    }
    return result;
}

GradeVec makeGradeVec(const pqxx::row& row)
{
    GradeVec grades;

    const auto grade_id = parseArrayNullable<TId>(row[sql::col::GRADE_ID]);
    if (grade_id.size() == 1 && !grade_id[0]) {
        return grades;
    }

    const auto graded_by = parseArray<TUid>(row[sql::col::GRADED_BY], grade_id.size());
    const auto graded_at = parseArray<std::string>(row[sql::col::GRADED_AT], grade_id.size());
    const auto value = parseArray<Grade::Value>(row[sql::col::VALUE], grade_id.size());
    const auto comment = parseArrayNullable<std::string>(row[sql::col::COMMENT], grade_id.size());
    const auto qualification = parseArray<Qualification>(row[sql::col::QUALIFICATION], grade_id.size());

    grades.reserve(grade_id.size());
    for (size_t i = 0; i < grade_id.size(); ++i) {
        grades.emplace_back(Grade{
            *grade_id[i],
            graded_by[i],
            chrono::parseSqlDateTime(graded_at[i]),
            value[i],
            comment[i],
            qualification[i]});
    }
    return grades;
}

GradedUnit makeGradedUnit(const pqxx::row& row)
{
    UnitPermissionFlags permissions;
    if (row["fixable"].as<bool>()) {
        permissions.set(UnitPermission::Fixable);
    }
    if (row["refutation_acceptable"].as<bool>()) {
        permissions.set(UnitPermission::RefutationAcceptable);
    }

    return {
        makeUnit(row),
        makeGradeVec(row),
        permissions
    };
}

auto makeGradedUnitVec(const pqxx::result& rows)
{
    GradedUnitVec result;
    result.reserve(rows.size());

    for (const auto& row : rows) {
        result.emplace_back(makeGradedUnit(row));
    }
    return result;
}

} // namespace

namespace sqlClause {

namespace {

const std::string UNIT_COLUMNS =
    sql::col::UNIT_ID + ", " +
    sql::col::ENTITY_ID + ", " +
    sql::col::ENTITY_DOMAIN + ", " +
    sql::col::ACTION_BY + ", " +
    sql::col::ACTION_AT + ", " +
    sql::col::ACTION;

std::string arrayAggOrderedByGradeId(const std::string& column)
{
    return "ARRAY_AGG(" + column + " ORDER BY " + sql::col::GRADE_ID + " DESC) as " + column;
}

const std::string AGGREGATED_GRADE_COLUMNS =
    sqlClause::arrayAggOrderedByGradeId(sql::col::GRADE_ID) + ", " +
    sqlClause::arrayAggOrderedByGradeId(sql::col::GRADED_BY) + ", " +
    sqlClause::arrayAggOrderedByGradeId(sql::col::GRADED_AT) + ", " +
    sqlClause::arrayAggOrderedByGradeId(sql::col::VALUE) + ", " +
    sqlClause::arrayAggOrderedByGradeId(sql::col::COMMENT) + ", " +
    sqlClause::arrayAggOrderedByGradeId(sql::col::QUALIFICATION);

// NMAPS-14927 Optimize unit feed
//
// With `graded-by={uid}` only units graded by `{uid}` are selected. To make use of
// `(graded_by, unit_id)` index we need to do it as `WHERE graded_by = {uid}`. Let's
// apply this condition in a subquery, as otherwise it would mess up grades' visibility.

const std::string SELECT_GRADED_UNITS_FMT = fmt::format(
    "SELECT "
        "{unit_columns}, "
        "{aggregated_grade_columns}, "
        "{{fixable}} as fixable, "
        "BOOL_OR({{refutation_acceptable}}) as refutation_acceptable "
    "FROM {unit} JOIN {grade} USING ({unit_id}) "
    "WHERE "
        "{unit_id} IN ("
            "SELECT "
                "{unit_id} "
            "FROM {unit} "
            "JOIN {grade} USING ({unit_id}) "
            "WHERE "
                "({{where_feed_params}}) AND "
                "({{where_filter}}) AND "
                "({{where_grades_visible}}) "
            "GROUP BY {unit_id} "
            "ORDER BY {unit_id} {{order}} "
            "LIMIT {{limit}}"
        ") "
    "GROUP BY {unit_id} "
    "ORDER BY {unit_id} {{order}}",

    "unit_id"_a=sql::col::UNIT_ID,
    "unit_columns"_a=UNIT_COLUMNS,
    "aggregated_grade_columns"_a=AGGREGATED_GRADE_COLUMNS,

    "unit"_a=sql::table::UNIT,
    "grade"_a=sql::table::GRADE);

const std::string SELECT_UP_TO_TWO_UNITS_FMT =
    "SELECT "
        "{columns} "
    "FROM " + sql::table::UNIT + " "
    "WHERE {where}"
    "LIMIT 2";

std::string where(const UnitFeedParams& params)
{
    // newest-first order
    if (params.before()) {
        return "(" + sql::col::UNIT_ID + " > " + std::to_string(params.before()) + ")";
    }
    if (params.after()) {
        return "(" + sql::col::UNIT_ID + " < " + std::to_string(params.after()) + ")";
    }
    return "TRUE";
}

std::string where(pqxx::transaction_base& txn, const UnitFilter& filter)
{
    std::stringstream query;
    query << "TRUE";

    if (filter.entityDomain()) {
        query << " AND " << sql::col::ENTITY_DOMAIN << " = " << sqlEnumToString(txn, *filter.entityDomain());
    }
    if (filter.entityIds()) {
        REQUIRE(!filter.entityIds()->empty(), "Empty set of entityIds is requested");
        const auto quote = [&txn] (const auto& value) { return txn.quote(value); };
        const auto entityIds = common::join(*filter.entityIds(), quote, ", ");
        query << " AND " << sql::col::ENTITY_ID << " IN (" << entityIds << ")";
    }
    if (filter.actionBy()) {
        query << " AND " << sql::col::ACTION_BY << " = " << *filter.actionBy();
    }
    if (filter.gradedBy()) {
        query << " AND " << sql::col::GRADED_BY << " = " << *filter.gradedBy();
    }
    if (filter.gradeValue()) {
        query << " AND " << sql::col::VALUE << " = " << sqlEnumToString(txn, *filter.gradeValue());
    }
    if (filter.gradeConfirmed()) {
        const auto op = *filter.gradeConfirmed() ? " = " : " != ";
        query << " AND " << sql::col::QUALIFICATION << " != " << sqlEnumToString(txn, Qualification::Expert)
              << " AND " << sql::col::VALUE << op << sql::col::LAST_EXPERT_VALUE;
    }
    if (filter.gradeStatus()) {
        query << " AND " << gradeStatusEquals(*filter.gradeStatus());
    }
    if (filter.status()) {
        query << " AND " << statusEquals(*filter.status());
    }

    return query.str();
}

std::string where(TUid uid, GradesVisibility visibility)
{
    switch (visibility) {
    case GradesVisibility::All:
        return "TRUE";

    case GradesVisibility::My:
        return sql::col::GRADED_BY + " = " + std::to_string(uid);

    case GradesVisibility::MyOrReceivedByMe:
        return "(" +
            sql::col::GRADED_BY + " = " + std::to_string(uid) + " OR " +
            sql::col::ACTION_BY + " = " + std::to_string(uid) +
        ")";
    }
}

std::string unitMatches(pqxx::transaction_base& txn, const Entity& entity, const Action& action)
{
    std::stringstream query;

    query <<
        sql::col::ENTITY_ID << " = " << txn.quote(entity.id) << " AND " <<
        sql::col::ENTITY_DOMAIN << " = " << sqlEnumToString(txn, entity.domain) << " AND " <<
        sql::col::ACTION_BY << " = " << action.by << " AND " <<
        sql::col::ACTION << " = " << txn.quote(action.name);

    if (entity.domain == Entity::Domain::Feedback) {
        query << " AND " << sql::col::ACTION_AT << " = " << txn.quote(chrono::formatSqlDateTime(action.at));
    }
    return query.str();
}

} // namespace

} // namespace sqlClause

std::optional<TId> get(pqxx::transaction_base& txn, const Entity& entity, const Action& action)
{
    const auto result = txn.exec(fmt::format(
        sqlClause::SELECT_UP_TO_TWO_UNITS_FMT,
        "columns"_a = sql::col::UNIT_ID,
        "where"_a = sqlClause::unitMatches(txn, entity, action)));

    REQUIRE(result.size() <= 1, "Multiple unitIds selected by entity & action");
    return result.empty() ? std::optional<TId>() : std::make_optional(result[0][sql::col::UNIT_ID].as<TId>());
}

Unit loadById(pqxx::transaction_base& txn, TId unitId)
{
    const auto result = txn.exec(fmt::format(
        sqlClause::SELECT_UP_TO_TWO_UNITS_FMT,
        "columns"_a = sqlClause::UNIT_COLUMNS,
        "where"_a = sql::col::UNIT_ID + " = " + std::to_string(unitId)));

    REQUIRE(result.size() == 1, "Got " << result.size() << " units by " << unitId << " unitId");
    return makeUnit(result[0]);
}

UnitFeed loadFeed(
    pqxx::transaction_base& txn,
    TUid uid,
    const UnitFeedParams& params,
    const UnitFilter& filter,
    GradesVisibility gradesVisibility)
{
    REQUIRE(!params.before() || !params.after(), "Both before and after are specified");

    const auto units = makeGradedUnitVec(txn.exec(fmt::format(
        sqlClause::SELECT_GRADED_UNITS_FMT,
        "fixable"_a = sqlClause::fixable(uid),
        "refutation_acceptable"_a = sqlClause::refutationAcceptable(uid),

        "where_feed_params"_a = sqlClause::where(params),
        "where_filter"_a = sqlClause::where(txn, filter),
        "where_grades_visible"_a = sqlClause::where(uid, gradesVisibility),

        "order"_a = (params.before() ? "ASC" : "DESC"),
        "limit"_a = params.perPage() + 1)));

    const auto size = std::min(units.size(), params.perPage());
    const auto hasMore = units.size() > params.perPage();

    if (params.before()) {
        return UnitFeed(GradedUnitVec(units.crbegin(), units.crbegin() + size), hasMore);
    }
    return UnitFeed(GradedUnitVec(units.cbegin(), units.cbegin() + size), hasMore);
}

} // namespace maps::wiki::assessment::units
