#include "minimal_overhead.h"

#include "path_search.h"

#include <maps/libs/log8/include/log8.h>

namespace maps::mrc::gen_targets {
namespace {

Seconds getPathTime(const RoadNetworkData& roadNetwork,
                   const std::vector<EdgeId>& path) {
    Seconds time = 0;
    for (size_t i = 0; i < path.size(); i++) {
            auto& edge = roadNetwork.edge(path[i]);
            time += edge.time;
            if (i != 0) {
                time += roadNetwork.getManeuverPenalty(path[i-1],
                                                       path[i]);
            }
    }
    return time;
}

} // anonymous namespace

OptimalOverheadCreator::OptimalOverheadCreator(ExtendedRoadNetwork& roadNetwork)
    : roadNetwork_(roadNetwork)
{
    auto overheadPoints = createOverheadStartAndEndPoints();
    auto& overheadStartsSet = overheadPoints.first;
    auto& overheadEndsSet = overheadPoints.second;

    auto overheadStartAndEndPairs = getOptimalStartAndEndPairs(
        overheadStartsSet, overheadEndsSet);
    addNecessaryEdgesCopies(overheadStartAndEndPairs);
}

std::pair<OverheadStartsSet, OverheadEndsSet>
OptimalOverheadCreator::createOverheadStartAndEndPoints()
{
    OverheadStartsSet overheadStartsSet;
    OverheadEndsSet overheadEndsSet;

    auto deadEndTargetEdges = findDeadEndTargetEdges();
    auto& edgesWithoutOutputTargets = deadEndTargetEdges.first;
    auto& edgesWithoutInputTargets = deadEndTargetEdges.second;

    for (EdgeId edgeId : edgesWithoutOutputTargets) {
        EdgeId newEdgeId = roadNetwork_.generateNewEdgeId();
        createOverheadStartFromEdge(edgeId, newEdgeId);
        overheadStartsSet.insert(newEdgeId);
    }
    for (EdgeId edgeId : edgesWithoutInputTargets) {
        EdgeId newEdgeId = roadNetwork_.generateNewEdgeId();
        createOverheadEndFromEdge(edgeId, newEdgeId);
        overheadEndsSet.insert(newEdgeId);
    }

    auto overheadPoints = handleUnbalancedNodes(edgesWithoutOutputTargets,
                                                edgesWithoutInputTargets);
    overheadStartsSet.insert(overheadPoints.first.begin(),
                             overheadPoints.first.end());
    overheadEndsSet.insert(overheadPoints.second.begin(),
                           overheadPoints.second.end());

    return {overheadStartsSet, overheadEndsSet};
}

std::pair<EdgesWithoutOutputTargets, EdgesWithoutInputTargets>
OptimalOverheadCreator::findDeadEndTargetEdges()
{
    EdgesWithoutOutputTargets edgesWithoutOutputTargets;
    EdgesWithoutInputTargets edgesWithoutInputTargets;

    for (auto edgeIt : roadNetwork_.edges()) {
        const Edge& edge = edgeIt.second;
        if (!edge.isTarget) {
            continue;
        }

        auto isTarget = [this](EdgeId edgeId) {
            return roadNetwork_.edge(edgeId).isTarget;
        };
        int inTargets = std::count_if(
            edge.inEdges.begin(), edge.inEdges.end(), isTarget);
        int outTargets = std::count_if(
            edge.outEdges.begin(), edge.outEdges.end(), isTarget);

        if (inTargets == 0) {
            edgesWithoutInputTargets.insert(edge.id);
        }
        if (outTargets == 0) {
            edgesWithoutOutputTargets.insert(edge.id);
        }
    }

    return {edgesWithoutOutputTargets, edgesWithoutInputTargets};
}

std::pair<OverheadStartsSet, OverheadEndsSet> OptimalOverheadCreator::handleUnbalancedNodes(
    const EdgesWithoutOutputTargets& edgesWithoutOutputTargets,
    const EdgesWithoutInputTargets& edgesWithoutInputTargets)
{
    OverheadStartsSet overheadStartsSet;
    OverheadEndsSet overheadEndsSet;

    for (const auto& node: roadNetwork_.nodes()) {
        int inTargets = 0;
        int outTargets = 0;
        for (EdgeId outEdgeId : node.second.outEdges) {
            if (roadNetwork_.edge(outEdgeId).isTarget
                && !edgesWithoutInputTargets.count(outEdgeId)) {
                outTargets++;
            }
        }
        for (EdgeId inEdgeId : node.second.inEdges) {
            if (roadNetwork_.edge(inEdgeId).isTarget
                && !edgesWithoutOutputTargets.count(inEdgeId)) {
                inTargets++;
            }
        }

        if (inTargets != outTargets) {
            auto overheadPoints = handleUnbalancedNode(
                node.second, inTargets, outTargets);
            overheadStartsSet.insert(overheadPoints.first.begin(),
                                     overheadPoints.first.end());
            overheadEndsSet.insert(overheadPoints.second.begin(),
                                   overheadPoints.second.end());
        }
    }
    return {overheadStartsSet, overheadEndsSet};
}

std::pair<OverheadStartsSet, OverheadEndsSet> OptimalOverheadCreator::handleUnbalancedNode(
    const Node& node, int nodeInTargets, int nodeOutTargets)
{
    OverheadStartsSet overheadStartsSet;
    OverheadEndsSet overheadEndsSet;

    for (int i = 0; i < nodeInTargets - nodeOutTargets; i++) {
        EdgeId newEdgeId = roadNetwork_.generateNewEdgeId();
        createOverheadStartFromNode(node, newEdgeId);
        overheadStartsSet.insert(newEdgeId);
    }
    for (int i = 0; i < nodeOutTargets - nodeInTargets; i++) {
        EdgeId newEdgeId = roadNetwork_.generateNewEdgeId();
        createOverheadEndFromNode(node, newEdgeId);
        overheadEndsSet.insert(newEdgeId);
    }
    return {overheadStartsSet, overheadEndsSet};
}

void OptimalOverheadCreator::addEdgeCopy(EdgeId sourceEdgeId, EdgeId newEdgeId) {
    Edge newEdge = roadNetwork_.edge(sourceEdgeId);
    newEdge.id = newEdgeId;
    newEdge.isTarget = true;
    roadNetwork_.addEdge(newEdge);
    originalId_[newEdge.id] = getOriginalId(sourceEdgeId);
}

void OptimalOverheadCreator::createOverheadStartFromEdge(EdgeId sourceEdgeId,
                                                         EdgeId newEdgeId)
{
    Edge newEdge;
    newEdge.id = newEdgeId;
    newEdge.length = 0;
    newEdge.time = 0;
    newEdge.isTarget = false;
    newEdge.geom = roadNetwork_.edge(sourceEdgeId).geom;
    newEdge.outEdges = roadNetwork_.edge(sourceEdgeId).outEdges;
    newEdge.outEdgesPenalties = roadNetwork_.edge(sourceEdgeId).outEdgesPenalties;
    roadNetwork_.addEdge(newEdge);
}

void OptimalOverheadCreator::createOverheadEndFromEdge(EdgeId sourceEdgeId,
                                                       EdgeId newEdgeId)
{
    Edge newEdge;
    newEdge.id = newEdgeId;
    newEdge.length = 0;
    newEdge.time = 0;
    newEdge.isTarget = false;
    newEdge.geom = roadNetwork_.edge(sourceEdgeId).geom;
    newEdge.inEdges = roadNetwork_.edge(sourceEdgeId).inEdges;
    roadNetwork_.addEdge(newEdge);
}

void OptimalOverheadCreator::createOverheadStartFromNode(const Node& node,
                                                         EdgeId newEdgeId)
{
    Edge newEdge;
    newEdge.id = newEdgeId;
    newEdge.length = 0;
    newEdge.time = 0;
    newEdge.isTarget = false;

    // selects output edges, which have several input targets
    for (EdgeId outEdgeId : node.outEdges) {
        const Edge& outEdge = roadNetwork_.edge(outEdgeId);
        int inTargets = 0;
        geolib3::Polyline2 geom;
        for (EdgeId inEdgeId : outEdge.inEdges) {
            if (roadNetwork_.edge(inEdgeId).isTarget) {
                geom = roadNetwork_.edge(inEdgeId).geom;
                inTargets++;
            }
        }
        if (inTargets > 1 || (!outEdge.isTarget && inTargets > 0)) {
            newEdge.outEdges.push_back(outEdgeId);
            newEdge.outEdgesPenalties.push_back(0);
            // geom value is not important
            newEdge.geom = geom;
        }
    }
    roadNetwork_.addEdge(newEdge);
}

void OptimalOverheadCreator::createOverheadEndFromNode(const Node& node,
                                                       EdgeId newEdgeId)
{
    Edge newEdge;
    newEdge.id = newEdgeId;
    newEdge.length = 0;
    newEdge.time = 0;
    newEdge.isTarget = false;

    // selects input edges, which have several output targets
    for (EdgeId inEdgeId : node.inEdges) {
        const Edge& inEdge = roadNetwork_.edge(inEdgeId);
        int outTargets = 0;
        geolib3::Polyline2 geom;
        for (EdgeId outEdgeId : inEdge.outEdges) {
            if (roadNetwork_.edge(outEdgeId).isTarget) {
                geom = roadNetwork_.edge(outEdgeId).geom;
                outTargets++;
            }
        }
        if (outTargets > 1 || (!inEdge.isTarget && outTargets > 0)) {
            newEdge.inEdges.push_back(inEdgeId);
            // geom value is not important
            newEdge.geom = geom;
        }
    }
    roadNetwork_.addEdge(newEdge);
}

std::set<std::pair<EdgeId, EdgeId>> OptimalOverheadCreator::getOptimalStartAndEndPairs(
    const OverheadStartsSet& overheadStartsSet,
    const OverheadEndsSet& overheadEndsSet)
{
    std::unordered_set<EdgeId> allEdges = roadNetwork_.getSetOfEdgeIds();
    std::set<std::pair<EdgeId, EdgeId>> startAndEndPairs;

    auto addStartAndEndPair = [&](EdgeId overheadStart, EdgeId overheadEnd) {
        startAndEndPairs.insert({overheadStart, overheadEnd});
        auto path = PathSearch(roadNetwork_,
                               overheadStart,
                               std::unordered_set<EdgeId>{overheadEnd},
                               allEdges).getResult();
        // insert corresponding backward edge with negative length value
        roadNetwork_.connectEdges(
            overheadEnd,
            overheadStart,
            -0.999 * getPathTime(roadNetwork_, path));
    };

    auto removeStartAndEndPair = [&](EdgeId overheadStart, EdgeId overheadEnd) {
        REQUIRE(startAndEndPairs.count({overheadStart, overheadEnd}),
                "no reverse overhead");
        startAndEndPairs.erase({overheadStart, overheadEnd});
        // remove corresponding backward negative edge
        roadNetwork_.disconnectEdges(overheadEnd, overheadStart);
    };

    std::unordered_set<EdgeId> remainingOverheadEnds = overheadEndsSet;

    for (EdgeId startEdgeId : overheadStartsSet) {
        // find the closest end point (maybe using negative edges)
        std::vector<EdgeId> path = PathSearch(
            roadNetwork_, startEdgeId,
            remainingOverheadEnds,
            allEdges).getResult();
        // path contains positive and negative edges. It has this structure:
        // {overheadStart} -> positive path ->
        // [overheadEnd -> negativeEdge -> overheadStart -> positive path] x N
        // -> overheadEnd
        // Each positive path represents new start and end pair.
        // Each negative edge means that we should delete the
        // corresponding start and end pair from the current pairs.

        // filter only start and end points
        std::vector<EdgeId> startAndEndPoints{startEdgeId};
        for (EdgeId pathEdge : path) {
            if (overheadEndsSet.count(pathEdge)
                || overheadStartsSet.count(pathEdge)) {
                startAndEndPoints.push_back(pathEdge);
            }
        }
        REQUIRE(startAndEndPoints.size() % 2 == 0, "broken startAndEndPoints");
        // each even point is an overhead start, each odd point is an
        // overhead end

        // for each negative part of the path
        for (size_t i = 1; i < startAndEndPoints.size()-1; i += 2) {
            removeStartAndEndPair(startAndEndPoints[i+1], startAndEndPoints[i]);
        }
        // for each positive part of the path
        for (size_t i = 0; i < startAndEndPoints.size(); i += 2) {
            addStartAndEndPair(startAndEndPoints[i], startAndEndPoints[i+1]);
        }
        remainingOverheadEnds.erase(startAndEndPoints.back());
    }

    // disconnect all the temporary negative edges from the graph
    for (auto& it : startAndEndPairs) {
        roadNetwork_.disconnectEdges(it.second, it.first);
    }
    return startAndEndPairs;
}

void OptimalOverheadCreator::addNecessaryEdgesCopies(
    std::set<std::pair<EdgeId, EdgeId>>& startAndEndPairs)
{
    std::unordered_set<EdgeId> allEdges = roadNetwork_.getSetOfEdgeIds();
    double overheadLength = 0;
    for (auto& it : startAndEndPairs) {
        EdgeId overheadStart = it.first;
        EdgeId overheadEnd = it.second;
        std::vector<EdgeId> path = PathSearch(
            roadNetwork_,
            overheadStart,
            std::unordered_set<EdgeId>{overheadEnd},
            allEdges).getResult();
        path.pop_back();
        for (EdgeId pathEdge : path) {
            overheadLength += roadNetwork_.edge(pathEdge).length;
            EdgeId newEdgeId = roadNetwork_.generateNewEdgeId();
            addEdgeCopy(pathEdge, newEdgeId);
            allEdges.insert(newEdgeId);
        }
    }
    INFO() << "overhead length: " << overheadLength;
}

EdgeId OptimalOverheadCreator::getOriginalId(EdgeId edgeId)
{
    if (originalId_.count(edgeId)) {
        return originalId_[edgeId];
    } else {
        return edgeId;
    }
}

} // namespace maps::mrc::gen_targets
