#include "graph.h"

#include <maps/wikimap/mapspro/services/mrc/libs/common/include/algorithm/parallel_for_each.h>
#include <maps/wikimap/mapspro/services/mrc/libs/common/include/geometry.h>
#include <maps/wikimap/mapspro/services/mrc/libs/db/include/track_point_gateway.h>

#include <maps/libs/geolib/include/distance.h>
#include <maps/libs/log8/include/log8.h>

#include <atomic>
#include <functional>
#include <memory>
#include <string>

#include <boost/range/adaptor/filtered.hpp>

namespace maps::mrc::export_gen {

namespace {

constexpr auto BATCH_SIZE = size_t(1'000'000);

struct EdgePhoto {
    road_graph::EdgeId edgeId;
    PhotoRef photo;
    fb::PolylinePosition begin;
    fb::PolylinePosition end;
    uint8_t order;
    Origin origin;
};

using CameraDeviationSet = std::set<db::CameraDeviation>;
using EdgePhotos = tbb::concurrent_vector<EdgePhoto>;
using EdgePhotoIterator = EdgePhotos::const_iterator;
using PhotoRefIterator = PhotoRefs::const_iterator;
using SubpolylineWithPhotoPtr = common::geometry::SubPolylineWithValue<const Photo*>;

template <class T>
bool between(const T& v, const T& lo, const T& hi)
{
    return lo <= v && v <= hi;
}

float clampRatio(float ratio)
{
    return std::clamp<float>(ratio, 0, 1);
}

const chrono::TimePoint& time(const Photo& photo)
{
    return photo.timestamp;
}

const chrono::TimePoint& time(const adapters::TrackSegment& trackSegment)
{
    return trackSegment.endTime;
}

auto lessTime = [](const auto& lhs, const auto& rhs) {
    return time(lhs) < time(rhs);
};

template <class T>
void copy(const tbb::concurrent_vector<T>& from, tbb::concurrent_vector<T>& to)
{
    for (const auto& item : from) {
        to.push_back(item);
    }
}

template <class ForwardIt, class Equal>
ForwardIt findGroupEnd(ForwardIt first, ForwardIt last, const Equal& equal)
{
    auto it = std::adjacent_find(first, last, std::not_fn(equal));
    return it == last ? last : std::next(it);
}

template <class Range, class Less, class Equal>
auto groupBy(Range& range, const Less& less, const Equal& equal)
{
    auto result = std::vector<typename Range::const_iterator>{};
    tbb::parallel_sort(range.begin(), range.end(), less);
    for (auto it = range.begin(); it != range.end();) {
        result.push_back(it);
        it = findGroupEnd(it, range.end(), equal);
    }
    return result;
}

bool lessEdge(const EdgePhoto& lhs, const EdgePhoto& rhs)
{
    return lhs.edgeId < rhs.edgeId;
}

bool equalEdge(const EdgePhoto& lhs, const EdgePhoto& rhs)
{
    return lhs.edgeId == rhs.edgeId;
}

bool lessPassage(const Photo& lhs, const Photo& rhs)
{
    return std::tie(lhs.sourceId, lhs.timestamp, lhs.featureId) <
           std::tie(rhs.sourceId, rhs.timestamp, rhs.featureId);
}

bool equalPassage(const Photo& lhs, const Photo& rhs)
{
    return lhs.sourceId == rhs.sourceId &&
           abs(lhs.timestamp - rhs.timestamp) <= MAX_TIME_INTERVAL;
}

std::optional<road_graph::EdgeId> findCurrentEdgeId(
    const Context& ctx,
    road_graph::EdgeId prevEdgeId)
{
    if (!ctx.prevPersistentIndex) {
        return std::nullopt;
    }
    if (ctx.prevPersistentIndex->version() == ctx.persistentIndex.version()) {
        return prevEdgeId;
    }
    auto longId = ctx.prevPersistentIndex->findLongId(prevEdgeId);
    if (!longId) {
        return std::nullopt;
    }
    return ctx.persistentIndex.findShortId(longId.value());
}

/// @return previous visible edges of photo if geometry or photo weren’t changed
EdgePhotos findPrevEdgePhotosIfActual(const Context& ctx, const Photo& photo)
{
    if (!ctx.prevPhotoToEdge || photo.origin != Origin::Fb) {
        return {};  // not actual
    }
    auto result = EdgePhotos{};
    auto photoToEdges = ctx.prevPhotoToEdge->lookupByFeatureId(photo.featureId);
    for (const auto& photoToEdge : photoToEdges) {
        auto edgeId =
            findCurrentEdgeId(ctx, road_graph::EdgeId(photoToEdge.edgeId()));
        if (!edgeId) {
            return {};  // not actual
        }
        result.push_back(EdgePhoto{.edgeId = *edgeId,
                                   .photo = std::cref(photo),
                                   .begin = photoToEdge.begin(),
                                   .end = photoToEdge.end(),
                                   .order = photoToEdge.order(),
                                   .origin = Origin::Fb});
    }
    return result;
}

db::TrackPoints toTrackPoints(PhotoRefIterator first, PhotoRefIterator last)
{
    auto result = db::TrackPoints{};
    std::for_each(first, last, [&](const Photo& photo) {
        result.emplace_back()
            .setSourceId(std::string{photo.sourceId})
            .setGeodeticPos(photo.geodeticPos)
            .setHeading(photo.direction.heading())
            .setTimestamp(photo.timestamp);
    });
    if (first != last) {
        /**
         * Trackpoints are created from photos if ride doesn't contain them.
         * The last photo must be inside track,
         * as coverage is limited by matched track.
         */
        const Photo& photo = *std::prev(last);
        auto averageTimeShift = [&] {
            using std::chrono::seconds;
            auto meters =
                fastGeoDistance(first->get().geodeticPos, photo.geodeticPos);
            if (meters < COVERAGE_METERS_MAX_PATH) {
                return seconds(5);
            }
            auto ratio = clampRatio(COVERAGE_METERS_MAX_PATH / meters);
            auto duration = time(photo) - time(*first);
            auto result = duration_cast<seconds>(duration * ratio);
            return std::clamp(result, seconds(2), seconds(10));
        }();
        result.emplace_back()
            .setSourceId(std::string{photo.sourceId})
            .setGeodeticPos(fastGeoShift(
                photo.geodeticPos,
                COVERAGE_METERS_MAX_PATH * photo.direction.vector()))
            .setHeading(photo.direction.heading())
            .setTimestamp(time(photo) + averageTimeShift);
    }
    return result;
}

geolib3::Point2 interpolatePoint(const adapters::TrackSegment& trackSegment,
                                 chrono::TimePoint timePoint)
{
    REQUIRE(between(timePoint, trackSegment.startTime, trackSegment.endTime),
            "trackSegment doesn't contain timePoint");
    auto timeInterval = trackSegment.endTime - trackSegment.startTime;
    return timeInterval.count()
               ? trackSegment.segment.pointByPosition(
                     clampRatio((timePoint - trackSegment.startTime).count() /
                                static_cast<double>(timeInterval.count())))
               : trackSegment.segment.midpoint();
}

std::optional<fb::PolylinePosition> positionByPoint(
    const geolib3::Polyline2& polyline,
    const geolib3::Point2& point)
{
    auto segmentIdx = polyline.segmentIndex(point);
    if (!segmentIdx) {
        WARN() << "wrong matched Point{.x=" << point.x() << ", .y=" << point.y()
               << "}";
        return std::nullopt;
    }
    auto segment = polyline.segmentAt(*segmentIdx);
    auto ratio = std::sqrt(geolib3::squaredLength(point - segment.start()) /
                           geolib3::squaredLength(segment));
    return fb::PolylinePosition(*segmentIdx, clampRatio(ratio));
}

EdgePhotos evalEdgePhotos(const road_graph::Graph& graph,
                          const adapters::TrackSegments& track,
                          const Photo& photo)
{
    auto result = EdgePhotos{};
    auto it = std::lower_bound(track.begin(), track.end(), photo, lessTime);
    if (it == track.end() || it->startTime > photo.timestamp) {
        return result;
    }
    auto point = interpolatePoint(*it, photo.timestamp);
    auto metersCountdown = COVERAGE_METERS_MAX_PATH;
    auto edgeGeom = geolib3::Polyline2{};
    do {
        if (!it->edgeId) {
            break;
        }
        if (result.empty() || result.back().edgeId != *it->edgeId) {
            REQUIRE(result.size() < std::numeric_limits<uint8_t>::max(),
                    "too many edges");
            edgeGeom =
                geolib3::Polyline2(graph.edgeData(*it->edgeId).geometry());
            auto pos = positionByPoint(edgeGeom, point);
            if (!pos) {
                break;
            }
            /**
             * Create an empty relation (begin == end). Move end position later.
             * If it's not possible, then leave "as is" to avoid re-match.
             * Empty relations are not used in coverage.
             */
            result.push_back(EdgePhoto{.edgeId = *it->edgeId,
                                       .photo = std::cref(photo),
                                       .begin = *pos,
                                       .end = *pos,  // to avoid re-match
                                       .order = uint8_t(result.size())});
            if (result.size() == 1 &&
                fastGeoDistance(point, photo.geodeticPos) >
                    COVERAGE_METERS_SNAP_THRESHOLD) {
                break;
            }
        }
        if (angleBetween(toDirection(it->segment), photo.direction) >
            COVERAGE_ANGLE_DIFF_THRESHOLD) {
            break;
        }
        auto meters = fastGeoDistance(point, it->segment.end());
        if (metersCountdown < meters) {
            auto denom =
                fastGeoDistance(it->segment.start(), it->segment.end());
            auto num = denom - meters + metersCountdown;
            point = it->segment.pointByPosition(clampRatio(num / denom));
            metersCountdown = 0.;
        }
        else {
            point = it->segment.end();
            metersCountdown -= meters;
        }
        auto pos = positionByPoint(edgeGeom, point);
        if (!pos) {
            break;
        }
        result.back().end = *pos;
        ++it;
    } while (metersCountdown > 0. && it != track.end() &&
             std::prev(it)->segment.end() == it->segment.start());
    return result;
}

void match(const Context& ctx,
           PhotoRefs& photos,
           const TrackPointProvider& trackPointProvider,
           EdgePhotos& result)
{
    auto photoIterators = groupBy(photos, lessPassage, equalPassage);
    auto processed = std::atomic_size_t{0};
    auto unmatched = std::atomic_size_t{0};
    common::parallelForEach<MAX_THREADS>(
        photoIterators.begin(),
        photoIterators.end(),
        [&](auto& /*guard*/, auto first) {
            auto last = findGroupEnd(first, photos.cend(), equalPassage);
            auto sourceId = first->get().sourceId;
            auto startTime = time(*first) - MAX_TIME_INTERVAL;
            auto endTime = time(*std::prev(last)) + MAX_TIME_INTERVAL;
            auto trackPoints = trackPointProvider(sourceId, startTime, endTime);
            if (trackPoints.empty()) {
                trackPoints = toTrackPoints(first, last);
            }
            auto track = ctx.matcher.match(trackPoints);
            std::for_each(first, last, [&](const Photo& photo) {
                auto edgePhotos =
                    evalEdgePhotos(ctx.matcher.graph(), track, photo);
                if (edgePhotos.empty()) {
                    ++unmatched;
                }
                else {
                    copy(edgePhotos, result);
                }
                if (auto snapshot = ++processed; snapshot % BATCH_SIZE == 0) {
                    INFO() << ctx.graphType << ": processed "
                           << format(snapshot) << " photos";
                }
            });
        });
    INFO() << ctx.graphType << ": processed " << format(processed)
           << " photos (unmatched " << format(unmatched) << ")";
}

EdgePhotos makeEdgePhotos(const Context& ctx,
                          const Photos& photos,
                          const TrackPointProvider& trackPointProvider)
{
    auto result = EdgePhotos{};
    auto matchingPhotos = PhotoRefs{};
    for (const auto& photo : photos) {
        if (photo.graph != ctx.graphType ||
            isStandalonePhotosDataset(photo.dataset) ||
            photo.sourceId == db::feature::NO_SOURCE_ID) {
            // not interested in single images that cannot be navigated
            continue;
        }
        auto edgePhotos = findPrevEdgePhotosIfActual(ctx, photo);
        if (edgePhotos.empty()) {
            matchingPhotos.push_back(std::cref(photo));
        }
        else {
            copy(edgePhotos, result);
        }
    }
    INFO() << ctx.graphType << ": found " << format(result.size())
           << " matches";
    match(ctx, matchingPhotos, trackPointProvider, result);
    INFO() << ctx.graphType << ": total " << format(result.size())
           << " matches";
    return result;
}

std::optional<road_graph::EdgeId> findPrevEdgeId(const Context& ctx,
                                                 road_graph::EdgeId edgeId)
{
    if (!ctx.prevPersistentIndex) {
        return std::nullopt;
    }
    if (ctx.prevPersistentIndex->version() == ctx.persistentIndex.version()) {
        return edgeId;
    }
    auto longId = ctx.persistentIndex.findLongId(edgeId);
    if (!longId) {
        return std::nullopt;
    }
    return ctx.prevPersistentIndex->findShortId(longId.value());
}

bool featuresContainEdgeCoverage(const db::TIdSet& featureIds,
                                 const fb::TEdge& edge)
{
    for (const auto& coverage : edge.coverages) {
        for (const auto& coveredSubpolyline : coverage.coveredSubpolylines) {
            if (!featureIds.contains(coveredSubpolyline.featureId())) {
                return false;
            }
        }
    }
    return true;
}

/// @return previous edge coverage if geometry or photos weren’t changed
/// Each EdgePhoto from [@param first, @param last) must have the same edgeId
std::optional<fb::TEdge> findPrevEdgeIfActual(const Context& ctx,
                                              EdgePhotoIterator first,
                                              EdgePhotoIterator last)
{
    auto edgeId = first->edgeId;
    auto photoIds = db::TIdSet{};
    auto originMin = Origin::Fb;
    std::for_each(first, last, [&](const EdgePhoto& edgePhoto) {
        photoIds.insert(edgePhoto.photo.get().featureId);
        originMin = std::min(originMin, edgePhoto.origin);
    });
    auto result = std::optional<fb::TEdge>{};
    if (ctx.prevGraph && originMin == Origin::Fb) {
        if (auto prevEdgeId = findPrevEdgeId(ctx, edgeId)) {
            if (result = ctx.prevGraph->edgeById(prevEdgeId->value())) {
                result->id = edgeId.value();
            }
        }
    }
    if (result && !featuresContainEdgeCoverage(photoIds, *result)) {
        result.reset();  // not actual
    }
    return result;
}

CameraDeviationSet toCameraDeviationSet(EdgePhotoIterator first,
                                        EdgePhotoIterator last)
{
    auto result = CameraDeviationSet{};
    std::for_each(first, last, [&](const EdgePhoto& edgePhoto) {
        result.insert(edgePhoto.photo.get().cameraDeviation);
    });
    return result;
}

// All photos from @param subpolylinesWithPhotos must have the same cameraDeviation
fb::TEdgeCoverage makeEdgeCoverage(
    const geolib3::Polyline2& edgeGeom,
    const std::vector<SubpolylineWithPhotoPtr>& subpolylinesWithPhotos)
{
    auto result =
        fb::TEdgeCoverage{.coverageFraction = 0.,
                          .actualizationDate = chrono::TimePoint::clock::now(),
                          .privacy = db::FeaturePrivacy::Min};
    auto meters = 0.;
    for (const auto& subpolylineWithPhoto: subpolylinesWithPhotos) {
        auto begin = fb::PolylinePosition(
            subpolylineWithPhoto.subPolyline.begin().segmentIdx(),
            subpolylineWithPhoto.subPolyline.begin().segmentRelPosition());
        auto end = fb::PolylinePosition(
            subpolylineWithPhoto.subPolyline.end().segmentIdx(),
            subpolylineWithPhoto.subPolyline.end().segmentRelPosition());
        meters += geoLength(partition(edgeGeom, begin, end));
        auto photoPtr = subpolylineWithPhoto.value;
        result.actualizationDate =
            std::min(result.actualizationDate, photoPtr->timestamp);
        result.cameraDeviation = photoPtr->cameraDeviation;
        result.privacy = std::max(result.privacy, photoPtr->privacy);
        result.coveredSubpolylines.emplace_back(photoPtr->featureId, begin, end);
    }

    if (meters >= MIN_COVERAGE_METERS) {
        result.coverageFraction = meters / geoLength(edgeGeom);
        // 100% + epsilon
        REQUIRE(result.coverageFraction < 1.001,
                "invalid coverageFraction " << result.coverageFraction);
        result.coverageFraction = clampRatio(result.coverageFraction);
    }
    return result;
}

/// temporarily avoid nexar
/// @see https://st.yandex-team.ru/MAPSMRC-3793
bool lessPriority(const Photo* lhs, const Photo* rhs)
{
    return std::make_tuple(lhs->dataset != db::Dataset::NexarDashcams,
                           lhs->dayPart,
                           lhs->timestamp,
                           lhs->featureId) <
           std::make_tuple(rhs->dataset != db::Dataset::NexarDashcams,
                           rhs->dayPart,
                           rhs->timestamp,
                           rhs->featureId);
}

/// All EdgePhoto of [@param first, @param last) must have the same edgeId
fb::TEdge evalEdge(const geolib3::Polyline2& edgeGeom,
                   EdgePhotoIterator first,
                   EdgePhotoIterator last)
{
    auto result = fb::TEdge{.id = first->edgeId.value()};
    auto cameraDeviations = toCameraDeviationSet(first, last);
    for (auto cameraDeviation : cameraDeviations) {
        auto privacy = db::FeaturePrivacy::Max;
        auto filter = [&](const EdgePhoto& edgePhoto) {
            return edgePhoto.photo.get().cameraDeviation == cameraDeviation &&
                   edgePhoto.photo.get().privacy <= privacy;
        };
        while (true) {
            auto edgePhotos = boost::make_iterator_range(first, last) |
                              boost::adaptors::filtered(filter);

            auto subpolylines = std::vector<SubpolylineWithPhotoPtr>{};
            for (const auto& edgePhoto : edgePhotos) {
                subpolylines.emplace_back(
                    common::geometry::SubPolyline(
                        convertToCommonPolylinePosition(edgePhoto.begin),
                        convertToCommonPolylinePosition(edgePhoto.end)),
                    std::addressof(edgePhoto.photo.get()));
            }

            auto coverage = makeEdgeCoverage(
                edgeGeom, common::geometry::merge(subpolylines, lessPriority));
            if (coverage.coverageFraction > 0.) {
                REQUIRE(coverage.cameraDeviation == cameraDeviation,
                        "invalid cameraDeviation");
                REQUIRE(coverage.privacy <= privacy, "invalid privacy");
                result.coverages.push_back(coverage);
            }
            if (coverage.privacy == db::FeaturePrivacy::Min) {
                break;
            }
            privacy = db::FeaturePrivacy(db::toIntegral(coverage.privacy) - 1);
        }
    }
    return result;
}

std::vector<fb::PhotoToEdgePair> toPhotoToEdgePairs(
    const EdgePhotos& edgePhotos)
{
    auto result = std::vector<fb::PhotoToEdgePair>{};
    for (const auto& edgePhoto : edgePhotos) {
        result.emplace_back(edgePhoto.photo.get().featureId,
                            edgePhoto.edgeId.value(),
                            edgePhoto.begin,
                            edgePhoto.end,
                            edgePhoto.order);
    }
    return result;
}

}  // namespace

Context::Context(const std::string& graphPath,
                 db::GraphType graphType_,
                 const std::string& prevMrcVersion,
                 const std::string& prevGraphPath,
                 const std::string& prevPhotoToEdgePath)
    : matcher(graphPath,
              adapters::MATCHER_CONFIG_RESOURCE,
              EMappingMode::Precharged)
    , persistentIndex(graphPath + "/" + EDGES_PERSISTENT_INDEX_FILE,
                      EMappingMode::Precharged)
    , graphType(graphType_)
{
    REQUIRE(matcher.graph().version() == persistentIndex.version(),
            ROAD_GRAPH_FILE << " version " << matcher.graph().version()
                            << " and " << EDGES_PERSISTENT_INDEX_FILE
                            << " version " << persistentIndex.version()
                            << " are not equal");
    try {
        REQUIRE(!prevGraphPath.empty(), "empty prevGraphPath");
        prevGraph.emplace(prevGraphPath + "/" + GRAPH_COVERAGE_FILE,
                          EMappingMode::Precharged);
        prevPersistentIndex.emplace(
            prevGraphPath + "/" + EDGES_PERSISTENT_INDEX_FILE,
            EMappingMode::Precharged);
        REQUIRE(prevGraph->mrcVersion() == prevMrcVersion,
                GRAPH_COVERAGE_FILE << " versions " << prevGraph->mrcVersion()
                                    << " and " << prevMrcVersion
                                    << " are not equal");
        REQUIRE(prevGraph->version() == prevPersistentIndex->version(),
                GRAPH_COVERAGE_FILE
                    << " version " << prevGraph->version() << " and "
                    << EDGES_PERSISTENT_INDEX_FILE << " version "
                    << prevPersistentIndex->version() << " are not equal");
        REQUIRE(!prevPhotoToEdgePath.empty(), "empty prevPhotoToEdgePath");
        prevPhotoToEdge.emplace(prevPhotoToEdgePath + "/" + PHOTO_TO_EDGE_FILE,
                                EMappingMode::Precharged);
        REQUIRE(prevPhotoToEdge->mrcVersion() == prevMrcVersion,
                PHOTO_TO_EDGE_FILE << " versions "
                                   << prevPhotoToEdge->mrcVersion() << " and "
                                   << prevMrcVersion << " are not equal");
        REQUIRE(prevPhotoToEdge->schemaVersion() ==
                    CURRENT_PHOTO_TO_EDGE_SCHEMA_VERSION,
                PHOTO_TO_EDGE_FILE << " old schema version");
        INFO() << graphType
               << " previous mrc version: " << prevMrcVersion;

        // changing the coverage scheme doesn't require rematching
        // prevPhotoToEdge can be used
        if (prevGraph->schemaVersion() != CURRENT_GRAPH_SCHEMA_VERSION) {
            WARN() << graphType << ": old schema of " << GRAPH_COVERAGE_FILE;
            prevGraph.reset();
        }
    }
    catch (const std::exception& e) {
        prevGraph.reset();
        prevPersistentIndex.reset();
        prevPhotoToEdge.reset();
        WARN() << graphType << ": " << e.what();
    }
}

GraphSummary makeGraphSummary(const Context& ctx,
                              const Photos& photos,
                              const TrackPointProvider& trackPointProvider)
{
    auto result =
        GraphSummary{.graph = fb::TGraph{.version = static_cast<std::string>(
                                             ctx.matcher.graph().version())}};

    auto edgePhotos = makeEdgePhotos(ctx, photos, trackPointProvider);
    auto edgePhotoIterators = groupBy(edgePhotos, lessEdge, equalEdge);
    common::parallelForEach<MAX_THREADS>(
        edgePhotoIterators.begin(),
        edgePhotoIterators.end(),
        [&](auto& guard, auto first) {
            auto last = findGroupEnd(first, edgePhotos.cend(), equalEdge);
            auto edge = findPrevEdgeIfActual(ctx, first, last);
            if (!edge) {
                auto edgeGeom = geolib3::Polyline2(
                    ctx.matcher.graph().edgeData(first->edgeId).geometry());
                edge = evalEdge(edgeGeom, first, last);
            }
            if (!edge->coverages.empty()) {
                std::lock_guard lock{guard};
                result.graph.edges.push_back(*edge);
                if (result.graph.edges.size() % BATCH_SIZE == 0) {
                    INFO() << ctx.graphType << ": generated "
                           << format(result.graph.edges.size()) << " edges";
                }
            }
        });
    INFO() << ctx.graphType << ": generated "
           << format(result.graph.edges.size()) << " edges";
    result.photoToEdgePairs = toPhotoToEdgePairs(edgePhotos);
    return result;
}

TrackPointProvider makeTrackPointProviderStub()
{
    return [](auto&&...) { return db::TrackPoints{}; };
}

TrackPointProvider makeTrackPointProvider(pgpool3::Pool& pool)
{
    return [&](std::string_view sourceId,
               chrono::TimePoint startTime,
               chrono::TimePoint endTime) {
        auto retryPolicy = maps::common::RetryPolicy{}
                               .setTryNumber(7)
                               .setInitialCooldown(std::chrono::seconds(1))
                               .setCooldownBackoff(2);
        return maps::common::retry(
            [&] {
                return db::TrackPointGateway{*pool.slaveTransaction()}.load(
                    db::table::TrackPoint::sourceId.equals(
                        std::string{sourceId}) &&
                    db::table::TrackPoint::timestamp.between(startTime,
                                                             endTime));
            },
            retryPolicy);
    };
}

}  // namespace maps::mrc::export_gen
