#include "exclusions_set.h"
#include "helpers.h"
#include "magic_strings.h"
#include <maps/wikimap/mapspro/libs/validator/common/magic_strings.h>

#include <yandex/maps/wiki/validator/storage/exclusions_gateway.h>
#include <yandex/maps/wiki/validator/storage/message_attributes_filter.h>
#include <yandex/maps/wiki/validator/storage/messages_writer.h>
#include <yandex/maps/wiki/validator/storage/results_gateway.h>
#include <yandex/maps/wiki/validator/result.h>
#include <yandex/maps/wiki/common/retry_duration.h>

#include <maps/libs/pgpool/include/pgpool3.h>
#include <maps/libs/log8/include/log8.h>

#include <string>
#include <vector>

namespace maps::wiki::validator::storage {

namespace {

const size_t MESSAGES_STORE_BATCH_SIZE = 1000;

ExclusionsSet loadExclusions(ResultPtr result, pgpool3::Pool& pgPool)
{
    ExclusionsSet exclSet;
    if (!result) {
        return exclSet;
    }

    auto txn = pgPool.masterReadOnlyTransaction();
    ExclusionsGateway exclGw(*txn);

    for (const auto& checkId : result->checkIds()) {
        auto exclBatch = exclGw.exclusionsForCheck(checkId, result->aoi().boundingBox());
        exclSet.add(exclBatch);
        INFO() << "loaded " << exclBatch.size()
            << " exclusions for check " << checkId;
    }

    return exclSet;
}

} // namespace

bool storeResult(
    ResultPtr result,
    pgpool3::Pool& pgPool,
    TTaskId taskId)
{
    if (!result) {
        return true;
    }

    const auto exclusionsSet = common::retryDuration([&] {
        return loadExclusions(result, pgPool);
    });

    std::vector<Message> messagesBatch;
    messagesBatch.reserve(MESSAGES_STORE_BATCH_SIZE);

    MessagesWriter messageWriter(pgPool, taskId);

    Result::MessageBuffer buffer;
    do {
        for (Message& message : buffer) {
            if (exclusionsSet.contains(message)) {
                continue;
            }

            messagesBatch.push_back(std::move(message));
            if (messagesBatch.size() >= MESSAGES_STORE_BATCH_SIZE) {
                if (!messageWriter.writeMessagesBatchWithRetries(messagesBatch)) {
                    return false;
                }

                messagesBatch.clear();
                messagesBatch.reserve(MESSAGES_STORE_BATCH_SIZE);
            }
        }

        buffer = result->popMessages();
    } while (!buffer.empty());

    return messageWriter.writeMessagesBatchWithRetries(messagesBatch);
}

ResultsGateway::ResultsGateway(Transaction& txn, TTaskId taskId)
    : txn_(txn)
    , taskId_(taskId)
{ }

size_t ResultsGateway::messageCount() const
{
    return messageCount(MessageAttributesFilter{});
}

size_t ResultsGateway::messageCount(
    const MessageAttributesFilter& filter) const
{
    std::string query =
        "SELECT sum(" + COUNT_COLUMN_NAME + ") "
        + " FROM " + TASK_MESSAGE_STATS_TABLE + JOIN_MESSAGE_ATTRIBUTES_TABLE
        + " WHERE " + TASK_ID_COLUMN_NAME + " = " + std::to_string(taskId_)
        + " AND " + whereClause(filter, txn_);

    return txn_.exec(query).at(0).at(0).as<size_t>(0);
}

MessageStatistics ResultsGateway::statistics(
    const MessageAttributesFilter& filter) const
{
    std::string query =
        "SELECT " + MESSAGE_ATTRIBUTES_COLUMNS + "," + COUNT_COLUMN_NAME
        + " FROM " + TASK_MESSAGE_STATS_TABLE + JOIN_MESSAGE_ATTRIBUTES_TABLE
        + " WHERE " + TASK_ID_COLUMN_NAME + " = " + std::to_string(taskId_)
        + " AND " + whereClause(filter, txn_);

    MessageStatistics result;

    for (const auto& row : txn_.exec(query)) {
        auto attributes = messageAttributesFromDbRow(row);
        result[attributes] = row[COUNT_COLUMN_NAME].as<size_t>();
    }

    return result;
}

StoredMessageData ResultsGateway::messages(
    const MessageAttributesFilter& filter,
    const revision::Snapshot& snapshot,
    size_t offset, size_t limit,
    revision::UserID uid) const
{
    // The reason for this rather ugly nested select statement is
    // inability of postgresql to generate optimal plan for plain
    // join. Possible improvements in planner:
    // 1. Mark output of nestloop join on attribute a as ordered by
    // (a, b) if outer loop is ordered by a and inner loop is ordered by (a, b).
    // 2. Mark output of index scan on index(a, b) as ordered by
    // (a, b) with condition (a = x and b = any(..))
    // 3. Loose indexscan in merge join

    const std::string taskMessageQuery =
        "SELECT " + MESSAGE_ID_COLUMNS
        + " FROM " + TASK_MESSAGE_TABLE

        + " WHERE " + TASK_ID_COLUMN_NAME + " = " + std::to_string(taskId_)
        + " AND " + ATTRIBUTES_ID_COLUMN_NAME + " = ANY("
            + "(SELECT array_agg(" + ATTRIBUTES_ID_COLUMN_NAME + ")"
            + " FROM " + MESSAGE_ATTRIBUTES_TABLE
            + " WHERE " + whereClause(filter, txn_) + ")::integer[])"

        + " ORDER BY " + MESSAGE_ID_COLUMNS
        + " OFFSET " + std::to_string(offset)
        + " LIMIT " + std::to_string(limit);

    const std::string isViewedQuery =
        "(SELECT viewed_by FROM " + MESSAGE_VIEW_TABLE +
        " vt WHERE vt.attributes_id = tm.attributes_id AND"
        " vt.content_id = tm.content_id AND vt.task_id = " + std::to_string(taskId_) +
        " AND viewed_by=" + std::to_string(uid) +
        " LIMIT 1) as is_viewed";
    const std::string query =
        "SELECT " + STORED_MESSAGE_COLUMNS
        + ", " + isViewedQuery +" FROM (" + taskMessageQuery + ") tm" + JOIN_MESSAGE_SUBSTANCE_TABLES
        + " LEFT JOIN " + EXCLUSION_TABLE + " USING (" + MESSAGE_ID_COLUMNS + ")"
        " ORDER BY " + MESSAGE_ID_COLUMNS;
    return storedMessagesFromDbRows(snapshot, txn_.exec(query));
}

StoredMessageDatum ResultsGateway::messageSetViewed(
        const revision::Snapshot& snapshot,
        revision::UserID uid,
        const MessageId& messageId) const
{
    txn_.exec(
        "INSERT INTO " + MESSAGE_VIEW_TABLE +
        " (task_id, attributes_id, content_id, viewed_by) VALUES (" +
        std::to_string(taskId_) + "," +
        std::to_string(messageId.attributesId()) +
        "," + std::to_string(messageId.contentId())
        + "," + std::to_string(uid)
        + ") ON CONFLICT DO NOTHING;");
    std::string query =
        "SELECT " + STORED_MESSAGE_COLUMNS + "," + std::to_string(uid) + " as is_viewed"
        " FROM " + TASK_MESSAGE_TABLE + " tm " + JOIN_MESSAGE_SUBSTANCE_TABLES +
        " LEFT JOIN " + EXCLUSION_TABLE + " USING (" + MESSAGE_ID_COLUMNS + ")"
        " WHERE task_id = " + std::to_string(taskId_) +
        " AND attributes_id = " + std::to_string(messageId.attributesId()) +
        " AND content_id = " + std::to_string(messageId.contentId()) +
        " LIMIT 1";
    return storedMessagesFromDbRows(snapshot, txn_.exec(query)).front();
}

std::vector<TCheckId> ResultsGateway::checkIdsWithFatalErrors() const
{
    std::ostringstream selectChecksQuery;
    selectChecksQuery
        << "SELECT " << CHECK_ID_COLUMN_NAME
        << " FROM " << TASK_MESSAGE_TABLE
        << " JOIN " << MESSAGE_ATTRIBUTES_TABLE
        << " USING (" << ATTRIBUTES_ID_COLUMN_NAME << ")"
        << " WHERE " << TASK_ID_COLUMN_NAME << " = " << taskId_
        << " AND " << SEVERITY_COLUMN_NAME << " = "
            << static_cast<int>(Severity::Fatal)
        << " GROUP BY " << CHECK_ID_COLUMN_NAME;

    std::vector<TCheckId> checkIds;
    for (const auto& row : txn_.exec(selectChecksQuery.str())) {
        const auto& checkId = row[0].as<std::string>();
        if (checkId != BASE_CHECK_ID) {
            checkIds.emplace_back(row[0].as<std::string>());
        }
    }
    return checkIds;
}

} // namespace maps::wiki::validator::storage
