#include <yandex/maps/wiki/routing/exception.h>
#include <yandex/maps/wiki/routing/route.h>

#include "route_impl.h"
#include "static_graph.h"

#include <yandex/maps/wiki/graph/graph.h>
#include <yandex/maps/wiki/graph/traversal.h>

#include <maps/libs/common/include/exception.h>

#include <algorithm>
#include <functional>
#include <map>
#include <memory>
#include <queue>


namespace maps {
namespace wiki {
namespace routing {

COPYABLE_PIMPL_DEFINITIONS(StopSnap)

ID StopSnap::stopId() const { return impl_->stopId; }

double StopSnap::locationOnElement() const { return impl_->locationOnElement; }

const Point& StopSnap::point() const { return impl_->point; }


COPYABLE_PIMPL_DEFINITIONS(TracePoint)

const DirectedElementID& TracePoint::directedElementId() const { return impl_->directedElementId; }

const OptionalStopSnap& TracePoint::stopSnap() const { return impl_->stopSnap; }


COPYABLE_PIMPL_DEFINITIONS(NoPathError)

ID NoPathError::fromStopId() const { return impl_->fromStopId; }

ID NoPathError::toStopId() const { return impl_->toStopId; }


COPYABLE_PIMPL_DEFINITIONS(AmbiguousPathError)

ID AmbiguousPathError::fromStopId() const { return impl_->fromStopId; }

ID AmbiguousPathError::toStopId() const { return impl_->toStopId; }

ID AmbiguousPathError::elementId() const { return impl_->elementId; }


MOVABLE_PIMPL_DEFINITIONS(RestoreResult)

const Trace& RestoreResult::trace() const { return impl_->trace; }

const IdSet& RestoreResult::unusedElementIds() const { return impl_->unusedElementIds; }

const AmbiguousPathErrors& RestoreResult::forwardAmbiguousPathErrors() const
{
    return impl_->forwardAmbiguousPathErrors;
}

const AmbiguousPathErrors& RestoreResult::backwardAmbiguousPathErrors() const
{
    return impl_->backwardAmbiguousPathErrors;
}

const NoPathErrors& RestoreResult::noPathErrors() const { return impl_->noPathErrors; }

RestoreResult restore(
        const Elements& elements,
        const Stops& stops,
        double stopSnapToleranceMeters)
{
    return restore(elements, stops, {}, stopSnapToleranceMeters);
}

bool RestoreResult::isOk() const
{
    return noPathErrors().empty() && forwardAmbiguousPathErrors().empty();
}

namespace {

typedef std::map<graph::NodeID, graph::NodeID> NodeIdMap;
typedef std::map<graph::NodeID, size_t> NodeIdToReverseOrder;

IdSet collectElementIds(const Elements& elements)
{
    IdSet elementIds;

    for (const auto& element: elements) {
        elementIds.insert(element.id());
    }

    return elementIds;
}


graph::NodeIdSet collectAllPathStartNodeIds(
        const NodeIdToEdges& nodeIdToVisitedInEdges,
        const NodeIdToReverseOrder& pathNodeIdToReverseOrder,
        ID nodeId)
{
    graph::NodeIdSet result;

    std::queue<graph::NodeID> nodeIdQueue;
    nodeIdQueue.push(nodeId);

    graph::NodeIdSet processedNodeIds;
    while (!nodeIdQueue.empty()) {
        const graph::NodeID nodeId = nodeIdQueue.front();
        nodeIdQueue.pop();

        for (const auto& edge: nodeIdToVisitedInEdges.at(nodeId)) {
            const graph::NodeID startNodeId = edge.startNodeId();

            if (pathNodeIdToReverseOrder.count(startNodeId)) {
                result.insert(startNodeId);
                continue;
            }

            if (!processedNodeIds.count(startNodeId) && nodeIdToVisitedInEdges.count(startNodeId)) {
                nodeIdQueue.push(startNodeId);
                processedNodeIds.insert(startNodeId);
            }
        }
    }

    return result;
}

bool allStartNodesInPath(
        const NodeIdToReverseOrder& pathNodeIdToReverseOrder,
        const graph::Edges& edges)
{
    for (const auto& edge: edges) {
        if (!pathNodeIdToReverseOrder.count(edge.startNodeId())) {
            return false;
        }
    }

    return true;
}

struct NodeIdsWithAmbiguousPath {
    IdSet forward;
    IdSet backward;
};

NodeIdsWithAmbiguousPath findNodeWithAmbiguousPath(
        /*const StaticGraph& graph,*/
        const NodeIdToReverseOrder& pathNodeIdToReverseOrder,
        const NodeIdToEdges& nodeIdToVisitedInEdges)
{
    NodeIdsWithAmbiguousPath result;

    for (const auto pair: pathNodeIdToReverseOrder) {
        const graph::NodeID nodeId = pair.first;

        const auto it = nodeIdToVisitedInEdges.find(nodeId);
        if (it == nodeIdToVisitedInEdges.end()) {
            // start node
            continue;
        }

        const graph::Edges& inEdges = it->second;
        if (inEdges.size() < 2 || allStartNodesInPath(pathNodeIdToReverseOrder, inEdges)) {
            continue;
        }

        const graph::NodeIdSet startPathNodeIds = collectAllPathStartNodeIds(
            nodeIdToVisitedInEdges,
            pathNodeIdToReverseOrder,
            nodeId
        );

        const size_t reverseOrder = pathNodeIdToReverseOrder.at(nodeId);
        for(graph::NodeID startNodeId: startPathNodeIds) {
            const size_t startReverseOrder = pathNodeIdToReverseOrder.at(startNodeId);
            if (startReverseOrder <= reverseOrder) {
                result.backward.insert(nodeId);
            }
            if (startReverseOrder > reverseOrder + 1) {
                result.forward.insert(nodeId);
            }
        }
    }

    return result;
}


struct SearchResult {
    NodeIdMap prevNodeMap;
    NodeIdToEdges nodeIdToVisitedInEdges;
    boost::optional<graph::NodeID> toNodeId;
};

SearchResult search(StaticGraph& graph, graph::NodeID fromNodeId, ID toStopId)
{
    NodeIdMap prevNodeMap;
    NodeIdToEdges nodeIdToVisitedInEdges;
    boost::optional<graph::NodeID> toNodeId;

    graph::breadthFirstSearch(
        /* outEdges = */ [&](graph::NodeID nodeId) -> graph::Edges {
            auto edges = graph.outEdges(nodeId);
            for (const auto& edge: edges) {
                nodeIdToVisitedInEdges[edge.endNodeId()].push_back(edge);
            }
            return edges;
        },
        /* step = */ [&](graph::NodeID from, graph::NodeID to, size_t /*distance*/) {
            prevNodeMap.emplace(to, from);

            const auto& tracePoint = graph.getTracePointByNodeId(to);
            if (tracePoint.stopSnap() && tracePoint.stopSnap()->stopId() == toStopId) {
                ASSERT(!toNodeId);
                toNodeId = to;
            }
        },
        fromNodeId,
        /* isLockNode = */ [&](const ID nodeId) {
            const auto& tracePoint = graph.getTracePointByNodeId(nodeId);
            return tracePoint.stopSnap() && nodeId != fromNodeId;
        }
    );

    return SearchResult {
        prevNodeMap,
        nodeIdToVisitedInEdges,
        toNodeId
    };
}

void addAmbiguousPathErrors(
        AmbiguousPathErrors& errors,
        const StaticGraph& graph,
        ID fromStopId,
        ID toStopId,
        const graph::NodeIdSet& nodeIds)
{
    for (graph::NodeID nodeId: nodeIds) {
        const TracePoint& tracePoint = graph.getTracePointByNodeId(nodeId);
        errors.push_back(
            PImplFactory::create<AmbiguousPathError>(
                fromStopId,
                toStopId,
                tracePoint.directedElementId().id()
            )
        );
    }
}

} // namespace


RestoreResult restore(
        const Elements& elements,
        const Stops& stops,
        const Conditions& conditions,
        double stopSnapToleranceMeters)
{
    if (elements.empty()) {
        throw BadParam("Elements' list is empty");
    }

    if (stops.size() < 2) {
        throw BadParam("Need at least two stops");
    }

    if (stopSnapToleranceMeters <= 0) {
        throw BadParam("Stop snap tolerance must be greater than 0");
    }

    StaticGraph graph{elements, stops, conditions, stopSnapToleranceMeters};

    auto getNodeId = [&graph](ID stopId) {
        return graph.getNodeId(graph.getTracePointByStopId(stopId));
    };

    Trace trace;
    IdSet unusedElementIds = collectElementIds(elements);
    AmbiguousPathErrors forwardAmbiguousPathErrors;
    AmbiguousPathErrors backwardAmbiguousPathErrors;
    NoPathErrors noPathErrors;

    graph::NodeID fromNodeId = getNodeId(stops.front().id());
    bool needAddStarPoint = true;
    for (size_t i = 0; i + 1 < stops.size(); ++i) {
        const auto& fromStop = stops[i];
        const auto& toStop = stops[i + 1];

        if (fromStop.id() == toStop.id()) {
            continue;
        }

        const auto result = search(graph, fromNodeId, toStop.id());
        if (!result.toNodeId) {
            noPathErrors.push_back(
                PImplFactory::create<NoPathError>(fromStop.id(), toStop.id())
            );

            fromNodeId = getNodeId(toStop.id());
            needAddStarPoint = true;
            continue;
        }

        const TracePoint& fromTracePoint = graph.getTracePointByNodeId(fromNodeId);
        unusedElementIds.erase(fromTracePoint.directedElementId().id());

        NodeIdToReverseOrder pathNodeIdToReverseOrder;
        Trace subtraceWithoutStartPoint;
        NodeIdToReverseOrder nodeIdToReverseOrder;
        for (ID nodeId = *result.toNodeId; nodeId != fromNodeId; ) {
            const TracePoint& tracePoint = graph.getTracePointByNodeId(nodeId);

            subtraceWithoutStartPoint.push_back(tracePoint);
            unusedElementIds.erase(tracePoint.directedElementId().id());
            pathNodeIdToReverseOrder.emplace(nodeId, pathNodeIdToReverseOrder.size());

            nodeId = result.prevNodeMap.at(nodeId);
        }
        pathNodeIdToReverseOrder.emplace(fromNodeId, pathNodeIdToReverseOrder.size());
        std::reverse(subtraceWithoutStartPoint.begin(), subtraceWithoutStartPoint.end());

        const auto nodeIdsWithAmbiguousPath = findNodeWithAmbiguousPath(
            pathNodeIdToReverseOrder,
            result.nodeIdToVisitedInEdges
        );

        addAmbiguousPathErrors(
            forwardAmbiguousPathErrors,
            graph,
            fromStop.id(),
            toStop.id(),
            nodeIdsWithAmbiguousPath.forward
        );

        addAmbiguousPathErrors(
            backwardAmbiguousPathErrors,
            graph,
            fromStop.id(),
            toStop.id(),
            nodeIdsWithAmbiguousPath.backward
        );

        if (needAddStarPoint) {
            trace.push_back(fromTracePoint);
        }
        needAddStarPoint = false;

        trace.insert(
            trace.end(),
            subtraceWithoutStartPoint.begin(), subtraceWithoutStartPoint.end()
        );

        fromNodeId = *result.toNodeId;
    }

    return PImplFactory::create<RestoreResult>(
        std::move(trace),
        std::move(unusedElementIds),
        std::move(forwardAmbiguousPathErrors),
        std::move(backwardAmbiguousPathErrors),
        std::move(noPathErrors)
    );
}

} // namespace routing
} // namespace wiki
} // namespace maps


