#include <maps/wikimap/mapspro/services/mrc/tools/experiment_sign_position_accuracy/lib/include/accuracy.h>
#include <maps/wikimap/mapspro/services/mrc/tools/experiment_sign_position_accuracy/lib/include/dataset.h>
#include <maps/wikimap/mapspro/services/mrc/tools/experiment_sign_position_accuracy/lib/include/db_loader.h>
#include <maps/wikimap/mapspro/services/mrc/tools/experiment_sign_position_accuracy/lib/include/geojson.h>
#include <maps/wikimap/mapspro/services/mrc/tools/experiment_sign_position_accuracy/lib/include/outdated_signs.h>

#include <maps/libs/cmdline/include/cmdline.h>
#include <maps/libs/chrono/include/time_point.h>
#include <maps/wikimap/mapspro/services/mrc/libs/db/include/feature_gateway.h>
#include <maps/wikimap/mapspro/services/mrc/libs/db/include/sign_gateway.h>
#include <maps/wikimap/mapspro/services/mrc/libs/graph_matcher_adapter/include/compact_graph_matcher_adapter.h>
#include <maps/wikimap/mapspro/services/mrc/libs/config/include/config.h>
#include <maps/wikimap/mapspro/services/mrc/libs/sensors_feature_positioner/include/sensors_loader.h>
#include <maps/wikimap/mapspro/services/mrc/libs/sensors_feature_positioner/include/sensors_feature_positioner_pool.h>
#include <maps/wikimap/mapspro/services/mrc/libs/sensors_feature_positioner/include/separate_rides.h>
#include <maps/wikimap/mapspro/services/mrc/libs/sensors_feature_positioner/include/object_ray.h>
#include <maps/wikimap/mapspro/services/mrc/libs/sensors_feature_positioner/include/object_positioner.h>
#include <maps/libs/log8/include/log8.h>

#include <maps/libs/geolib/include/distance.h>

#include <boost/optional.hpp>

#include <algorithm>
#include <chrono>
#include <vector>
#include <string>

namespace tws = maps::mrc::tracks_with_sensors;
namespace sfp = maps::mrc::sensors_feature_positioner;
namespace mrc = maps::mrc;
namespace db = maps::mrc::db;

// sensors algorithm needs some car movements to calculate camera orientation
const double MIN_TRACK_METERS = 3000;
const double SIGN_SIDE_SIZE_METERS = 0.7;

/**
 * @see
 * https://yav.yandex-team.ru/secret/sec-01cp2k8gbhycereqjrqg2va311/explore/versions
 */
const auto CONFIG_SECRET = "sec-01cp2k8gbhycereqjrqg2va311";

/**
 * @see
 * https://a.yandex-team.ru/arc/trunk/arcadia/maps/wikimap/mapspro/services/mrc/libs/config/cfg/t-config.production.xml
 */
const auto CONFIG_PATH = "../../../libs/config/cfg/t-config.production.xml";

// Loads db::Sign for each signPhoto
db::Signs loadСorrespondingSignsFromDb(const tws::SignPhotos& signPhotos,
                                       maps::wiki::common::PoolHolder& poolHolder)
{
    auto txn = poolHolder.pool().slaveTransaction();
    db::FeatureGateway featureGtw(*txn);

    std::vector<db::TId> featureIds;
    for (auto& signPhoto : signPhotos) {
        featureIds.push_back(signPhoto.featureId);
    }
    db::Features features = featureGtw.load(
        db::table::Feature::id.in(featureIds));
    std::map<db::TId, const db::Feature*> featuresMap;
    for (auto& feature : features) {
        featuresMap[feature.id()] = &feature;
    }

    db::SignFeatureGateway signFeatureGtw(*txn);
    db::SignFeatures signFeatures = signFeatureGtw.load(
        db::table::SignFeature::featureId.in({featureIds}));

    std::vector<db::TId> signIds;
    for (const auto& signFeature : signFeatures) {
        signIds.push_back(signFeature.signId());
    }

    db::SignGateway signGtw(*txn);
    db::Signs signs = signGtw.load(db::table::Sign::id.in({signIds}));
    std::map<db::TId, db::Sign*> signsMap;
    for (auto& sign : signs) {
        signsMap[sign.id()] = &sign;
    }

    auto getSignId = [&](db::TId featureId,
                         maps::mrc::common::ImageBox bbox) {
        for (const auto sf : signFeatures) {
            if (sf.featureId() != featureId) {
                continue;
            }
            auto curBbox = revertByImageOrientation(
                sf.imageBox(),
                featuresMap[featureId]->size(),
                featuresMap[featureId]->orientation());
            if (curBbox == bbox) {
                return sf.signId();
            }
        }
        throw maps::Exception("can't find sign " + std::to_string(featureId));
    };

    db::Signs resultSigns;
    resultSigns.reserve(signPhotos.size());

    for (const auto& signPhoto: signPhotos) {
        auto signId = getSignId(signPhoto.featureId,
                                signPhoto.bbox);
        resultSigns.push_back(*signsMap[signId]);
        resultSigns.back().setType(signPhoto.signType);
    }

    return resultSigns;
}

