#pragma once

#include <maps/wikimap/mapspro/libs/masstransit/masstransit.h>

#include <yandex/maps/wiki/routing/route.h>
#include <maps/libs/geolib/include/closest_point.h>
#include <maps/libs/geolib/include/serialization.h>

#include <optional>

namespace maps::wiki::masstransit {

constexpr double STOP_SNAP_TOLERANCE_METERS = 50.0;

const std::string BUS_ROWS_SQL_TEMPLATE =
    "SELECT "
        "ft_rd_el.ft_id AS id, "
        "rd_el.rd_el_id AS el_id, "
        "rd_el.f_rd_jc_id AS f_jc_id, "
        "rd_el.t_rd_jc_id AS t_jc_id, "
        "rd_el.f_zlev, "
        "rd_el.t_zlev, "
        "rd_el.oneway, "
        "rd_el.back_bus, "
        "rd_el.access_id, "
        "ST_AsText(rd_el.shape) AS shape "
    "FROM ft_rd_el "
    "JOIN rd_el USING (rd_el_id) "
    "WHERE ft_id IN %1%";

const std::string TRAM_ROWS_SQL_TEMPLATE =
    "SELECT "
        "ft_edge.ft_id AS id, "
        "edge.edge_id AS el_id, "
        "edge.f_node_id AS f_jc_id, "
        "edge.t_node_id AS t_jc_id, "
        "edge.f_zlev, "
        "edge.t_zlev, "
        "'B' AS oneway, " // both ways
        "0 AS back_bus, " // no back bus
        "31 AS access_id, " // all types of transport
        "ST_AsText(edge.shape) AS shape "
    "FROM ft_edge "
    "JOIN edge USING (edge_id) "
    "WHERE ft_id IN %1%";

const std::string METRO_ROWS_SQL_TEMPLATE =
    "SELECT "
        "ft_edge.ft_id AS id, "
        "ST_AsText(ST_LineMerge(ST_Collect(edge.shape))) AS shape "
    "FROM ft_edge "
    "JOIN edge USING (edge_id) "
    "WHERE ft_id IN %1%"
    "GROUP BY 1;";

struct ThreadGeometry : public Object
{
    DBID threadId;
    Polyline polyline;

    explicit ThreadGeometry(DBID threadId)
        : threadId(threadId)
    { };

