#include "get_route_time.h"

#include <maps/wikimap/mapspro/services/editor/src/branch_helpers.h>
#include <maps/wikimap/mapspro/services/editor/src/common.h>
#include <maps/wikimap/mapspro/services/editor/src/magic_strings.h>
#include <maps/wikimap/mapspro/services/editor/src/utils.h>

#include <maps/wikimap/mapspro/services/editor/src/configs/categories_strings.h>

#include "common.h"
#include "config.h"
#include "helper.h"
#include "route_condition.h"
#include "route_element.h"
#include "util.h"

#include <yandex/maps/wiki/revision/revisionsgateway.h>
#include <yandex/maps/wiki/routing/exception.h>
#include <yandex/maps/wiki/routing/route.h>

#include <yandex/maps/wiki/common/misc.h>
#include <yandex/maps/wiki/common/string_utils.h>

#include <maps/libs/common/include/exception.h>
#include <maps/libs/geolib/include/bounding_box.h>
#include <maps/libs/geolib/include/point.h>
#include <maps/libs/geolib/include/polyline.h>

#include <boost/optional.hpp>

#include <cmath>
#include <functional>
#include <map>
#include <memory>
#include <vector>


namespace maps {
namespace wiki {

namespace {

const std::string STR_TASK_METHOD_NAME = "GetRouteTime";

constexpr double STOP_SEARCH_RADIUS_METERS = 30;

RoutingConfig busThreadRoutinConfig() {
    return RoutingConfig {
        CATEGORY_RD_EL,
        /* aoi  = */ geolib3::BoundingBox(),
        /* direction = */ [](const RouteElement& el) {
            return el[STR_BACK_BUS] ? Direction::Both : el[STR_WAY].as<Direction>();
        },
        /* filter = */ [](const RouteElement& /*el*/) {
            return true;
        },
        /* weight = */ [](const RouteElement& /*lhs*/, const RouteElement& /*rhs*/) {
            return 1;
        },
        /* condition = */ RoutingConditionConfig {
            /* filter = */ [](const RouteCondition& cond) {
                return cond.categoryId() == CATEGORY_COND
                    && isSet(AccessId::Bus, cond[STR_ACCEESS_ID]);
            },
            /* type = */ [](const RouteCondition& cond) {
                return cond[STR_COND_TYPE] == common::ConditionType::Uturn
                    ? RouteConditionType::Turnabout
                    : RouteConditionType::Forbidden;
            }
        },
        0,
        0
    };
}

RoutingConfig tramThreadRoutinConfig() {
    return RoutingConfig {
        CATEGORY_TRANSPORT_TRAM_EL,
        /* aoi  = */ geolib3::BoundingBox(),
        /* direction = */ [](const RouteElement& /*el*/) {
            return Direction::Both;
        },
        /* filter = */ [](const RouteElement& /*el*/) {
            return true;
        },
        /* weight = */ [](const RouteElement& /*lhs*/, const RouteElement& /*rhs*/) {
            return 1;
        },
        /* condition = */ boost::none,
        0,
        0
    };
}

RoutingConfig createRoutingConfig(const std::string& threadCategoryId, const routing::Stops& stops)
{
    REQUIRE(!stops.empty(), "Empty stop list");

    if (threadCategoryId == CATEGORY_TRANSPORT_BUS_THREAD) {
        return busThreadRoutinConfig();
    }

    if (threadCategoryId == CATEGORY_TRANSPORT_TRAM_THREAD) {
        return tramThreadRoutinConfig();
    }

    THROW_WIKI_LOGIC_ERROR(
        ERR_ROUTING_UNSUPPORTED_ROUTE_CATEGORY,
        "'" << threadCategoryId << "' is unsupported thread category id"
    );
}

constexpr double MIN_SPEED = KMH_TO_MS_RATIO * 5.0;
constexpr double MAX_SPEED = KMH_TO_MS_RATIO * 150.0;
constexpr double MAX_TRANSPORT_SPEED = KMH_TO_MS_RATIO * 80.0;
constexpr double ACCELERATION_TIME = 15.0;

double speed(const RouteElement& element)
{
    static const double alpha = std::log(MAX_TRANSPORT_SPEED / MAX_SPEED) / MAX_SPEED;
    const double value = element.speedIntervalMetersPerSecond().min;
    return std::max(MIN_SPEED, value * std::exp(alpha * value));
}

double travelTime(const RouteElement& element)
{
    return element.lengthMeters() / speed(element);
}

routing::Trace tryRestore(
        const routing::Elements& elements,
        const routing::Stops& stops,
        const routing::Conditions& conditions,
        double stopSnapToleranceMeters)
{
    auto getStopGeomById = [&stops](TOid id) {
        for (const auto& stop: stops) {
            if (stop.id() == id) {
                return stop.geom();
            }
        }

        THROW_WIKI_INTERNAL_ERROR("Impossible find stop with id" << id);
    };

    try {
        routing::RestoreResult result = routing::restore(
            elements,
            stops,
            conditions,
            stopSnapToleranceMeters
        );

        if (!result.noPathErrors().empty()) {
            const routing::NoPathError& error = result.noPathErrors().front();

            throw LogicExceptionWithLocation(
                ERR_ROUTING_IMPOSSIBLE_RESTORE_ROUTE,
                getStopGeomById(error.toStopId())
            ) << "Impossible find path to stop " << error.toStopId();
        }

        if (!result.forwardAmbiguousPathErrors().empty()) {
            const routing::AmbiguousPathError& error = result.forwardAmbiguousPathErrors().front();

            throw LogicExceptionWithLocation(
                ERR_ROUTING_IMPOSSIBLE_RESTORE_ROUTE,
                getStopGeomById(error.toStopId())
            ) << "Ambiguous path to stop " << error.toStopId();
        }

        return result.trace();
    } catch (routing::ImpossibleSnapStopError& error) {
        ASSERT(!error.stopIds().empty());
        const TOid stopId = *error.stopIds().begin();

        throw LogicExceptionWithLocation(
            ERR_ROUTING_IMPOSSIBLE_SNAP_STOP,
            getStopGeomById(stopId)
        ) << "Stop " << stopId << " is too distant from transport graph";
    }
}

} // namespace

std::string GetRouteTime::Request::dump() const
{
    std::stringstream out;

    out << " user: " << user
        << " token: " << token
        << " branch: " << branchId

        << " revisionId: " << revisionId
        << " categoryId: " << categoryId

        << " addElementIds [" << common::join(addElementIds, ',') << ']'
        << " removeElementIds [" << common::join(removeElementIds, ',') << ']'

        << " threadStops: " << threadStopSequence.toString();

    return out.str();
}

GetRouteTime::GetRouteTime(
        const ObserverCollection&,
        const Request& request,
        taskutils::TaskID asyncTaskID)
    : controller::BaseController<GetRouteTime>(BOOST_CURRENT_FUNCTION, asyncTaskID)
    , request_(request)
{}

std::string GetRouteTime::printRequest() const { return request_.dump(); }

void GetRouteTime::control()
{
    BranchContextFacade facade(request_.branchId);

    auto context = facade.acquireRead(request_.branchId, request_.token);
    auto gateway = revision::RevisionsGateway(context.txnCore(), context.branch);
    auto snapshot = gateway.snapshot(gateway.headCommitId());

    restoreStopIds(snapshot, request_.threadStopSequence);
    restoreAttrs(snapshot, request_.threadStopSequence);
    const std::vector<TOid> stopIds = request_.threadStopSequence.stopIds();

    const IdToRevision idToStopRevision = load(snapshot, TOIds{stopIds.begin(), stopIds.end()});
    const routing::Stops stops = toStopSequence(idToStopRevision, stopIds);
    const RoutingConfig config = createRoutingConfig(request_.categoryId, stops);

    const TOIds threadElementIds = getThreadElementIds(
        snapshot,
        request_.revisionId.objectId(),
        request_.addElementIds,
        request_.removeElementIds
    );

    WIKI_REQUIRE(
        stops.size() >= 2,
        ERR_ROUTING_IMPOSSIBLE_RESTORE_ROUTE,
        "Need at least 2 stops"
    );

    const auto idToRevision = load(snapshot, threadElementIds);

    const routing::Trace trace = tryRestore(
        toElements(snapshot, config, idToRevision),
        stops,
        loadConditions(snapshot, config, threadElementIds),
        STOP_SEARCH_RADIUS_METERS
    );

    const auto idToRouteElement = mapIdToRouteElement(idToRevision);

    auto it = trace.begin();
    WIKI_REQUIRE(
        trace.size() >= 1 && it->stopSnap(),
        ERR_ROUTING_IMPOSSIBLE_RESTORE_ROUTE,
        "Invalid trace after restore"
    );

    const double startRatio = 1 - it->stopSnap()->locationOnElement();
    const RouteElement& startElement = idToRouteElement.at(it->directedElementId().id());

    double current = ACCELERATION_TIME + startRatio * travelTime(startElement) ;

    std::vector<size_t> result;
    result.reserve(stops.size() - 1);

    for (++it; it != trace.end(); ++it) {
        const RouteElement& element = idToRouteElement.at(it->directedElementId().id());
        const double time = travelTime(element);

        if (!it->stopSnap()) {
            current += time;
            continue;
        }

        const double ratio = it->stopSnap()->locationOnElement();

        const auto before = it - 1;
        if (before->stopSnap() && before->directedElementId() == it->directedElementId()) {
            current -= (1 - ratio) * time ;
        } else {
            current += ratio * time;
        }

        const size_t threadStopIdx = result.size() + 1;
        const auto& threadStopAttrs = request_.threadStopSequence.at(threadStopIdx).attrs;

        const TOid stopId = it->stopSnap()->stopId();
        const auto stopAttrs = common::AttrsWrap::extract(idToStopRevision.at(stopId));

        const bool isViaStop = stopAttrs[ATTR_TRANSPORT_STOP_WAYPOINT] || (
            threadStopAttrs[ATTR_TRANSPORT_THREAD_STOP_NO_EMBARKATION]
                && threadStopAttrs[ATTR_TRANSPORT_THREAD_STOP_NO_DISEMBARKATION]
        );

        if (isViaStop) {
            result.push_back(current);
            current = (1 - ratio) * time;
        } else {
            result.push_back(current + ACCELERATION_TIME);
            current = ACCELERATION_TIME + (1 - ratio) * time;
        }
    }

    result_->time = std::move(result);
}

const std::string& GetRouteTime::taskName() { return STR_TASK_METHOD_NAME; }

} // namespace wiki
} // namespace maps
