#include <maps/wikimap/mapspro/services/mrc/tools/experiment_sfm_signs_positioning/bin/sfm_positioning_evaluation.h>

#include <maps/libs/chrono/include/time_point.h>
#include <maps/libs/cmdline/include/cmdline.h>
#include <maps/wikimap/mapspro/services/mrc/libs/config/include/config.h>
#include <maps/wikimap/mapspro/services/mrc/libs/db/include/feature_gateway.h>
#include <maps/wikimap/mapspro/services/mrc/libs/db/include/sign_gateway.h>
#include <maps/wikimap/mapspro/services/mrc/libs/graph_matcher_adapter/include/graph_matcher_adapter.h>
#include <maps/wikimap/mapspro/services/mrc/libs/sensors_feature_positioner/include/sensors_feature_positioner_pool.h>
#include <maps/wikimap/mapspro/services/mrc/libs/sensors_feature_positioner/include/sensors_loader.h>
#include <maps/wikimap/mapspro/services/mrc/libs/sensors_feature_positioner/include/separate_rides.h>
#include <maps/wikimap/mapspro/services/mrc/libs/sensors_feature_positioner/include/sign_ray.h>
#include <maps/wikimap/mapspro/services/mrc/tools/experiment_sign_position_accuracy/lib/include/accuracy.h>
#include <maps/wikimap/mapspro/services/mrc/tools/experiment_sign_position_accuracy/lib/include/dataset.h>
#include <maps/wikimap/mapspro/services/mrc/tools/experiment_sign_position_accuracy/lib/include/geojson.h>
#include <maps/libs/log8/include/log8.h>

#include <boost/optional.hpp>

#include <chrono>
#include <string>
#include <vector>

namespace tws = maps::mrc::tracks_with_sensors;
namespace sfp = maps::mrc::sensors_feature_positioner;
namespace sfm = maps::mrc::experiment_sfm_positioning;
namespace mrc = maps::mrc;
namespace pimp = maps::mrc::pos_improvment;
namespace db = maps::mrc::db;

