#include "static_graph.h"

#include "route_impl.h"

#include <yandex/maps/wiki/routing/exception.h>

#include <maps/libs/geolib/include/bounding_box.h>
#include <maps/libs/geolib/include/closest_point.h>
#include <maps/libs/geolib/include/distance.h>
#include <maps/libs/geolib/include/static_geometry_searcher.h>
#include <maps/libs/geolib/include/vector.h>

#include <maps/libs/common/include/exception.h>
#include <yandex/maps/wiki/common/geom_utils.h>

#include <boost/optional.hpp>

#include <algorithm>
#include <functional>
#include <utility>


namespace maps {
namespace wiki {
namespace routing {

namespace  {

typedef std::vector<StopSnap> StopSnaps;
typedef std::map<DirectedElementID, StopSnaps> DirectedElementIdToStopSnaps;

DirectedElementIdToStopSnaps snapStops(
        const Elements& elements,
        const Stops& stops,
        double stopSnapToleranceMeters)
{
    geolib3::StaticGeometrySearcher<Polyline, size_t> index;
    for (size_t i = 0; i < elements.size(); ++i) {
        index.insert(std::addressof(elements[i].geom()), i);
    }
    index.build();

    DirectedElementIdToStopSnaps directedElementIdToStopSnaps;
    IdSet snappedStopIds;
    IdSet notSnappedStopIds;
    for (const auto& stop: stops) {
        if (snappedStopIds.count(stop.id()) || notSnappedStopIds.count(stop.id())) {
            continue;
        }

        const double radius = geolib3::toMercatorUnits(stopSnapToleranceMeters, stop.geom());
        const double diametr = 2 * radius;

        const geolib3::BoundingBox searchBox{stop.geom(), diametr, diametr};
        const auto result = index.find(searchBox);

        boost::optional<DirectedElementID> snapElementId;
        const Polyline* snapElementGeom = nullptr;
        Point snapPoint;
        double bestDistance = radius;

        for (auto it = result.first; it != result.second; ++it) {
            const auto& element = elements[it->value()];

            const Point closestPoint = geolib3::closestPoint(element.geom(), stop.geom());

            const double distance = geolib3::distance(stop.geom(), closestPoint);
            if (distance > bestDistance) {
                continue;
            }

            const auto segmentIndex = element.geom().segmentIndex(closestPoint);
            ASSERT(segmentIndex);
            const auto segment = element.geom().segmentAt(*segmentIndex);

            const geolib3::Vector2 roadVector = segment.vector();
            const geolib3::Vector2 pointVector = stop.geom() - segment.start();

            if (geolib3::crossProduct(roadVector, pointVector) <= 0) {
                if (!isSet(Direction::Forward, element.direction())) {
                    continue;
                }
                snapElementId = DirectedElementID{element.id(), Direction::Forward};
            } else {
                if (!isSet(Direction::Backward, element.direction())) {
                    continue;
                }
                snapElementId = DirectedElementID{element.id(), Direction::Backward};
            }

            snapElementGeom = &element.geom();
            snapPoint = closestPoint;
            bestDistance = distance;
        }

        if (!snapElementId) {
            notSnappedStopIds.insert(stop.id());
            continue;
        }

        ASSERT(snapElementGeom);
        const double distanceAlong = snapElementId->direction() == Direction::Forward
            ? geolib3::distanceAlongFromStart(*snapElementGeom, snapPoint)
            : geolib3::distanceAlongToEnd(*snapElementGeom, snapPoint);

        snappedStopIds.insert(stop.id());
        directedElementIdToStopSnaps[*snapElementId].push_back(
            PImplFactory::create<StopSnap>(
                stop.id(),
                distanceAlong / geolib3::length(*snapElementGeom),
                snapPoint
            )
        );
    }

    if (!notSnappedStopIds.empty()) {
        throw ImpossibleSnapStopError(std::move(notSnappedStopIds));
    }

    for (auto& pair: directedElementIdToStopSnaps) {
        auto& stopSnaps = pair.second;
        std::sort(
            stopSnaps.begin(), stopSnaps.end(),
            [](const StopSnap& lhs, const StopSnap& rhs) {
                return lhs.locationOnElement() < rhs.locationOnElement();
            }
        );
    }

    return directedElementIdToStopSnaps;
}

typedef std::vector<DirectedElementID> DirectedElementIds;
typedef std::map<ElementEnd, DirectedElementIds> IncidenceMap;

struct Incidence {
    IncidenceMap in;
    IncidenceMap out;
};

Incidence computeIncidence(const Elements& elements)
{
    Incidence incidence;

    for (const auto& element: elements) {
        if (isSet(Direction::Forward, element.direction())) {
            incidence.out[element.start()].emplace_back(element.id(), Direction::Forward);
            incidence.in[element.end()].emplace_back(element.id(), Direction::Forward);
        }
        if (isSet(Direction::Backward, element.direction())) {
            incidence.in[element.start()].emplace_back(element.id(), Direction::Backward);
            incidence.out[element.end()].emplace_back(element.id(), Direction::Backward);
        }
    }

    return incidence;
}

Direction directionToJunction(const Element& element, ID junctionId) {
    return element.end().junctionId() == junctionId
        ? Direction::Forward
        : Direction::Backward;
}

ID oppositeJunciton(const Element& element, ID junctionId)
{
    return element.end().junctionId() == junctionId
        ? element.start().junctionId()
        : element.end().junctionId();
}

bool endsWith(const Path& path, const Path& suffix)
{
    return (suffix.size() <= path.size()) && std::equal(
        suffix.rbegin(), suffix.rend(),
        path.rbegin()
    );
}

bool equalStopId(const OptionalStopSnap& lhs, const OptionalStopSnap& rhs)
{
    return rhs
        ? lhs && (lhs->stopId() == rhs->stopId())
        : !lhs;
}

bool lessStopId(const OptionalStopSnap& lhs, const OptionalStopSnap& rhs)
{
    return rhs
        ? lhs && (lhs->stopId() < rhs->stopId())
        : false;

}

bool less(const TracePoint& lhs, const TracePoint& rhs)
{
    if (lhs.directedElementId() < rhs.directedElementId()) {
        return true;
    }

    return lhs.directedElementId() == rhs.directedElementId()
        ? lessStopId(lhs.stopSnap(), rhs.stopSnap())
        : false;
}

bool equal(const TracePoint& lhs, const TracePoint& rhs)
{
    return lhs.directedElementId() == rhs.directedElementId()
        && equalStopId(lhs.stopSnap(), rhs.stopSnap());
}

} // namespace

CompoundNodeID::CompoundNodeID(const TracePoint& tracePoint, const Path& conditionPath)
    : tracePoint(tracePoint)
    , conditionPath(conditionPath)
{
    if (!conditionPath.empty()) {
        REQUIRE(
            conditionPath.back() == tracePoint.directedElementId(),
            "Invalid condition path"
        );
    }
}

bool operator<(const CompoundNodeID& lhs, const CompoundNodeID& rhs)
{
    return less(lhs.tracePoint, rhs.tracePoint)
        || (equal(lhs.tracePoint, rhs.tracePoint) && lhs.conditionPath < rhs.conditionPath);

}

ConditionPosition::ConditionPosition(ID conditionId, size_t pathPosition)
    : conditionId(conditionId)
    , pathPosition(pathPosition)
{}

ConditionIndex::ConditionIndex(const Elements& elements, const Conditions& conditions)
{
    std::map<ID, const Element&> idToElement;
    for (const auto& element: elements) {
        idToElement.insert({element.id(), element});
    }

    auto allElementsAvailable = [&idToElement](const Condition& cond) {
        return idToElement.count(cond.fromElementId())
            && std::all_of(
                cond.toElementIds().begin(), cond.toElementIds().end(),
                [&](const ID elementId) {
                    return idToElement.count(elementId);
                }
            );
    };

    for (const auto& cond: conditions) {
        if (!allElementsAvailable(cond)) {
            continue;
        }

        if (cond.type() == common::ConditionType::Prohibited) {
            ID junctionId = cond.viaJunctionId();
            auto& forbiddenConditionPath = idToForbiddenConditionPath_[cond.id()];
            for (size_t pos = 0; pos <= cond.toElementIds().size(); ++pos) {
                const auto& element = idToElement.at(
                    pos == 0 ? cond.fromElementId() : cond.toElementIds().at(pos - 1)
                );
                if (pos != 0) {
                    junctionId = oppositeJunciton(element, junctionId);
                }

                const DirectedElementID id{
                    element.id(),
                    directionToJunction(element, junctionId)
                };
                directedElementIdToForbiddenConditionPositions_[id].emplace_back(
                    cond.id(),
                    pos
                );
                forbiddenConditionPath.push_back(id);
            }
        } else if (cond.type() == common::ConditionType::Uturn) {
            const auto& element = idToElement.at(cond.fromElementId());
            isTurnaboutPossible_.insert(
                {element.id(), directionToJunction(element, cond.viaJunctionId())}
            );
        }
    }
}

bool ConditionIndex::isTurnaboutPossible(const DirectedElementID& id) const
{
    return isTurnaboutPossible_.count(id);
}

bool ConditionIndex::isForbiddenPath(const Path& path) const
{
    if (path.empty()) {
        return false;
    }

    const auto it = directedElementIdToForbiddenConditionPositions_.find(path.back());
    if (it == directedElementIdToForbiddenConditionPositions_.end()) {
        return false;
    }

    const auto& positions = it->second;
    return std::any_of(
        positions.begin(), positions.end(),
        [&](const ConditionPosition& position) {
            return endsWith(path, idToForbiddenConditionPath_.at(position.conditionId));
        }
    );
}

Path ConditionIndex::longestSuffixConditionPrefix(const Path& path) const
{
    Path result;

    const auto it = directedElementIdToForbiddenConditionPositions_.find(path.back());
    if (it != directedElementIdToForbiddenConditionPositions_.end()) {
        for (const auto& position: it->second) {
            const auto& forbiddenConditionPath = idToForbiddenConditionPath_.at(
                position.conditionId
            );
            Path forbiddenConditionPathPrefix {
                forbiddenConditionPath.begin(),
                forbiddenConditionPath.begin() + position.pathPosition + 1
            };
            if (endsWith(path, forbiddenConditionPathPrefix)
                    && forbiddenConditionPathPrefix.size() > result.size()) {
                result = std::move(forbiddenConditionPathPrefix);
            }
        }
    }

    return result;
}

StaticGraph::StaticGraph(
        const Elements& elements,
        const Stops& stops,
        const Conditions& conditions,
        double stopSnapToleranceMeters)
    : conditionIndex_(elements, conditions)
{
    const Incidence incidence = computeIncidence(elements);

    const auto directedElementIdToStopSnaps = snapStops(elements, stops, stopSnapToleranceMeters);
    for (const auto& pair: directedElementIdToStopSnaps) {
        const auto& directedElementId = pair.first;
        const auto& stopSnaps = pair.second;

        for (const auto& stopSnap: stopSnaps) {
            const auto point = PImplFactory::create<TracePoint>(
                directedElementId,
                stopSnap
            );
            stopIdToTracePoint_.insert({stopSnap.stopId(), point});
        }
    }

    auto tracePointSequence = [&](const DirectedElementID& directedElementId) {
        std::vector<TracePoint> tracePoints;

        const auto it = directedElementIdToStopSnaps.find(directedElementId);
        if (it != directedElementIdToStopSnaps.end()) {
            for (const auto& stopSnap: it->second) {
                tracePoints.push_back(
                    PImplFactory::create<TracePoint>(directedElementId, stopSnap)
                );
            }
        } else {
            tracePoints.push_back(
                PImplFactory::create<TracePoint>(directedElementId)
            );
        }

        return tracePoints;
    };

    for (const auto& pair: incidence.in) {
        const auto& viaJcId = pair.first;
        const auto it = incidence.out.find(viaJcId);
        for (const auto& fromElId: pair.second) {
            const auto fromTracePointSequence = tracePointSequence(fromElId);

            for (size_t i = 0; i + 1 < fromTracePointSequence.size(); ++i) {
                const auto startNodeId = getNodeId(fromTracePointSequence[i]);
                const auto endNodeId = getNodeId(fromTracePointSequence[i + 1]);
                outEdgesByNodeId_[startNodeId] = graph::Edges{{startNodeId, endNodeId}};
            }

            const auto nodeId = getNodeId(fromTracePointSequence.back());
            if (it == incidence.out.end()) {
                outEdgesByNodeId_[nodeId] = graph::Edges();
                continue;
            }

            graph::Edges edgeList;
            for (const auto& toElId: it->second) {
                const auto toTracePointSequence = tracePointSequence(toElId);

                const auto conditionPath = conditionIndex_.longestSuffixConditionPrefix(
                    {fromElId, toElId}
                );

                const auto startNodeId = getNodeId(fromTracePointSequence.back());
                const auto endNodeId = getNodeId(toTracePointSequence.front(), conditionPath);

                if (conditionIndex_.isForbiddenPath(conditionPath)) {
                    continue;
                }

                if (fromElId.id() == toElId.id()
                        && !conditionIndex_.isTurnaboutPossible(fromElId)) {
                    continue;
                }
                edgeList.emplace_back(startNodeId, endNodeId);
            }

            outEdgesByNodeId_[nodeId] = std::move(edgeList);
        }
    }
}

const TracePoint& StaticGraph::getTracePointByStopId(ID stopId) const
{
    const auto it = stopIdToTracePoint_.find(stopId);
    REQUIRE(
        it != stopIdToTracePoint_.end(),
        "There is no stop with id " << stopId
    );

    return it->second;
}

const TracePoint& StaticGraph::getTracePointByNodeId(graph::NodeID nodeId) const
{
    return getCompoundNodeId(nodeId).tracePoint;
}

graph::NodeID StaticGraph::getNodeId(const TracePoint& tracePoint)
{
    return nodeIdMap_(CompoundNodeID{tracePoint, {}});
}

graph::NodeID StaticGraph::getNodeId(
        const TracePoint& tracePoint,
        const Path& conditionPath)
{
    return nodeIdMap_(CompoundNodeID{tracePoint, conditionPath});
}

const CompoundNodeID& StaticGraph::getCompoundNodeId(graph::NodeID nodeId) const
{
    return nodeIdMap_[nodeId];
}

graph::Edges StaticGraph::outEdges(graph::NodeID nodeId)
{
    const auto it = outEdgesByNodeId_.find(nodeId);
    if (it != outEdgesByNodeId_.end()) {
        return it->second;
    }

    const auto& fromNodeId = getCompoundNodeId(nodeId);
    ASSERT(!fromNodeId.conditionPath.empty());

    graph::Edges edgeList;
    const CompoundNodeID fromNodeIdWithoutPath{fromNodeId.tracePoint, {}};
    for (const auto& edge: outEdges(nodeIdMap_(fromNodeIdWithoutPath))) {
        const auto& toNodeId = getCompoundNodeId(edge.endNodeId());

        auto conditionPath = fromNodeId.conditionPath;
        if (conditionPath.back() != toNodeId.tracePoint.directedElementId()) {
            conditionPath.push_back(toNodeId.tracePoint.directedElementId());
            conditionPath = conditionIndex_.longestSuffixConditionPrefix(conditionPath);
        }

        if (!conditionIndex_.isForbiddenPath(conditionPath)) {
            edgeList.emplace_back(
                nodeId,
                nodeIdMap_(CompoundNodeID{toNodeId.tracePoint, conditionPath}),
                edge.weight()
            );
        }
    }

    return outEdgesByNodeId_[nodeId] = std::move(edgeList);
}

} // namespace routing
} // namespace wiki
} // namespace maps