    struct YmapsdfLoadData {
        Elements elements;
        std::map<DBID, size_t> idToElement;
    };
    YmapsdfLoadData ymapsdf;
};

inline StopPoints
collectStopPoints(
    const Masstransit& masstransit,
    DBID threadId)
{
    StopPoints stopPoints;

    auto addStopPoint = [&](DBID stopId) {
        const auto& stop = masstransit.stops[stopId];
        stopPoints.emplace_back(stop.stopId, stop.point);
    };

    if (masstransit.threadStops.count(threadId)) {
        for (const auto& stopInThread : masstransit.threadStops[threadId].stopsInThread) {
            addStopPoint(stopInThread.stopId);
        }
    } else {
        /* For connector without own stops we return 2 stops:
         * 1. Last stop of the source thread
         * 2. First stop of the destination thread
         */
        const auto& connector = masstransit.connectors[threadId];
        const auto& stopsInSrcThread =
            masstransit.threadStops[connector.srcThreadId].stopsInThread;
        const auto& stopsInDstThread =
            masstransit.threadStops[connector.dstThreadId].stopsInThread;
        addStopPoint(stopsInSrcThread.back().stopId);
        addStopPoint(stopsInDstThread.front().stopId);
    }

    return stopPoints;
}

inline Polyline
restoreThreadGeometry(
    const routing::RestoreResult& result,
    const Elements& elements,
    const std::map<DBID, size_t>& idToElement)
{
    Polyline polyline;

    std::optional<routing::DirectedElementID> previousDirectedElementId;
    for (const auto& tracePoint : result.trace()) {
        const auto& directedElementId = tracePoint.directedElementId();
        if (previousDirectedElementId == directedElementId) {
            continue;
        }
        previousDirectedElementId = directedElementId;

        const auto elementId = directedElementId.id();
        const auto direction = directedElementId.direction();
        const auto& element = elements[idToElement.at(elementId)];

        auto rdElPolyline = element.geom();

        if (direction == Direction::Backward) {
            rdElPolyline.reverse();
        }

        polyline.extend(rdElPolyline, geolib3::MergeEqualPoints);
    }

    const auto& start = result.trace().front().stopSnap()->point();
    const auto& end = result.trace().back().stopSnap()->point();

    return partition(polyline, start, end);
}

inline void
adjustDirection(Polyline& polyline, const Point& stopPoint)
{
    if (distanceAlongFromStart(polyline, stopPoint) > distanceAlongToEnd(polyline, stopPoint)) {
        polyline.reverse();
    }
}

inline void
adjustSplitPoint(Polyline& circular, const Point& startStop)
{
    REQUIRE(circular.points().front() == circular.points().back(),
        "Non-closed circular line");
    auto endPart = partitionFromStart(circular, startStop);
    circular = partitionToEnd(circular, startStop);
    circular.extend(endPart, geolib3::MergeEqualPoints);
}

template<typename ThreadGeometry>
void
addRoutedElement(
    ThreadGeometry& threadGeom,
    const pqxx::row& tuple)
{
    const auto shapeStr = getAttr<std::string>(tuple, ymapsdf::SHAPE);
    const auto shape = geolib3::WKT::read<Polyline>(shapeStr);
    const auto rdElId = getAttr<DBID>(tuple, ymapsdf::EL_ID);
    const auto fJcId = getAttr<DBID>(tuple, ymapsdf::F_JC_ID);
    const auto tJcId = getAttr<DBID>(tuple, ymapsdf::T_JC_ID);
    const auto fZLev = getAttr<int>(tuple, ymapsdf::F_ZLEV);
    const auto tZLev = getAttr<int>(tuple, ymapsdf::T_ZLEV);
    const auto backBus = getAttr<bool>(tuple, ymapsdf::BACK_BUS);
    const auto accessedByBus = getAttr<DBID>(tuple, ymapsdf::ACCESS_ID) & value::ACCESS_BUS;
    auto direction = getAttr<Direction>(tuple, ymapsdf::ONEWAY);
    if (backBus) {
        direction = accessedByBus ? Direction::Both : reverse(direction);
    }

    ElementEnd fromJc(fJcId, fZLev);
    ElementEnd toJc(tJcId, tZLev);
    const auto elementId = threadGeom.ymapsdf.elements.size();
    threadGeom.ymapsdf.elements.emplace_back(rdElId, direction, shape, fromJc, toJc);
    threadGeom.ymapsdf.idToElement.emplace(rdElId, elementId);
}

enum class UseConditions
{
    Yes,
    No
};

template<typename ThreadGeometry>
void
routeThreadGeometry(
    ThreadGeometry& threadGeom,
    Masstransit& masstransit,
    UseConditions useConditions)
{
    auto DATA_ERROR = [&]() -> DataErrorMessage { return DataErrorMessage(masstransit.log()); };

    const auto threadId = threadGeom.threadId;

    const auto result = routing::restore(
        threadGeom.ymapsdf.elements,
        collectStopPoints(masstransit, threadId),
        useConditions == UseConditions::Yes ? roadConditions(masstransit) : Conditions(),
        STOP_SNAP_TOLERANCE_METERS);

    if (!result.isOk()) {
        DATA_ERROR() << "Object id=" << threadId << " geometry error";
        for (const auto& noPathError : result.noPathErrors()) {
            DATA_ERROR() << "No path found from stop " << noPathError.fromStopId()
                << " to stop " << noPathError.toStopId();
        }
        for (const auto& forwardAmbiguousPathError : result.forwardAmbiguousPathErrors()) {
            DATA_ERROR() << "Ambiguity found between stops " << forwardAmbiguousPathError.fromStopId()
                << " and " << forwardAmbiguousPathError.toStopId()
                << " at road element " << forwardAmbiguousPathError.elementId();
        }
        throw RuntimeError() << "Object id=" << threadId << " geometry error";
    }

    threadGeom.polyline = restoreThreadGeometry(
        result,
        threadGeom.ymapsdf.elements,
        threadGeom.ymapsdf.idToElement);
}

template<typename ThreadGeometry>
void
adjustLinearThread(
    Masstransit& masstransit,
    ThreadGeometry& threadGeom)
{
    const auto& threadStops = masstransit.threadStops[threadGeom.threadId];
    const auto& startStop = masstransit.stops[threadStops.stopsInThread.front().stopId];
    const auto& endStop = masstransit.stops[threadStops.stopsInThread.back().stopId];

    adjustDirection(
        threadGeom.polyline,
        closestPoint(threadGeom.polyline, startStop.point));

    const auto& start = closestPoint(threadGeom.polyline, startStop.point);
    const auto& end = closestPoint(threadGeom.polyline, endStop.point);
    threadGeom.polyline = partition(threadGeom.polyline, start, end);
}

template<typename ThreadGeometry>
void
adjustCircularThread(
    Masstransit& masstransit,
    ThreadGeometry& threadGeom)
{
    const auto& threadStops = masstransit.threadStops[threadGeom.threadId];
    const auto& startStop = masstransit.stops[threadStops.stopsInThread.front().stopId];
    const auto& secondStop = masstransit.stops[threadStops.stopsInThread[1].stopId];

    adjustSplitPoint(
        threadGeom.polyline,
        closestPoint(threadGeom.polyline, startStop.point));

    adjustDirection(
        threadGeom.polyline,
        closestPoint(threadGeom.polyline, secondStop.point));
}

} // namespace maps::wiki::masstransit
