#include <yandex/maps/wiki/validator/storage/exclusions_gateway.h>
#include <yandex/maps/wiki/validator/storage/exception.h>

#include "magic_strings.h"
#include "helpers.h"

#include <maps/libs/common/include/exception.h>

namespace maps::wiki::validator::storage {

namespace {

bool exclusionsTableExists(Transaction& txn)
{
    size_t sepPos = EXCLUSION_TABLE.find('.');
    ASSERT(sepPos != std::string::npos);

    std::string exclusionsTableQuery =
        "SELECT 1 FROM pg_tables"
        " WHERE schemaname = " + txn.quote(EXCLUSION_TABLE.substr(0, sepPos))
        + " AND tablename = "
        + txn.quote(EXCLUSION_TABLE.substr(sepPos + 1, std::string::npos));
    return !txn.exec(exclusionsTableQuery).empty();
}

std::string whereClause(
    const ExclusionsFilter& filter,
    const Transaction& txn)
{
    std::ostringstream result;
    result << whereClause(filter.attributes, txn);
    if (filter.bbox) {
        result << " AND " + storage::whereClause(*filter.bbox);
    }
    if (filter.createdBy) {
        result << " AND " << CREATED_BY_COLUMN_NAME << " = " << *filter.createdBy;
    }
    if (filter.createdBefore) {
        std::string sqlTime = chrono::formatSqlDateTime(*filter.createdBefore);
        result << " AND " << CREATED_COLUMN_NAME << " < " << txn.quote(sqlTime);
    }

    return result.str();
}

std::string havingClause(const ExclusionsFilter& filter)
{
    std::ostringstream result;
    if (filter.viewedBy) {
        result << " HAVING ";
        if (filter.viewedBy->state == ViewedState::Viewed) {
            result <<
                filter.viewedBy->uid <<
                " = ANY (array_agg(" << VIEWED_BY_COLUMN_NAME << "))";
        } else {
            result <<
                "NOT " << filter.viewedBy->uid <<
                " = ANY (array_agg(" << VIEWED_BY_COLUMN_NAME << "))"
                " OR " <<
                filter.viewedBy->uid <<
                " = ANY (array_agg(" << VIEWED_BY_COLUMN_NAME << "))"
                " IS NULL";
        }
    }
    return result.str();
}

} // namespace

ExclusionsGateway::ExclusionsGateway(Transaction& txn)
    : txn_(txn)
{ }

StoredMessageDatum ExclusionsGateway::addExclusion(
    const MessageId& messageId,
    TUId createdBy,
    const revision::Snapshot& snapshot) const
{
    pqxx::result r;
    try {
        r = txn_.exec(
            "WITH inserted AS ("
            " INSERT INTO " + EXCLUSION_TABLE
            + "(" + MESSAGE_ID_COLUMNS + "," + EXCLUSION_INFO_COLUMNS + ")"
            + "VALUES ("
            + std::to_string(messageId.attributesId()) + ", "
            + std::to_string(messageId.contentId()) +
            ", NOW(), "
            + std::to_string(createdBy) + ")"
            + "RETURNING " + MESSAGE_ID_COLUMNS + "," + EXCLUSION_INFO_COLUMNS
            + ")"
            + "SELECT " + STORED_MESSAGE_COLUMNS
            + " FROM inserted" + JOIN_MESSAGE_SUBSTANCE_TABLES
        );
    } catch (const pqxx::unique_violation&) {
        throw ExclusionExistsError();
    } catch (const pqxx::integrity_constraint_violation&) {
        throw NonexistentMessageError();
    }

    StoredMessageData datum = storedMessagesFromDbRows(snapshot, r);
    ASSERT(datum.size() == 1);
    return datum[0];
}

StoredMessageDatum ExclusionsGateway::removeExclusion(
    const MessageId& messageId,
    const revision::Snapshot& snapshot) const
{
    pqxx::result r = txn_.exec(
        "WITH deleted AS ("
        "DELETE FROM " + EXCLUSION_TABLE
        + " WHERE (" + MESSAGE_ID_COLUMNS + ") = ("
        + std::to_string(messageId.attributesId()) + ", "
        + std::to_string(messageId.contentId()) + ")"
        + " RETURNING "
        + MESSAGE_ID_COLUMNS
        + ", NULL AS " + CREATED_COLUMN_NAME
        + ", NULL AS " + CREATED_BY_COLUMN_NAME
        + ")"
        + "SELECT " + STORED_MESSAGE_COLUMNS
        + " FROM deleted" + JOIN_MESSAGE_SUBSTANCE_TABLES);

    StoredMessageData data = storedMessagesFromDbRows(snapshot, r);
    ASSERT(data.size() <= 1);
    if (data.empty()) {
        throw NonexistentExclusionError();
    }
    return data[0];
}

StoredMessageDatum ExclusionsGateway::viewExclusion(
    const MessageId& messageId,
    TUId viewedBy,
    const revision::Snapshot& snapshot) const
{
    auto attributesIdStr = std::to_string(messageId.attributesId());
    auto contentIdStr = std::to_string(messageId.contentId());

    std::string viewQuery =
        "INSERT INTO " + EXCLUSION_VIEW_TABLE +
        " (" + MESSAGE_ID_COLUMNS + "," + VIEWED_BY_COLUMN_NAME + ") " +
        "VALUES (" + attributesIdStr +
                "," + contentIdStr +
                "," + std::to_string(viewedBy) +
        ") ON CONFLICT DO NOTHING";
    txn_.exec(viewQuery);

    std::string resultQuery =
        "SELECT " + STORED_MESSAGE_COLUMNS
        + ", ARRAY[" + std::to_string(viewedBy) + "] AS " + VIEWED_BY_ARR_COLUMN_NAME +
        + " FROM " + EXCLUSION_TABLE + JOIN_MESSAGE_SUBSTANCE_TABLES
        + " WHERE " + ATTRIBUTES_ID_COLUMN_NAME + " = " + attributesIdStr
        + " AND " + CONTENT_ID_COLUMN_NAME + " = " + contentIdStr;
    auto rows = txn_.exec(resultQuery);

    StoredMessageData data = storedMessagesFromDbRows(snapshot, rows);
    ASSERT(data.size() == 1);
    return data[0];
}

MessageStatistics ExclusionsGateway::statistics(
    const ExclusionsFilter& filter) const
{
    std::string joinClause = filter.bbox
        ? JOIN_MESSAGE_SUBSTANCE_TABLES
        : JOIN_MESSAGE_ATTRIBUTES_TABLE;

    std::string query =
        "SELECT "
        + MESSAGE_ATTRIBUTES_COLUMNS
        + ", count(*) AS " + COUNT_COLUMN_NAME
        + " FROM " + EXCLUSION_TABLE + joinClause
        + " WHERE " + whereClause(filter, txn_)
        + " GROUP BY " + MESSAGE_ATTRIBUTES_COLUMNS;

    MessageStatistics result;

    for (const auto& row : txn_.exec(query)) {
        auto attributes = messageAttributesFromDbRow(row);
        result[attributes] = row[COUNT_COLUMN_NAME].as<size_t>();
    }

    return result;
}

ExclusionsGateway::ResultBeforeAfter
ExclusionsGateway::exclusions(
    const ExclusionsFilter& filter,
    const revision::Snapshot& snapshot,
    const MessageId& startID,
    common::BeforeAfter beforeAfter,
    size_t limit) const
{
    std::string query =
        "SELECT " + STORED_MESSAGE_COLUMNS
        + ", array_agg(" + VIEWED_BY_COLUMN_NAME + ") AS " + VIEWED_BY_ARR_COLUMN_NAME
        + " FROM " + EXCLUSION_TABLE + JOIN_MESSAGE_SUBSTANCE_TABLES
        + " LEFT JOIN " + EXCLUSION_VIEW_TABLE + " USING (" + MESSAGE_ID_COLUMNS + ")"
        + " WHERE " + whereClause(filter, txn_);

    if (!startID.empty()) {
        std::string startQuery =
            "SELECT " + STORED_MESSAGE_COLUMNS +
            " FROM " + EXCLUSION_TABLE + JOIN_MESSAGE_SUBSTANCE_TABLES +
            " WHERE " + ATTRIBUTES_ID_COLUMN_NAME + " = " + std::to_string(startID.attributesId()) +
            " AND " + CONTENT_ID_COLUMN_NAME + " = " + std::to_string(startID.contentId());
        const auto startMessage = storedMessagesFromDbRows(snapshot, txn_.exec(startQuery)).front();
        query +=
            " AND " + CREATED_COLUMN_NAME + " " +
            (beforeAfter == common::BeforeAfter::Before ? GREATER : LESS) +
            txn_.quote(startMessage.exclusionInfo()->createdAt);
    }

    query += " GROUP BY "
          + MESSAGE_ID_COLUMNS + ","
          + MESSAGE_ATTRIBUTES_COLUMNS + ","
          + GEOMETRY_COLUMN_NAME + ","
          + REVISION_IDS_COLUMN_NAME + ","
          + EXCLUSION_INFO_COLUMNS;
    query += havingClause(filter);

    bool reverse = !startID.empty() && beforeAfter == common::BeforeAfter::Before;
    query += reverse ? MESSAGES_ORDER_REVERT : MESSAGES_ORDER;

    if (limit) {
        query += " LIMIT " + std::to_string(limit + 1);
    }
    const auto rows = txn_.exec(query);
    if (rows.empty()) {
        return {};
    }
    const bool hasMore = limit ? (rows.size() > limit) : false;

    return ResultBeforeAfter {
        reverse
            ? storedMessagesFromDbRows(
                snapshot,
                limit >= rows.size() ? rows.rbegin() : rows.rbegin() + 1,
                rows.rend(),
                limit)
            : storedMessagesFromDbRows(
                snapshot,
                rows.begin(),
                rows.end(),
                limit),
        hasMore
    };
}

StoredMessageData ExclusionsGateway::exclusions(
    const ExclusionsFilter& filter,
    const revision::Snapshot& snapshot,
    size_t offset, size_t limit) const
{
    std::string query =
        "SELECT " + STORED_MESSAGE_COLUMNS
        + " FROM " + EXCLUSION_TABLE + JOIN_MESSAGE_SUBSTANCE_TABLES
        + " WHERE " + whereClause(filter, txn_)
        + " ORDER BY " + CREATED_COLUMN_NAME + " DESC"
        + " OFFSET " + std::to_string(offset)
        + " LIMIT " + std::to_string(limit);

    return storedMessagesFromDbRows(snapshot, txn_.exec(query));
}

std::vector<Message> ExclusionsGateway::exclusionsForCheck(
    const TCheckId& check,
    const boost::optional<geolib3::BoundingBox>& aoiBbox) const
{
    if (!exclusionsTableExists(txn_)) {
        return {};
    }

    std::string query =
        + "SELECT " + MESSAGE_SUBSTANCE_COLUMNS
        + " FROM " + EXCLUSION_TABLE + JOIN_MESSAGE_SUBSTANCE_TABLES
        + " WHERE " + CHECK_ID_COLUMN_NAME + " = " + txn_.quote(check);

    if (aoiBbox) {
        query += " AND ("
            + GEOMETRY_COLUMN_NAME + " IS NULL OR "
            + whereClause(*aoiBbox)
            + ")";
    }

    std::vector<Message> result;
    for (const auto& row : txn_.exec(query)) {
        result.push_back(messageFromDbRow(row));
    }

    return result;
}

} // namespace maps::wiki::validator::storage