boost::posix_time::ptime toBoost(const maps::chrono::TimePoint& time)
{
    using clock = std::chrono::system_clock;
    return boost::posix_time::from_time_t(clock::to_time_t(
        std::chrono::time_point_cast<clock::time_point::duration>(time)));
}

void sortUniqueByTime(db::TrackPoints& trackPoints)
{
    std::sort(trackPoints.begin(), trackPoints.end(),
              [](const db::TrackPoint& lhs,
                 const db::TrackPoint& rhs) {
                  return lhs.timestamp()
                         < rhs.timestamp();
              });
    auto end = std::unique(
        trackPoints.begin(), trackPoints.end(),
        [](const db::TrackPoint& lhs,
           const db::TrackPoint& rhs) {
            return lhs.timestamp()
                       == rhs.timestamp()
                   || toBoost(rhs.timestamp())
                          == toBoost(lhs.timestamp());
        });
    trackPoints.erase(end, trackPoints.end());
}

int main(int argc, char* argv[]) {
    maps::log8::setLevel(maps::log8::Level::INFO);

    maps::cmdline::Parser parser("Loads gps track either from dataset or from"
                                 " mrc db and exports it to"
                                 " geojson file. The created file contains"
                                 " two tracks: red - gps track,"
                                 " blue - matched track");
    auto outputGeojson = parser.string("out-geojson")
        .help("output Geojson file");
    auto datasetPrefix = parser.string("dataset")
        .help("dataset with track and signs");
    auto assignmentId = parser.num("assignment-id");
    auto sourceId = parser.string("source-id");
    auto staticGraphDir = parser.string("static-graph-dir")
        .defaultValue("/var/lib/datasets/yandex-maps-mrc-graph_20.12.13-36408320");
    auto searchForOutdatedSigns = parser.flag("find-outdated-signs");
    parser.parse(argc, argv);

    REQUIRE(datasetPrefix.defined()
            ^ (assignmentId.defined() && sourceId.defined()),
            "provide either dataset path or assignmentId and sourceId");

    auto config = maps::mrc::common::Config(
        maps::vault_boy::loadContextWithYaVault(CONFIG_SECRET), CONFIG_PATH);
    auto mdsClient = config.makeMdsClient();

    maps::wiki::common::PoolHolder poolHolder = config.makePoolHolder();
    tws::GeojsonSaver geojsonSaver;

    mrc::db::TrackPoints trackPoints = datasetPrefix.defined()
        ? tws::loadTrack(datasetPrefix)
        : tws::loadTrack(poolHolder, assignmentId, sourceId);
    INFO() << "Loaded " << trackPoints.size() << " track points";
    REQUIRE(trackPoints.size() > 1, "Too few track points");
    auto trackCopy = trackPoints;
    sortUniqueByTime(trackCopy);

    tws::Photos allPhotos = datasetPrefix.defined()
        ? tws::loadPhotos(datasetPrefix,
                          trackCopy.front().timestamp(),
                          trackCopy.back().timestamp())
        : tws::loadPhotos(poolHolder,
                          assignmentId,
                          sourceId,
                          trackCopy.front().timestamp(),
                          trackCopy.back().timestamp());

    double sourceTrackLengthMeters = 0;
    double sourceTrackLengthWOGaps = 0;
    double improvedTrackLengthMeters = 0;

    sfp::SensorEvents sensorEvents = datasetPrefix.defined()
        ? tws::loadSensors(datasetPrefix)
        : sfp::loadAssignmentSensors(mdsClient,
                                     poolHolder.pool(),
                                     maps::chrono::TimePoint(),
                                     std::chrono::system_clock::now(),
                                     sourceId,
                                     assignmentId);
    INFO() << "Loaded " << sensorEvents.gyroEvents.size() << " gyro sensors";
    INFO() << "Loaded " << sensorEvents.accEvents.size() << " acc sensors";
    REQUIRE(sensorEvents.gyroEvents.size() > 10, "Too few sensors events");
    REQUIRE(sensorEvents.accEvents.size() > 10, "Too few sensors events");

    sfp::Tracks tracks = sfp::splitIntoSeparateRides(trackPoints,
                                                     sensorEvents);

    std::unordered_set<maps::mrc::db::TId> featuresWithPosition;
    for (auto track : tracks) {
        double curTrackLength = 0;
        try {
            for (size_t i = 1; i < track.trackPoints.size(); i++) {
                double l = maps::geolib3::fastGeoDistance(track.trackPoints[i - 1].geodeticPos(),
                                                          track.trackPoints[i].geodeticPos());
                if (l < 50) {
                    curTrackLength += l;
                }
            }
            sourceTrackLengthMeters += curTrackLength;
            if (curTrackLength < MIN_TRACK_METERS) {
                WARN() << "skipped short track with "
                       << curTrackLength << " meters lenght";
                continue;
            }
            sourceTrackLengthWOGaps += curTrackLength;

            sfp::SensorsFeaturePositioner featurePositioner(
                mrc::adapters::CompactGraphMatcherAdapter(staticGraphDir),
                track.trackPoints,
                track.sensorEvents);
            INFO() << "track times " << featurePositioner.trackStartTime().time_since_epoch().count()
                   << " " << featurePositioner.trackEndTime().time_since_epoch().count();

            double curImprovedTrackLength = 0;
            std::optional<maps::mrc::pos_improvment::ImprovedGpsEvent> prevEvent;
            for (auto t = featurePositioner.trackStartTime();
                 t < featurePositioner.trackEndTime(); t += std::chrono::seconds(1)) {
                auto curEvent = featurePositioner.getPositionByTime(t);
                if (curEvent && prevEvent) {
                    double l = maps::geolib3::fastGeoDistance(curEvent->geodeticPosition(),
                                                              prevEvent->geodeticPosition());
                    if (l < 50) {
                        improvedTrackLengthMeters += l;
                        curImprovedTrackLength += l;
                    }
                }
                prevEvent = curEvent;
            }

            INFO() << "curTrackLength = " << curTrackLength / 1000 << " km,"
                   << " curImprovedTrackLength = " << curImprovedTrackLength / 1000 << " km";

            tws::Photos photos = datasetPrefix.defined()
                ? tws::loadPhotos(datasetPrefix,
                                  track.trackPoints.front().timestamp(),
                                  track.trackPoints.back().timestamp())
                : tws::loadPhotos(poolHolder,
                                  assignmentId,
                                  sourceId,
                                  track.trackPoints.front().timestamp(),
                                  track.trackPoints.back().timestamp());
            for (const auto& photo : photos) {
                if (!featurePositioner.getPositionByTime(photo.timestamp)) {
                    WARN() << "can't calculate position for photo";
                } else {
                    featuresWithPosition.insert(photo.featureId);
                }
            }
            std::unordered_set<mrc::db::TId> featureIds;
            featureIds.reserve(photos.size());
            for (const auto& photo : photos) {
                featureIds.insert(photo.featureId);
            }

            tws::SignPhotos signPhotos = datasetPrefix.defined()
                ? tws::loadSigns(datasetPrefix,
                                 featureIds)
                : tws::loadSignBboxes(poolHolder, photos);
            sfp::Rays signRays;
            signRays.reserve(signPhotos.size());

            size_t rayIdGenerator = 0;
            for (const auto& signPhoto : signPhotos) {
                std::optional<mrc::pos_improvment::ImprovedGpsEvent> cameraPos
                    = featurePositioner.getPositionByTime(
                        signPhoto.featureTs);
                if (!cameraPos) {
                    WARN() << "no improved position for a ray";
                    continue;
                }
                if (cameraPos) {
                    signRays.push_back(sfp::constructRay(
                        signPhoto.bbox,
                        *cameraPos,
                        signPhoto.photoSize,
                        rayIdGenerator++,
                        signPhoto.featureId,
                        sourceId,
                        SIGN_SIDE_SIZE_METERS));
                    signRays.back().objectTypeId = static_cast<size_t>(signPhoto.signType);
                    if (isnan(cameraPos->odometerMercatorPosition().x())) {
                        WARN() << "sign ray coordinate is nan";
                    }
                }
            }

            auto resultSignGroups = sfp::calculateObjectsPositions(signRays);
            std::vector<tws::Signs> signGroups(resultSignGroups.size());
            tws::Signs aggregatedSigns;

            for (size_t i = 0; i < resultSignGroups.size(); i++) {
                signGroups[i].push_back(
                    tws::Sign{resultSignGroups[i].mercatorPos,
                              static_cast<maps::mrc::traffic_signs::TrafficSign>(resultSignGroups[i].rays[0].objectTypeId),
                              sfp::calculateObjectHeading(resultSignGroups[i])});
                tws::Sign aggregatedSign = signGroups[i][0];
                aggregatedSigns.push_back(aggregatedSign);
            }

            if (searchForOutdatedSigns) {
                tws::findOutdatedSigns(poolHolder, photos, aggregatedSigns);
            }

            geojsonSaver.addSignsAsRays(signRays);
            geojsonSaver.addSignGroupsAsPolylines(signGroups);

            if (!datasetPrefix.defined()) {
                continue;
            }

            std::unordered_set<int> signGtIds;
            for (const auto& signPhoto : signPhotos) {
                signGtIds.insert(*signPhoto.signGtId);
            }
            std::vector<tws::Sign> gtSigns = tws::loadGtSigns(datasetPrefix, signGtIds);
            geojsonSaver.addSignsAsPoints(gtSigns, "gt", "#00ff00");

            INFO() << "Calculated signs accuracy";
            tws::printAccuracy(
                tws::compareWithGroundTruth(
                    aggregatedSigns, gtSigns, tws::SignMatching::Nearest));

            mrc::db::Signs correspondingDbSigns
                = loadСorrespondingSignsFromDb(signPhotos, poolHolder);
            tws::Signs dbSigns;
            std::unordered_set<int> addedIds;
            for (size_t i = 0; i < signPhotos.size(); i++) {
                if (addedIds.count(correspondingDbSigns[i].id())) {
                    continue;
                }
                dbSigns.push_back(tws::Sign{correspondingDbSigns[i].mercatorPos(),
                                            correspondingDbSigns[i].type(),
                                            correspondingDbSigns[i].heading(),
                                            signPhotos[i].signGtId});
                addedIds.insert(correspondingDbSigns[i].id());
            }
            geojsonSaver.addSignsAsPoints(dbSigns, "db", "#ff0000");
            INFO() << "Database signs accuracy";
            tws::printAccuracy(
                tws::compareWithGroundTruth(
                    dbSigns, gtSigns, tws::SignMatching::ById));

        } catch (const maps::Exception& e) {
            INFO() << "Failed to handle track of length " << curTrackLength / 1000.0;
            INFO() << e;
            continue;
        }
    }
    if (outputGeojson.defined()) {
        geojsonSaver.save(outputGeojson);
    }

    INFO() << "source track length km: " << sourceTrackLengthMeters / 1000.0;
    INFO() << "source track length without gaps km: " << sourceTrackLengthWOGaps / 1000.0;
    INFO() << "improved track length km: " << improvedTrackLengthMeters / 1000.0;
    INFO() << "handled " << featuresWithPosition.size() << " photos of " << allPhotos.size();

    return EXIT_SUCCESS;
}
