#include <maps/wikimap/mapspro/services/mrc/libs/graph/impl/condition_index.h>

#include <maps/libs/log8/include/log8.h>

#include <maps/libs/common/include/exception.h>

#include <boost/algorithm/cxx11/all_of.hpp>
#include <boost/algorithm/cxx11/any_of.hpp>


namespace maps::mrc::graph {

namespace {

template<typename T, typename Container>
std::unordered_map<TId, T> byId(const Container& objects)
{
    std::unordered_map<TId, T> result;

    for (const auto& object: objects) {
        result.emplace(object.id(), object);
    }

    return result;
}

inline bool isForbidden(AccessId accessId, const Condition& condition)
{
    return !!(accessId & condition.accessId())  // Just one enough
        && (condition.type() == Condition::Type::Prohibited
                || condition.type() == Condition::Type::Barrier);
}

inline bool isUturn(AccessId accessId, const Condition& condition)
{
    return isSet(accessId, condition.accessId()) // need all
        && condition.type() == Condition::Type::Uturn;
}

inline Direction directionToJunction(const RoadElement& element, TId junctionId) {
    if (element.endJunctionId() == junctionId) {
        return Direction::Forward;
    }

    if (element.startJunctionId() == junctionId) {
        return Direction::Backward;
    }

    throw RuntimeError() << "Road element " << element.id()
                         << " is not connected with " << junctionId;
}

inline bool endWith(const Trace& trace, const Trace& suffix)
{
    return (suffix.size() <= trace.size()) &&
        std::equal(suffix.rbegin(), suffix.rend(), trace.rbegin());
}

void copy(const ConditionPositionsByDirectedId& from , ConditionPositionsByDirectedId& to)
{
    for (auto&& [directedId, toInsert]: from) {
        auto& current = to[directedId];

        if (current.empty()) {
            current = std::move(toInsert);
        } else {
            current.insert(current.end(), toInsert.begin(), toInsert.end());
        }
    }
}

} // namespace

ConditionIndex::ConditionIndex(
        AccessId accessId,
        const RoadElements& elements,
        const Conditions& conditions)
{
    const auto elementById = byId<const RoadElement&>(elements);

    auto allElementsAvailable = [&](const Condition& condition) {
        return elementById.count(condition.fromElementId())
            && boost::algorithm::all_of(
                condition.toElementIds(),
                [&](auto elementId) {
                    return elementById.count(elementId);
                }
            );
    };

    auto addForbiddenCondition = [this, &elementById](const Condition& cond)
    {
        TId junctionId = cond.viaJunctionId();

        Trace forbiddenTrace;
        ConditionPositionsByDirectedId forbiddenConditionPositionsByDirectedId;

        for (size_t position = 0; position <= cond.toElementIds().size(); ++position) {
            const TId elementId = position == 0
                ? cond.fromElementId()
                : cond.toElementIds().at(position - 1);

            const auto& element = elementById.at(elementId);

            if (position != 0) {
                junctionId = element.oppositeJunction(junctionId);
            } else {
                REQUIRE(
                    element.hasJunction(junctionId),
                    "Road element " << element.id() << " is not connected with " << junctionId
                );
            }

            const DirectedId id(elementId, directionToJunction(element, junctionId));

            forbiddenConditionPositionsByDirectedId[id].emplace_back(cond.id(), position);
            forbiddenTrace.push_back(id);
        }

        forbiddenTraceByConditionId_[cond.id()] = std::move(forbiddenTrace);

        copy(
            forbiddenConditionPositionsByDirectedId,
            forbiddenConditionPositionsByDirectedId_
        );
    };

    auto addUturnCondition = [this, &elementById](const Condition& cond) {
        const auto& from = elementById.at(cond.fromElementId());

        REQUIRE(cond.toElementIds().size() == 1, "Invalid length of uturn condition " << cond.id());
        const auto& toId = cond.toElementIds().front();

        REQUIRE(from.id() == toId, "Invalid uturn condition " << cond.id());

        isUturnAllowed_.emplace(from.id(), directionToJunction(from, cond.viaJunctionId()));
    };

    for (const auto& condition: conditions) try { // broken condition is rather common case
        if (!allElementsAvailable(condition)) {
            continue;
        } else if (isForbidden(accessId, condition)) {
            addForbiddenCondition(condition);
        } else if (isUturn(accessId, condition)) {
            addUturnCondition(condition);
        }
    } catch (const RuntimeError& e) {
        ERROR() << e;
    }
}

bool ConditionIndex::isUturnAllowed(const DirectedId& id) const { return isUturnAllowed_.count(id); }

bool ConditionIndex::hasForbiddenConditionAsSuffix(const Trace& trace) const
{
    if (trace.empty()) {
        return false;
    }

    const auto it = forbiddenConditionPositionsByDirectedId_.find(trace.back());
    if (it == forbiddenConditionPositionsByDirectedId_.end()) {
        return false;
    }

    return boost::algorithm::any_of(
        it->second,
        [&](const auto& condPos) {
            return endWith(trace, forbiddenTraceByConditionId_.at(condPos.conditionId));
        }
    );
}

Trace ConditionIndex::longestSuffixForbiddenConditionPrefix(const Trace& trace) const
{
    if (trace.empty()) {
        return {};
    }

    Trace result;

    const auto it = forbiddenConditionPositionsByDirectedId_.find(trace.back());
    if (it != forbiddenConditionPositionsByDirectedId_.end()) {
        for (const auto& condPos: it->second) {
            const auto& forbiddenTrace = forbiddenTraceByConditionId_.at(condPos.conditionId);

            Trace prefix {
                forbiddenTrace.begin(),
                forbiddenTrace.begin() + condPos.position + 1
            };

            if (endWith(trace, prefix) && prefix.size() > result.size()) {
                result = std::move(prefix);
            }
        }
    }

    return result;
}

} // namespace maps::mrc::graph
