#include <maps/wikimap/mapspro/services/mrc/libs/graph/include/graph.h>

#include <maps/wikimap/mapspro/services/mrc/libs/graph/impl/condition_index.h>
#include <maps/wikimap/mapspro/services/mrc/libs/graph/impl/trace_index.h>

#include <map>
#include <type_traits>

namespace maps::mrc::graph {

namespace {

using IncidenceMap = std::map<TId, DirectedIds>;

struct Incidence {
    IncidenceMap in;
    IncidenceMap out;
};

Incidence computeIncidence(const RoadElements& elements, AccessId accessId)
{
    IncidenceMap in, out;

    for (const auto& element: elements) {
        if (!isSet(accessId, element.accessId())) {
            continue;
        }

        if (ymapsdf::rd::isSet(Direction::Forward, element.direction())) {
            out[element.startJunctionId()].emplace_back(element.id(), Direction::Forward);
            in[element.endJunctionId()].emplace_back(element.id(), Direction::Forward);
        }

        if (ymapsdf::rd::isSet(Direction::Backward, element.direction())) {
            in[element.startJunctionId()].emplace_back(element.id(), Direction::Backward);
            out[element.endJunctionId()].emplace_back(element.id(), Direction::Backward);
        }
    }

    return {std::move(in), std::move(out)};
}

Trace addToEnd(Trace trace, const DirectedId& directedId)
{
    trace.push_back(directedId);
    return trace;
}

} // namespace

using Adjacency = std::map<DirectedId, DirectedIds>;

struct Graph::Impl {
    Impl(
            AccessId accessId,
            const RoadElements& elements,
            const Conditions& conditions);

    AccessId accessId;
    TraceIndex traceIndex;
    ConditionIndex conditionIndex;
    Adjacency adjacency;
    ElementById elementById;
};

Graph::Impl::Impl(AccessId accessId, const RoadElements& elements, const Conditions& conditions)
    : accessId(accessId)
    , conditionIndex(accessId, elements, conditions)
{
    for (const auto& element: elements) {
        elementById.emplace(element.id(), element);
    }

    const auto [in, out] = computeIncidence(elements, accessId);

    for (const auto& [viaJunctionId, fromElementIds]: in) {
        const auto it = out.find(viaJunctionId);

        if (it == out.end()) {
            continue;
        }

        for (const auto& toId: it->second) {
            const int toZLev = elementById.at(toId.roadElementId()).zLevel(viaJunctionId);

            for (const auto& fromId: fromElementIds) {
                const int fromZLev = elementById.at(fromId.roadElementId()).zLevel(viaJunctionId);

                if (fromZLev == toZLev && fromId.roadElementId() != toId.roadElementId()) {
                    adjacency[fromId].push_back(toId);
                }
            }
        }
    }
}

Graph::Graph(AccessId accessId, const RoadElements& elements, const Conditions& conditions)
    : impl_(new Impl(accessId, elements, conditions))
{}

Graph::~Graph() {};

AccessId Graph::accessId() const { return impl_->accessId; }

bool Graph::hasElement(TId id) const { return impl_->elementById.count(id); }

const RoadElement& Graph::elementById(TId id) const
{
    const auto it = impl_->elementById.find(id);
    REQUIRE(it != impl_->elementById.end(), "No element with id " << id);
    return it->second;
}

const RoadElement& Graph::elementByDirectedId(const DirectedId& id) const
{
    return elementById(id.roadElementId());
}

const RoadElement& Graph::elementByNodeId(NodeId nodeId) const { return elementById(getId(nodeId)); }

Graph::ElementRange Graph::elements() const
{
    return boost::adaptors::values(std::as_const(impl_->elementById));
}

Edges Graph::edges(NodeId nodeId)
{
    const Trace& fromTrace = getTrace(nodeId);

    ASSERT(!fromTrace.empty());
    const DirectedId& fromId = fromTrace.back();

    auto distance = [this](const DirectedId& fromId, const DirectedId& toId)
    {
        const RoadElement& from = elementByDirectedId(fromId);
        const RoadElement& to = elementByDirectedId(toId);

        return (from.fastLength() + to.fastLength()) / 2;
    };

    Edges edges;

    const auto it = impl_->adjacency.find(fromId);
    if (it != impl_->adjacency.end()) {
        for (const auto& toId: it->second) {
            const Trace trace = impl_->conditionIndex.longestSuffixForbiddenConditionPrefix(
                addToEnd(fromTrace, toId)
            );

            if (!impl_->conditionIndex.hasForbiddenConditionAsSuffix(trace)) {
                const Trace toTrace = trace.empty() ? Trace{toId} : trace;

                edges.emplace_back(nodeId, getNodeId(toTrace), distance(fromId, toId));
            }
        }
    }

    if (impl_->conditionIndex.isUturnAllowed(fromId)) {
        const DirectedId toId = fromId.reverse();

        const auto nextNodeId = getNodeId(toId);
        ASSERT(nextNodeId);

        edges.emplace_back(nodeId, *nextNodeId, distance(fromId, toId));
    }

    return edges;
}

TId Graph::getId(NodeId nodeId) const { return getDirectedId(nodeId).roadElementId(); }

const DirectedId& Graph::getDirectedId(NodeId nodeId) const
{
    const Trace& trace = getTrace(nodeId);
    ASSERT(!trace.empty());
    return trace.back();
}

OptionalNodeId Graph::getNodeId(const DirectedId& directedId) {
    const RoadElement& element = elementByDirectedId(directedId);

    const bool isPossibleMove = isSet(directedId.direction(), element.direction())
        && isSet(accessId(), element.accessId());

    if (isPossibleMove) {
        return getNodeId(Trace{directedId});
    }

    return std::nullopt;
}

NodeId Graph::getNodeId(const Trace& trace)
{
    REQUIRE(!trace.empty(), "Empty trace!");
    return impl_->traceIndex.nodeId(trace);
}

const Trace& Graph::getTrace(NodeId nodeId) const { return impl_->traceIndex.trace(nodeId); }

} // namespace maps::mrc::graph