namespace {

// sensors algorithm needs some car movements to calculate camera orientation
constexpr double MIN_TRACK_METERS = 3000;

double computeTrackLength(const sfp::Track& track)
{
    double trackLength = 0;
    for (size_t i = 1; i < track.trackPoints.size(); i++) {
        double l = maps::geolib3::fastGeoDistance(
            track.trackPoints[i - 1].geodeticPos(),
            track.trackPoints[i].geodeticPos());
        if (l < 50) {
            trackLength += l;
        }
    }
    return trackLength;
}

// DEBUG
tws::Photos loadPhotos(const std::string& datasetPrefix)
{
    auto photos = tws::loadPhotos(datasetPrefix);
    // const auto isNotInSet = [](const tws::Photo& photo) {
    //    static std::set<db::TId> theSet
    //        = {53475428, 53475432, 53475450, 53475451, 53475452, 53475453,
    //           53475453, 53475454, 53475456, 53475457, 53475458};
    //    return !theSet.count(photo.featureId);
    //};
    // photos.erase(std::remove_if(photos.begin(), photos.end(), isNotInSet),
    //    photos.end());
    return photos;
}

// DEBUG
tws::Signs loadGtSigns(
    const std::string& datasetPrefix, const tws::Signs& /* signs */)
{
    // std::unordered_set<tws::SignGtId> gtIds;
    // for (const auto& sign : signs) {
    //    gtIds.insert(*sign.id);
    //}
    // return tws::loadGtSigns(datasetPrefix, gtIds);
    return tws::loadGtSigns(datasetPrefix);
}

void renderGeojsonPhotos(tws::GeojsonSaver& saver,
    const sfp::SensorsFeaturePositioners& featurePositioners,
    const tws::Photos& photos)
{
    auto featurePositionerIt = featurePositioners.begin();
    const auto getImprovedPosition = [&](maps::chrono::TimePoint timestamp)
        -> std::optional<pimp::ImprovedGpsEvent> {
        while (featurePositionerIt != featurePositioners.end()) {
            if (timestamp < featurePositionerIt->trackStartTime()) {
                if (featurePositionerIt == featurePositioners.begin()) {
                    return std::nullopt;
                }
                --featurePositionerIt;
                continue;
            } else if (featurePositionerIt->trackEndTime() < timestamp) {
                ++featurePositionerIt;
                continue;
            }
            return featurePositionerIt->getPositionByTime(timestamp);
        }
        return std::nullopt;
    };

    for (const auto& photo : photos) {
        const auto improvedPosition = getImprovedPosition(photo.timestamp);
        if (improvedPosition) {
            saver.addPhotoAsRay(*improvedPosition);
        }
    }
}

std::optional<tws::Accuracy> evaluateDataset(const std::string& datasetPrefix,
    const mrc::adapters::GraphMatcherAdapter& graphMatcherAdapter,
    bool exportToGeojson,
    std::size_t featuresLimit)
{
    INFO() << "Start with dataset " << datasetPrefix;

    mrc::db::TrackPoints trackPoints = tws::loadTrack(datasetPrefix);
    INFO() << "Loaded " << trackPoints.size() << " track points";
    REQUIRE(trackPoints.size() > 1, "Too few track points");

    tws::SensorEvents sensorEvents = tws::loadSensors(datasetPrefix);
    REQUIRE(sensorEvents.gyroEvents.size() > 10, "Too few sensors events");
    INFO() << "Loaded " << sensorEvents.gyroEvents.size() << " sensors";

    // Must split into tracks that satisfy the sensors positioning
    //      requirements.
    // WARN: The tracks are referenced by a feature positioner so must
    //       outsurvive it.
    sfp::Tracks tracks
        = sfp::splitIntoSeparateRides(trackPoints, sensorEvents);
    if (tracks.empty()) {
        WARN() << "The ground truth track is empty, skipping this dataset";
        return std::nullopt;
    }

    sfp::SensorsFeaturePositioners featurePositioners;

    double tracksLength = 0;
    for (auto track : tracks) {
        const auto curTrackLength = computeTrackLength(track);
        if (curTrackLength < MIN_TRACK_METERS) {
            WARN() << "skipped short track with " << curTrackLength
                   << " meters lenght";
            continue;
        }
        tracksLength += curTrackLength;

        featurePositioners.emplace_back(
            graphMatcherAdapter, track.trackPoints, track.sensorEvents);
    }
    INFO() << "TRACKS LENGTH = " << tracksLength / 1000.0 << " km";

    const tws::Photos photos = loadPhotos(datasetPrefix);
    const tws::SignPhotos signPhotos = tws::loadSigns(datasetPrefix);
    const tws::Signs locatedSigns = sfm::locateSignsByTwoPhotos(
        featurePositioners, photos, signPhotos, featuresLimit);

    const tws::Signs gtSigns = loadGtSigns(datasetPrefix, locatedSigns);

    if (exportToGeojson) {
        // Add GT signs to geojson
        tws::GeojsonSaver geojsonSaver;
        geojsonSaver.addSignsAsPoints(gtSigns, "gt", "#00ff00");
        geojsonSaver.addSignsAsPoints(locatedSigns, "calc", "#ff0000");
        renderGeojsonPhotos(geojsonSaver, featurePositioners, photos);
        geojsonSaver.save(datasetPrefix + ".geojson");
    }

    INFO() << "Accuracy evaluated for dataset " << datasetPrefix;
    const auto accuracy = tws::compareWithGroundTruth(
        locatedSigns, gtSigns, tws::SignMatching::ById);
    tws::printAccuracy(accuracy);
    return accuracy;
}

tws::Accuracy aggregateAccuracies(
    const std::vector<tws::Accuracy>& accuracies)
{
    double totalWeight = 0;
    for (const auto& accuracy : accuracies) {
        totalWeight += accuracy.sortedErrors.size();
    }

    tws::Accuracy aggregate{0, {}, 0, 0, 0};
    for (const auto& accuracy : accuracies) {
        const double contribution
            = accuracy.sortedErrors.size() / totalWeight;
        aggregate.avgError += contribution * accuracy.avgError;
        aggregate.recall += contribution * accuracy.recall;
        aggregate.deviation += contribution * accuracy.deviation;

        aggregate.severalSignsForOneGt += accuracy.severalSignsForOneGt;
        aggregate.sortedErrors.insert(aggregate.sortedErrors.end(),
            accuracy.sortedErrors.begin(), accuracy.sortedErrors.end());
    }
    std::sort(aggregate.sortedErrors.begin(), aggregate.sortedErrors.end());

    return aggregate;
}
}

int main(int argc, char* argv[]) try

{
    maps::log8::setLevel(maps::log8::Level::INFO);

    maps::cmdline::Parser parser("Evaluate results of two-photos SfM "
                                 "approach to compute sign positions");
    auto geojson = parser.flag("geojson").help(
        "Export positioning results to geojson");

    auto staticGraphDir
        = parser.dir("static-graph-dir").help("Path to a static graph data");
    // E.g. /var/spool/yandex/maps/graph/19.01.17-1

    auto featuresLimit = parser.num("feture-matches")
                             .help("lower bound of matched feature points on "
                                   "a couple of photos")
                             .defaultValue(8);

    parser.parse(argc, argv);
    REQUIRE(!parser.argv().empty(),
        "You must specify a list of dataset prefixes!");

    mrc::adapters::GraphMatcherAdapter graphMatcherAdapter(staticGraphDir);

    std::vector<tws::Accuracy> accuracies;
    for (const auto& datasetPrefix : parser.argv()) {
        const auto accuracy = evaluateDataset(
            datasetPrefix, graphMatcherAdapter, geojson, featuresLimit);
        if (accuracy) {
            accuracies.push_back(*accuracy);
        }
    }

    INFO() << "===================";
    INFO() << "Aggregated accuracy";
    INFO() << "===================";
    tws::printAccuracy(aggregateAccuracies(accuracies));

    return EXIT_SUCCESS;

} catch (const maps::Exception& e) {
    INFO() << e;
    return EXIT_FAILURE;
}
