#include "connected_components.h"

#include <stack>

namespace maps::mrc::gen_targets {

namespace {

enum class VisitDirection {
    DirectVisit,
    ReturnVisit,
};

// This function is a non-recursive version of this code:
//
// void dfsVisit(EdgeId edgeId) {
//     // Now it is the direct dfs visit of the edge (moving up the tree)
//     if (isVisited(edgeId)) {
//         return;
//     }
//     actionBeforeChildrenWillBeVisited(edgeId);
//     for (childEdgeId : children(edgeId)) {
//         dfsVisit(childEdgeId);
//     }
//     // Now it is the return dfs visit of the edge (moving down the tree)
//     actionAfterChildrenWereVisited(edgeId);
// }
//
// dfsVisit(startEdgeId);
//
template<typename ChildrenF>
void dfsVisitor(
    EdgeId startEdgeId,
    ChildrenF children,
    std::function<bool(EdgeId)> isVisited,
    std::function<void(EdgeId)> actionBeforeChildrenWillBeVisited,
    std::function<void(EdgeId)> actionAfterChildrenWereVisited = [](EdgeId){})
{
    std::stack<std::pair<EdgeId, VisitDirection>> dfsStack;
    dfsStack.push({startEdgeId, VisitDirection::DirectVisit});

    while (!dfsStack.empty()) {
        EdgeId curEdgeId = dfsStack.top().first;
        VisitDirection& visitDirection = dfsStack.top().second;

        if (visitDirection == VisitDirection::ReturnVisit) {
            // moving down the dfs tree
            actionAfterChildrenWereVisited(curEdgeId);
            dfsStack.pop();
            continue;
        }
        // moving up the dfs tree

        if (isVisited(curEdgeId)) {
            dfsStack.pop();
            continue;
        }

        actionBeforeChildrenWillBeVisited(curEdgeId);
        for (EdgeId childEdgeId : children(curEdgeId)) {
            dfsStack.push({childEdgeId, VisitDirection::DirectVisit});
        }
        visitDirection = VisitDirection::ReturnVisit;
    }
}

// Part of the Kosaraju's algorithm
// Does reverse dfs and returns the order of edges in which
// they dissapear from the dfs path stack.
std::vector<EdgeId> getReverseDfsOrder(const RoadNetworkData& roadNetwork,
                                       std::function<bool(EdgeId)> filter)
{
    std::unordered_set<EdgeId> visitedEdges;
    std::vector<EdgeId> kosarajuOrder;
    kosarajuOrder.reserve(roadNetwork.getEdges().size());

    const auto getInEdges = [&](EdgeId edgeId) {
        return roadNetwork.edge(edgeId).inEdges;
    };
    const auto isVisited = [&] (EdgeId edgeId) {
        return visitedEdges.count(edgeId) || !filter(edgeId);
    };
    const auto actionBeforeChildrenWillBeVisited = [&](EdgeId edgeId) {
        visitedEdges.insert(edgeId);
    };
    const auto actionAfterChildrenWereVisited = [&](EdgeId edgeId) {
        kosarajuOrder.push_back(edgeId);
    };

    for (const auto& edge : roadNetwork.edges()) {
        dfsVisitor(edge.first, getInEdges, isVisited,
                   actionBeforeChildrenWillBeVisited,
                   actionAfterChildrenWereVisited);
    }

    return kosarajuOrder;
}
} // namespace

// Kosaraju's algorithm
std::unordered_map<EdgeId, ComponentId>
getStronglyConnectedComponents(const RoadNetworkData& roadNetwork,
                               std::function<bool(EdgeId)> filter)
{
    const auto reverseDfsOrder = getReverseDfsOrder(roadNetwork, filter);
    const auto getOutEdges
        = [&](EdgeId edgeId) { return roadNetwork.edge(edgeId).outEdges; };

    std::unordered_map<EdgeId, ComponentId> edgeIdToComponentId;
    size_t numberOfComponents = 0;

    const auto isVisited = [&] (EdgeId edgeId) {
        return edgeIdToComponentId.count(edgeId) || !filter(edgeId);
    };
    const auto visitEdge = [&](EdgeId edgeId) {
        edgeIdToComponentId[edgeId] = numberOfComponents;
    };

    for (auto edgeIt = reverseDfsOrder.rbegin();
         edgeIt != reverseDfsOrder.rend();
         ++edgeIt) {
        if (edgeIdToComponentId.count(*edgeIt)) {
            continue;
        }
        // Run direct DFS and assing current component ID to all eges that are
        // reachable from the current one.
        dfsVisitor(*edgeIt, getOutEdges, isVisited, visitEdge);
        numberOfComponents++;
    }
    return edgeIdToComponentId;
}

// finds connected conponent using dfs where children of each edge are
// input and output edges.
std::unordered_map<EdgeId, ComponentId>
getWeaklyConnectedComponents(const RoadNetworkData& roadNetwork,
                             std::function<bool(EdgeId)> filter)
{
    const auto getInAndOutEdges = [&](EdgeId edgeId) {
        const Edge& edge = roadNetwork.edge(edgeId);
        std::vector<EdgeId> edges = edge.outEdges;
        edges.insert(edges.end(), edge.inEdges.begin(), edge.inEdges.end());
        return edges;
    };

    std::unordered_map<EdgeId, ComponentId> edgeIdToComponentId;
    size_t numberOfComponents = 0;

    const auto isVisited = [&] (EdgeId edgeId) {
        return !filter(edgeId) || edgeIdToComponentId.count(edgeId);
    };
    const auto visitEdge = [&](EdgeId edgeId) {
        edgeIdToComponentId[edgeId] = numberOfComponents;
        return true;
    };

    for (const auto& edge : roadNetwork.edges()) {
        if (!filter(edge.first) || edgeIdToComponentId.count(edge.first)) {
            continue;
        }
        dfsVisitor(edge.first, getInAndOutEdges, isVisited, visitEdge);
        numberOfComponents++;
    }
    return edgeIdToComponentId;
}

} // namespace maps::mrc::gen_targets
