#include <maps/wikimap/mapspro/services/mrc/libs/common/include/algorithm/for_each_passage.h>
#include <maps/wikimap/mapspro/services/mrc/libs/graph_matcher_adapter/include/feature_positioner.h>
#include <maps/libs/geolib/include/conversion.h>
#include <maps/libs/geolib/include/direction.h>
#include <maps/libs/geolib/include/distance.h>

#include <algorithm>

namespace maps {
namespace mrc {
namespace adapters {
namespace {

constexpr size_t MAX_DISTANCE_INTERVAL_IN_METERS(1000);
constexpr int MAX_TIME_INTERVAL_SECONDS = 60;
constexpr std::chrono::seconds MAX_TIME_INTERVAL(MAX_TIME_INTERVAL_SECONDS);

TrackSegments toTrackSegments(const db::TrackPoints& trackPoints)
{
    TrackSegments result;
    result.reserve(trackPoints.size()); // extra one
    std::adjacent_find(
        trackPoints.begin(), trackPoints.end(),
        [&result](const db::TrackPoint& lhs, const db::TrackPoint& rhs) {
            if ((rhs.timestamp() - lhs.timestamp()) < MAX_TIME_INTERVAL
                && geolib3::fastGeoDistance(lhs.geodeticPos(), rhs.geodeticPos())
                       < MAX_DISTANCE_INTERVAL_IN_METERS) {
                result.push_back({{lhs.geodeticPos(), rhs.geodeticPos()},
                                  lhs.timestamp(),
                                  rhs.timestamp()});
            }
            return false; // continue
        });
    return result;
}

auto byPassages = [](const auto& lhs, const auto& rhs) {
    return std::make_tuple(lhs.sourceId(), lhs.timestamp()) <
           std::make_tuple(rhs.sourceId(), rhs.timestamp());
};

geolib3::Heading heading(const geolib3::Segment2& geodetic)
{
    geolib3::Direction2 direction(
        geolib3::convertGeodeticToMercator(geodetic));
    return direction.heading();
}

template <class T>
chrono::TimePoint time(const T& val)
{
    return val.timestamp();
}

chrono::TimePoint time(const chrono::TimePoint& val)
{
    return val;
}

chrono::TimePoint time(const track_classifier::TrackInterval& val)
{
    return val.end;
}

chrono::TimePoint time(const TrackSegment& val)
{
    return val.endTime;
}

auto byTime = [](const auto& lhs, const auto& rhs) {
    return time(lhs) < time(rhs);
};

template <class SortedIt>
std::pair<SortedIt, SortedIt> subrange(SortedIt first,
                                       SortedIt last,
                                       chrono::TimePoint startTime,
                                       chrono::TimePoint endTime)
{
    auto resultFirst = std::lower_bound(first, last, startTime, byTime);
    auto resultLast = std::upper_bound(resultFirst, last, endTime, byTime);
    return {resultFirst, resultLast};
}

bool interpolateFeaturePosition(const TrackSegments& path,
                                db::Feature& feature)
{
    auto it = std::lower_bound(path.begin(), path.end(), feature, byTime);
    if (it != path.end() && it->startTime <= feature.timestamp()) {
        auto timeInterval = it->endTime - it->startTime;
        if (!timeInterval.count()) {
            feature.setGeodeticPos(it->segment.pointByPosition(.5));
        }
        else {
            feature.setGeodeticPos(it->segment.pointByPosition(
                static_cast<double>(
                    (feature.timestamp() - it->startTime).count())
                / static_cast<double>(timeInterval.count())));
        }
        feature.setHeading(heading(it->segment));
        return true;
    }
    return false;
}

template <class FeatureIt>
void snapFeaturesPositionsToTrackPointByTime(const db::TrackPoint& trackPoint,
                                             FeatureIt orderedByPassagesFirst,
                                             FeatureIt orderedByPassagesLast)
{
    auto [first, last] = std::equal_range(
        orderedByPassagesFirst, orderedByPassagesLast, trackPoint, byPassages);
    std::for_each(first, last, [&](db::Feature& feature) {
        feature.setGeodeticPos(trackPoint.geodeticPos());
        if (auto heading = trackPoint.heading()) {
            feature.setHeading(*heading);
        }
    });
}

std::optional<db::GraphType> toGraphType(track_classifier::TrackType trackType)
{
    switch (trackType) {
        case track_classifier::TrackType::Undefined:
            return std::nullopt;
        case track_classifier::TrackType::Pedestrian:
            return db::GraphType::Pedestrian;
        case track_classifier::TrackType::Vehicle:
            return db::GraphType::Road;
    }
}

} // anonymous namespace

FeaturePositioner::FeaturePositioner(
    GraphTypeToMatcherMap graphTypeToMatcherMap,
    TrackPointProvider trackPointProvider,
    TrackClassifier trackClassifier)
    : graphTypeToMatcherMap_(std::move(graphTypeToMatcherMap))
    , trackPointProvider_(std::move(trackPointProvider))
    , trackClassifier_(std::move(trackClassifier))
{
}

bool FeaturePositioner::operator()(db::Features& features) const
{
    std::sort(features.begin(), features.end(), byPassages);
    return (*this)(features.begin(), features.end());
}

bool FeaturePositioner::operator()(db::Features::iterator first, db::Features::iterator last) const
{
    bool changeGraph = false;
    common::forEachPassage(first, last, [&](auto begin, auto end) {
        auto startTime = begin->timestamp() - MAX_TIME_INTERVAL;
        auto endTime = std::prev(end)->timestamp() + MAX_TIME_INTERVAL;
        auto trackPoints
            = trackPointProvider_(begin->sourceId(), startTime, endTime);
        auto trackIntervals = std::vector<track_classifier::TrackInterval>{};
        if (!trackPoints.empty()) {
            std::sort(trackPoints.begin(), trackPoints.end(), byTime);
            trackIntervals = trackClassifier_(trackPoints, graphTypeToMatcherMap_);
            std::sort(trackIntervals.begin(), trackIntervals.end(), byTime);
        }
        for (const auto& trackInterval : trackIntervals) {
            auto matcher = getMatcherByType(trackInterval.type);
            auto photoRng = subrange(begin,
                                     end,
                                     trackInterval.begin,
                                     trackInterval.end);
            auto pointRng = subrange(trackPoints.begin(),
                                     trackPoints.end(),
                                     trackInterval.begin,
                                     trackInterval.end);
            auto unmatchedPoints =
                db::TrackPoints{pointRng.first, pointRng.second};
            auto unmatchedSegments = toTrackSegments(unmatchedPoints);
            auto matchedSegments = matcher == nullptr
                                       ? TrackSegments{}
                                       : matcher->match(unmatchedPoints);
            for (auto it = photoRng.first; it != photoRng.second; ++it) {
                auto graphBefore = std::pair{it->hasGraph(), it->graph()};
                if (interpolateFeaturePosition(matchedSegments, *it)) {
                    it->setGraph(*toGraphType(trackInterval.type));
                }
                else if (interpolateFeaturePosition(unmatchedSegments, *it)) {
                    it->resetGraph();
                }
                else {
                    it->resetGraph().resetPos().resetHeading();
                }
                auto graphAfter = std::pair{it->hasGraph(), it->graph()};
                changeGraph |= (graphBefore != graphAfter);
            }
        }
        for (const auto& trackPoint : trackPoints) {
            if (trackPoint.isAugmented()) {
                snapFeaturesPositionsToTrackPointByTime(trackPoint, begin, end);
            }
        }
    },
    MAX_TIME_INTERVAL_SECONDS);
    return changeGraph;
}

const Matcher*
FeaturePositioner::getMatcherByType(track_classifier::TrackType trackType) const
{
    auto graphType = toGraphType(trackType);
    if (!graphType) {
        return nullptr;
    }
    auto it = graphTypeToMatcherMap_.find(*graphType);
    if (it == graphTypeToMatcherMap_.end()) {
        return nullptr;
    }
    return it->second;
}

TrackClassifier classifyAs(track_classifier::TrackType trackType)
{
    return [=](db::TrackPoints trackPoints, const GraphTypeToMatcherMap& /*graphTypeToMatcherMap*/) {
        track_classifier::TrackInterval result;
        auto min =
            std::min_element(trackPoints.begin(), trackPoints.end(), byTime);
        auto max =
            std::max_element(trackPoints.begin(), trackPoints.end(), byTime);
        result.begin = min->timestamp();
        result.end = max->timestamp();
        result.type = trackType;
        return std::vector<track_classifier::TrackInterval>{{result}};
    };
}

} // adapters
} // mrc
} // maps
