#include <yandex/maps/wiki/social/region_feed.h>

#include <yandex/maps/wiki/common/string_utils.h>

#include "feed_helpers.h"
#include "helpers.h"


namespace maps::wiki::social {

namespace {

enum class IntervalEdge { Left, Right };

const std::string CEILING_EVENT_ID_QUERY =
    "SELECT MAX(event_id)+1 FROM social.commit_event";

std::string
initialValues(
    std::optional<TId> minEventId,
    std::optional<TId> maxEventId,
    size_t limit,
    IntervalEdge fixedEdge
)
{
    const auto minEventIdStr = minEventId ? std::to_string(*minEventId) : "0";
    const auto maxEventIdStr =
        maxEventId ? std::to_string(*maxEventId) : CEILING_EVENT_ID_QUERY;

    const auto fixedEdgeValueStr =
        (fixedEdge == IntervalEdge::Right) ? maxEventIdStr : minEventIdStr;

    std::ostringstream os;
    os << "0::int, "
          "(" << fixedEdgeValueStr <<     ")::bigint, " // fixed_boundary immutable boundary for event_id.
          "(" << minEventIdStr <<         ")::bigint, " // lft start value (0 by default or min_event_id).
          "(" << maxEventIdStr <<         ")::bigint, " // rgt start value equal to max_event_id.
          "0::bigint, "                                 // commits count at previous iteration.
          "(" << limit         <<         ") "          // necessary number of commits to look for.
          ;
    return os.str();
}

std::pair<std::string, std::string>
eventIdBoundaries(IntervalEdge fixedEdge)
{
    static const std::string fixed = "fixed_boundary";
    static const std::string free = "(lft + rgt) / 2";
    if (fixedEdge == IntervalEdge::Left) {
        return std::make_pair(fixed, free);
    }
    return std::make_pair(free, fixed);
}

std::string commitEventSubquery(
    const std::string& whereClause,
    const std::string& leftBoundary,
    const std::string& rightBoundary
)
{
    std::ostringstream os;
    os << SELECT_COMMIT_EVENT_FIELDS
       << " WHERE " << whereClause
       << " AND "
       << sql::col::EVENT_ID << " > " << leftBoundary
       << " AND "
       << sql::col::EVENT_ID << " < " << rightBoundary
       << " LIMIT asked + 1";
    return os.str();
}

std::string
dropLeftSearchIntervalPartCondition(
    IntervalEdge fixedEdge
)
{
    if (fixedEdge == IntervalEdge::Right) {
        return "commits_count > asked";
    }
    return "commits_count < asked";
}

std::string
dropRightSearchIntervalPartCondition(
    IntervalEdge fixedEdge
)
{
    if (fixedEdge == IntervalEdge::Right) {
        return dropLeftSearchIntervalPartCondition(IntervalEdge::Left);
    }
    return dropLeftSearchIntervalPartCondition(IntervalEdge::Right);
}

std::string wantedBoundaryBinarySearch(IntervalEdge fixedEdge)
{
    std::ostringstream os;
    os <<
        " CASE "
        "     WHEN " << dropLeftSearchIntervalPartCondition(fixedEdge) <<
        "         THEN ((lft + rgt) / 2) + 1 " // Add one to avoid eternal loop near right edge.
        "     ELSE "
        "         lft "
        " END                             AS lft, "
        " CASE "
        "     WHEN " << dropRightSearchIntervalPartCondition(fixedEdge) <<
        "         THEN (lft + rgt) / 2 "
        "     ELSE "
        "         rgt "
        " END                             AS rgt "
    ;

    return os.str();
}

std::string recursiveQuery(
    const TUids& skippedUids,
    const std::string& whereClauseStr,
    std::optional<TId> minEventId,
    std::optional<TId> maxEventId,
    size_t limit,
    IntervalEdge fixedEdge
)
{
    const std::string order =
        (fixedEdge == IntervalEdge::Right ? " DESC " : " ASC ");

    const auto [left, right] = eventIdBoundaries(fixedEdge);
    const auto commitEventSubqueryStr = commitEventSubquery(whereClauseStr, left, right);
    const auto initialValuesStr = initialValues(minEventId, maxEventId, limit, fixedEdge);

    std::ostringstream os;
    os << " WITH RECURSIVE ";
    if (!skippedUids.empty()) {
        os << "skipped_uids(uid) AS (VALUES (" << common::join(skippedUids, "),(") << ")),";
    }
    os << " t(iteration, fixed_boundary, lft, rgt, prev_count, asked) AS ( "
          " VALUES ( "
          << initialValuesStr
          << " ) "
          " UNION "
          " SELECT "
          "     iteration + 1                   AS iteration, "
          "     fixed_boundary                  AS fixed_boundary, "
          << wantedBoundaryBinarySearch(fixedEdge) << ", "
          "     commits_count                   AS prev_count, "
          "     asked                           AS asked "
          "     FROM "
          "         t, "
          "         LATERAL ( "
          "             SELECT "
          "                 COUNT(1)        AS commits_count "
          "             FROM "
          "                 (" << commitEventSubqueryStr << ") AS commit_event_subquery "
          "         ) AS t2 "
          "     WHERE "
          "         lft < rgt "
          "         AND "
          "         prev_count <> asked "
          " ) "
          " SELECT "
          "     commit_event_subquery.* "
          " FROM "
          "     ( "
          "         SELECT    fixed_boundary, lft, rgt, asked "
          "         FROM      t "
          "         ORDER BY  iteration DESC "
          "         LIMIT     1 "
          "     ) AS last_iteration, "
          "     LATERAL "
          "         (" << commitEventSubqueryStr << ")      AS commit_event_subquery "
          " ORDER BY event_id " << order << " ;";

    return os.str();
}

} // namespace

std::pair<Events, HasMore>
RegionFeed::eventsHead(size_t limit) const
{
    checkLimit(limit);

    // createEvents(...) function needs 1 additional commit
    // so passed limit is (limit + 1).
    auto fixedEdge = IntervalEdge::Right;
    auto query = recursiveQuery(skippedUids_, whereClause(), {}, {}, limit + 1, fixedEdge);

    return createEvents(work_.exec(query), limit, PushOrder::Back);
}

std::pair<Events, HasMore>
RegionFeed::eventsAfter(
    TId eventId,
    size_t limit
) const
{
    checkLimit(limit);

    auto fixedEdge = IntervalEdge::Left;
    // createEvents(...) function needs 1 additional commit
    // so passed limit is (limit + 1).
    const auto query =
        recursiveQuery(skippedUids_, whereClause(), eventId, {}, limit + 1, fixedEdge);

    return createEvents(work_.exec(query), limit, PushOrder::Front);
}

std::pair<Events, HasMore>
RegionFeed::eventsBefore(
    TId eventId,
    size_t limit
) const
{
    checkLimit(limit);
    ASSERT(eventId);

    auto fixedEdge = IntervalEdge::Right;

    // createEvents(...) function needs 1 additional commit
    // so passed limit is (limit + 1).
    const auto query =
        recursiveQuery(skippedUids_, whereClause(), {}, eventId, limit + 1, fixedEdge);

    return createEvents(work_.exec(query), limit, PushOrder::Back);
}

std::string
RegionFeed::whereClause() const
{
    std::ostringstream os;
    if (boundsPredicate_ == EventBoundsPredicate::CoveredByRegion) {
        os << " ST_CoveredBy(" << sql::col::BOUNDS_GEOM << ",";
    } else {
        os << " (" << sql::col::BOUNDS_GEOM << " && ";
    }

    os << "    ST_GeomFromWKB(" << work_.quote_raw(regionGeometryMercatorWkb_)
       << "," << sql::value::MERCATOR_SRID << "))"

       // Feed branch filter.
       << " AND " << sql::col::BRANCH_ID << " = " << branchId_

       // Choose only edits.
       << " AND " << sql::col::TYPE << "='" << sql::value::EVENT_TYPE_EDIT << "'"

       // Use to help planner choose the correct index from
       // https://a.yandex-team.ru/review/1207648 .
       << " AND " << sql::col::BOUNDS_GEOM << " IS NOT NULL "

       // Include or exclude categories if provided.
       << " AND " << categoriesWhereCondition(work_, categoryIdsFilter_);

    if (!skippedUids_.empty()) {
        os << " AND " << sql::col::CREATED_BY << " NOT IN (SELECT uid FROM skipped_uids)";
    }

    return os.str();
}

} // namespace maps::wiki::social
