#include "road_utils.h"
#include "misc.h"

#include <yandex/maps/wiki/common/rd/cond_type.h>
#include <maps/libs/common/include/exception.h>

namespace maps::wiki::validator::utils {

using categories::RD_EL;
using categories::RD_JC;
using categories::COND;

#define REQUIRE_INCIDENCE(jc,rdEl) \
    REQUIRE(jc->id() == rdEl->startJunction() || jc->id() == rdEl->endJunction(), \
        "Junction id=" << jc->id() << \
        " is not incident to element id=" << rdEl->id())

int
getRdElZLevelByJc(const RoadElement* rdEl, const Junction* rdJc)
{
    REQUIRE_INCIDENCE(rdJc, rdEl);
    return rdJc->id() == rdEl->startJunction()
        ? rdEl->fromZlevel()
        : rdEl->toZlevel();
}

bool
isRdElCanBeFrom(const RoadElement* rdEl, const Junction* rdJc)
{
    REQUIRE_INCIDENCE(rdJc, rdEl);
    return isSet(
        rdJc->id() == rdEl->startJunction()
            ? RoadElement::Direction::Forward
            : RoadElement::Direction::Backward,
        rdEl->direction());
}

bool
isRdElCanBeTo(const RoadElement* rdEl, const Junction* rdJc)
{
    REQUIRE_INCIDENCE(rdJc, rdEl);
    return isSet(
        rdJc->id() == rdEl->startJunction()
            ? RoadElement::Direction::Backward
            : RoadElement::Direction::Forward,
        rdEl->direction());
}

RdElsByZLevel
groupRdElByZLevForJc(
    CheckContext* context,
    const Junction* rdJc,
    common::AccessId accessId)
{
    RdElsByZLevel result;

    auto viewRdEl = context->objects<RD_EL>();

    auto addRdEls =
        [&](const std::vector<TId>& rdElIds) {
            for (auto rdElId : rdElIds) {
                if (viewRdEl.loaded(rdElId)) {
                    auto rdEl = viewRdEl.byId(rdElId);
                    if (isSet(accessId, rdEl->accessId())) {
                        result[getRdElZLevelByJc(rdEl, rdJc)].push_back(rdEl);
                    }
                }
            }
        };

    addRdEls(rdJc->inElements());
    addRdEls(rdJc->outElements());
    return result;
}

RdElContinuationsByJcId
findProhibitedContinuations(
    CheckContext* context,
    common::AccessId accessId)
{
    RdElContinuationsByJcId result;
    context->objects<COND>().visit(
        [&](const Condition* cond) {
            if (cond->type() != common::ConditionType::Prohibited) {
                return;
            }
            if (!cond->schedules().empty()) {
                return;
            }
            if (cond->vehicleRestrictionParameters()) {
                return;
            }
            if (!(accessId & cond->accessId())) {
                return;
            }
            if (cond->toRoadElements().size() > 1)
            {
                return;
            }

            REQUIRE(!cond->toRoadElements().empty(),
                "Missed \"to\" element for condition id=" << cond->id());
            result[cond->viaJunction()].emplace_back(
                    cond->fromRoadElement(),
                    cond->toRoadElements().front().second);
        });
    return result;
}

bool
canDrive(
    const RoadElement* from,
    const Junction* via,
    const RoadElement* to,
    const RdElContinuationsByJcId& prohibitedContinuations,
    RdElFilter backLaneFilter)
{
    REQUIRE_INCIDENCE(via, from);
    REQUIRE_INCIDENCE(via, to);

    /* No need for loop edges in graph of rd_el's reachability.
       It slows calculations only. */
    if (from->id() == to->id()) {
        return false;
    }

    int fromZLevel = 0;
    int toZLevel = 0;

    if (from->startJunction() == via->id()) {
        if (!isSet(RoadElement::Direction::Backward, from->direction()) &&
            !backLaneFilter(from)) {
                return false;
        }
        fromZLevel = from->fromZlevel();
    } else {
        if (!isSet(RoadElement::Direction::Forward, from->direction()) &&
            !backLaneFilter(from)) {
                return false;
        }
        fromZLevel = from->toZlevel();
    }

    if (to->startJunction() == via->id()) {
        if (!isSet(RoadElement::Direction::Forward, to->direction()) &&
            !backLaneFilter(to)) {
                return false;
        }
        toZLevel = to->fromZlevel();
    } else {
        if (!isSet(RoadElement::Direction::Backward, to->direction()) &&
            !backLaneFilter(to)) {
                return false;
        }
        toZLevel = to->toZlevel();
    }

    if (fromZLevel != toZLevel) {
        return false;
    }

    auto iter = prohibitedContinuations.find(via->id());
    if (iter == prohibitedContinuations.end()) {
        return true;
    }
    for (const auto& continuation : iter->second) {
        if (continuation.first == from->id() &&
            continuation.second == to->id()) {
                return false;
        }
    }
    return true;
}

RdEls
findRdElContinuations(
    CheckContext* context,
    TId primaryRdElId,
    RdElFilter graphFilter,
    const RdElContinuationsByJcId& prohibitedContinuations,
    IncidenceDirection incidence,
    RdElFilter backLaneFilter,
    RdElDirectionFilter directionFilter)
{
    auto viewRdEl = context->objects<RD_EL>();
    if (!viewRdEl.loaded(primaryRdElId)) {
        return {};
    }
    auto primaryRdEl = viewRdEl.byId(primaryRdElId);

    auto viewRdJc = context->objects<RD_JC>();

    RdEls result;

    auto addEdges = [&](const Junction* via, const std::vector<TId>& rdElIds) {
        for (auto rdElId : rdElIds) {
            if (!viewRdEl.loaded(rdElId)) {
                continue;
            }
            auto proceedingRdEl = viewRdEl.byId(rdElId);
            if (!graphFilter(proceedingRdEl)) {
                continue;
            }
            if (!directionFilter(proceedingRdEl, via->id())) {
                continue;
            }

            bool isContinuation = false;
            switch (incidence) {
                case IncidenceDirection::In : {
                    isContinuation = canDrive(proceedingRdEl, via, primaryRdEl,
                        prohibitedContinuations, backLaneFilter);
                    break;
                }
                case IncidenceDirection::Out : {
                    isContinuation = canDrive(primaryRdEl, via, proceedingRdEl,
                        prohibitedContinuations, backLaneFilter);
                    break;
                }
                case IncidenceDirection::Both : {
                    isContinuation =
                        canDrive(proceedingRdEl, via, primaryRdEl,
                            prohibitedContinuations, backLaneFilter) ||
                        canDrive(primaryRdEl, via, proceedingRdEl,
                            prohibitedContinuations, backLaneFilter);
                    break;
                }
            }
            if (isContinuation) {
                result.push_back(proceedingRdEl);
            }
        }
    };

    auto addIncidentEdges = [&](TId rdJcId) {
        if (!viewRdJc.loaded(rdJcId)) {
            return;
        }
        auto rdJc = viewRdJc.byId(rdJcId);
        addEdges(rdJc, rdJc->inElements());
        addEdges(rdJc, rdJc->outElements());
    };

    addIncidentEdges(primaryRdEl->startJunction());
    addIncidentEdges(primaryRdEl->endJunction());

    return result;
}

} // namespace maps::wiki::validator::utils
