#include <maps/wikimap/mapspro/tools/ymapsdf-conversion/rdel-merger/merger.h>

#include <maps/libs/common/include/exception.h>
#include <maps/libs/log8/include/log8.h>

#include <yandex/maps/wiki/common/pg_utils.h>
#include <yandex/maps/wiki/common/string_utils.h>

#include <boost/lexical_cast.hpp>


namespace maps {
namespace wiki {
namespace rdel_merger {

namespace {
const Ids EMPTY_IDS;
} // namespace

Merger::Merger(
        const std::string& connStr,
        const std::string& schema,
        size_t maxVertexes,
        double maxLength)
    : db_(connStr, schema)
    , maxVertexes_(maxVertexes)
    , maxLength_(maxLength)
{
    loadRoads();
    loadConds();
    loadElements();
}

const Ids& Merger::roadIds(Id rdElId) const
{
    auto it = rdElToRoads_.find(rdElId);
    return it == rdElToRoads_.end() ? EMPTY_IDS : it->second;
}

const Ids& Merger::condIds(Id rdElId) const
{
    auto it = rdElToConds_.find(rdElId);
    return it == rdElToConds_.end() ? EMPTY_IDS : it->second;
}

void Merger::run()
{
    INFO() << "Checking fatal erros...";
    auto errorRdElId = findErrorRdEl();
    REQUIRE(!errorRdElId, "FATAL error detected, rd_el: " << errorRdElId);

    size_t allProcessed = 0;
    for (;;) {
        auto processed = doWork();
        allProcessed += processed;
        INFO() << "Processed: " << processed << " / " << allProcessed;

        if (!processed) {
            break;
        }
    }
}

void Merger::loadRoads()
{
    auto query = Query()
        .select("rd_id, rd_el_id")
        .from("rd_rd_el");

    for (const auto& row : db_.read(query, "rd_rd_el")) {
        auto rdId = row[0].as<Id>();
        auto rdElId = row[1].as<Id>();
        rdElToRoads_[rdElId].emplace(rdId);
    }
}

void Merger::loadConds()
{
    auto query = Query()
        .select("cond_id, rd_el_id")
        .from("cond_rd_seq JOIN cond USING(cond_seq_id)");

    for (const auto& row : db_.read(query, "conds")) {
        auto condId = row[0].as<Id>();
        auto rdElId = row[1].as<Id>();
        rdElToConds_[rdElId].emplace(condId);
    }
}

void Merger::loadElements()
{
    auto query = Query()
        .select("rd_el_id")
        .from("rd_el")
        .where("f_zlev != t_zlev");

    for (const auto& row : db_.read(query, "protected elements")) {
        auto rdElId = row[0].as<Id>();
        protectedRdElIds_.emplace(rdElId);
    }
}

Id Merger::findErrorRdEl(Id rdElId)
{
    std::string filter =
        "NOT ST_Equals(ST_StartPoint(l.shape),j1.shape) OR NOT ST_Equals(ST_EndPoint(l.shape),j2.shape)";
    if (rdElId) {
        filter = "(" + filter + ") AND rd_el_id=" + std::to_string(rdElId);
    }

    auto rows = db_.exec(Query()
        .select("rd_el_id")
        .from("rd_el l JOIN rd_jc j1 ON f_rd_jc_id=j1.rd_jc_id JOIN rd_jc j2 ON t_rd_jc_id=j2.rd_jc_id")
        .where(filter)
        .limit(1));
    return rows.empty() ? 0 : rows[0][0].as<Id>();
}

void Merger::fixRdElDirection(Id rdElId)
{
    INFO() << "FIX " << rdElId;
    db_.exec(Query()
        .update("rd_el")
        .set("f_rd_jc_id=t_rd_jc_id, t_rd_jc_id=f_rd_jc_id")
        .where(Query() << "rd_el_id=" << rdElId));

    auto errorRdElId = findErrorRdEl(rdElId);
    REQUIRE(!errorRdElId, "FATAL error detected after fix, rd_el: " << errorRdElId);
}

size_t Merger::doWork()
{
    auto rdJcToRdEls = loadRdJcToRdEls();

    size_t processed = 0;
    Ids skippedRdJcs;
    for (const auto& [midJcId, rdElIds] : rdJcToRdEls) {
        if (skippedRdJcs.count(midJcId)) {
            continue;
        }

        auto rdElIdsClause = "rd_el_id IN (" + common::join(rdElIds, ',') + ")";
        auto query = Query()
            .select(
                "rd_el_id, f_rd_jc_id, t_rd_jc_id, oneway, "
                "ST_Length(ST_Transform(shape,3395)) AS len, "
                "ST_NumPoints(shape) AS num")
            .from("rd_el")
            .where(rdElIdsClause)
            .orderBy("rd_el_id");

        double len = 0;
        size_t num = 0;
        Id fRdJcId = 0;
        Id tRdJcId = 0;
        bool needCheckAndFix = false;

        auto rows = db_.exec(query);
        ASSERT(rows.size() == 2);
        for (const auto& row : rows) {
            len += row["len"].as<double>(0);
            num += row["num"].as<size_t>(0);

            auto idFirst = row["f_rd_jc_id"].as<Id>(0);
            auto idTail = row["t_rd_jc_id"].as<Id>(0);

            if (row["oneway"].as<std::string>() == "B") { // bothway
                if (idFirst != midJcId) {
                    if (fRdJcId) {
                        tRdJcId = idFirst;
                        needCheckAndFix = true;
                    } else {
                        fRdJcId = idFirst;
                    }
                }
                if (idTail != midJcId) {
                    if (tRdJcId) {
                        fRdJcId = idTail;
                        needCheckAndFix = true;
                    } else {
                        tRdJcId = idTail;
                    }
                }
            } else {
                if (idFirst != midJcId) {
                    fRdJcId = idFirst;
                }
                if (idTail != midJcId) {
                    tRdJcId = idTail;
                }
            }
        }

        skippedRdJcs.emplace(midJcId);
        --num;
        if (num > maxVertexes_ || len > maxLength_ || !fRdJcId || !tRdJcId || fRdJcId == tRdJcId) {
            continue;
        }

        skippedRdJcs.emplace(fRdJcId);
        skippedRdJcs.emplace(tRdJcId);

        auto rdElId1 = rdElIds.front();
        auto rdElId2 = rdElIds.back();

        if (!condIds(rdElId2).empty()) {
            ASSERT(condIds(rdElId1).empty());
            std::swap(rdElId1, rdElId2);
        }

        db_.write(Query()
            .update("rd_el")
            .set(Query()
                << "f_rd_jc_id=" << fRdJcId << ','
                << "t_rd_jc_id=" << tRdJcId << ','
                << "shape=(SELECT ST_Linemerge(ST_Collect(shape))"
                            " FROM rd_el WHERE " << rdElIdsClause << ")")
            .where(Query() << "rd_el_id=" << rdElId1));

        if (needCheckAndFix && findErrorRdEl(rdElId1) == rdElId1) {
            fixRdElDirection(rdElId1);
        }

        if (rdElToRoads_.count(rdElId2)) {
            db_.write(Query() << "DELETE FROM rd_rd_el WHERE rd_el_id=" << rdElId2);
            rdElToRoads_.erase(rdElId2);
        }
        db_.write(Query() << "DELETE FROM rd_el WHERE rd_el_id=" << rdElId2);
        db_.write(Query() << "DELETE FROM rd_jc WHERE rd_jc_id=" << midJcId);
        ++processed;
    }
    return processed;
}

std::unordered_map<Id, std::vector<Id>> Merger::loadRdJcToRdEls()
{
    auto allowed =
        "SELECT rd_jc_id FROM"
        " (SELECT f_rd_jc_id AS rd_jc_id FROM rd_el"
        " UNION ALL "
        "  SELECT t_rd_jc_id AS rd_jc_id FROM rd_el) tmp"
        " GROUP BY 1 HAVING COUNT(*)=2"
        " EXCEPT SELECT rd_jc_id FROM bound_jc";

    std::string cmpFields =
        "fc, fow, speed_cat, speed_limit, f_zlev, t_zlev, oneway, access_id,"
        "back_bus, forward_bus, paved, poor_condition, stairs, sidewalk, struct_type,"
        "ferry, dr, toll, srv_ra, srv_uc, isocode, subcode";

    auto rows = db_.read(Query() <<
        "SELECT rd_jc_id, ARRAY_AGG(rd_el_id) AS rd_el_ids"
        "    FROM ("
        "    SELECT f_rd_jc_id AS rd_jc_id, rd_el_id, " << cmpFields << " FROM rd_el"
        "    UNION ALL"
        "    SELECT t_rd_jc_id AS rd_jc_id, rd_el_id, " << cmpFields << " FROM rd_el) tmp"
        " WHERE rd_jc_id IN (" << allowed << ")"
        " GROUP BY rd_jc_id, " << cmpFields <<
        " HAVING COUNT(DISTINCT rd_el_id)=2;", "candidates");

    std::unordered_map<Id, std::vector<Id>> rdJcToRdEls;
    for (const auto& row : rows) {
        std::vector<Id> ids;
        for (const auto& idStr : common::parseSqlArray(row["rd_el_ids"].as<std::string>())) {
            ids.emplace_back(boost::lexical_cast<Id>(idStr));
        }
        ASSERT(ids.size() == 2);

        auto id1 = ids.front();
        auto id2 = ids.back();
        ASSERT(id1 != id2);
        if (protectedRdElIds_.count(id1) || protectedRdElIds_.count(id2)) {
            continue;
        }

        if (!condIds(id1).empty() && !condIds(id2).empty()) {
            continue;
        }

        if (roadIds(id1) == roadIds(id2)) {
            auto rdJcId = row[0].as<Id>();
            rdJcToRdEls.emplace(rdJcId, std::move(ids));
        }
    }
    INFO() << "After check conds & roads: " << rdJcToRdEls.size();
    return rdJcToRdEls;
}

} // namespace rdel_merger
} // namespace wiki
} // namespace maps
