
#include <maps/wikimap/mapspro/services/mrc/eye/experiments/signs_map/matcher/find_missing_object/lib/common.h>
#include <maps/wikimap/mapspro/services/mrc/eye/experiments/signs_map/matcher/find_missing_object/lib/util.h>
#include <maps/wikimap/mapspro/services/mrc/eye/experiments/signs_map/matcher/find_missing_object/lib/yt_serialization.h>
#include <maps/wikimap/mapspro/services/mrc/eye/lib/detection/include/catboost_visibility_predictor.h>
#include <maps/wikimap/mapspro/services/mrc/eye/lib/detection/include/superpoints_match.h>
#include <maps/wikimap/mapspro/services/mrc/eye/lib/detection/include/yt_serialization.h>

#include <maps/wikimap/mapspro/services/mrc/libs/common/include/types.h>
#include <maps/wikimap/mapspro/services/mrc/libs/common/include/image_box.h>
#include <maps/wikimap/mapspro/services/mrc/libs/common/include/keypoints.h>
#include <maps/wikimap/mapspro/services/mrc/libs/common/include/exif.h>
#include <maps/wikimap/mapspro/services/mrc/libs/common/include/load.h>

#include <maps/wikimap/mapspro/services/mrc/libs/db/include/common.h>
#include <maps/wikimap/mapspro/services/mrc/libs/db/include/feature_gateway.h>

#include <maps/wikimap/mapspro/services/mrc/libs/carsegm/include/carsegm.h>
#include <maps/wikimap/mapspro/services/mrc/libs/superpoint/include/superpoint.h>
#include <maps/wikimap/mapspro/services/mrc/libs/keypoints_matcher/include/superglue_matcher.h>
#include <maps/wikimap/mapspro/services/mrc/libs/keypoints_matcher/include/matcher.h>

#include <maps/wikimap/mapspro/services/mrc/libs/yt/include/common.h>
#include <maps/wikimap/mapspro/services/mrc/libs/yt/include/serialization.h>

#include <maps/libs/common/include/base64.h>
#include <maps/libs/cmdline/include/cmdline.h>
#include <maps/libs/log8/include/log8.h>
#include <maps/libs/common/include/exception.h>
#include <maps/libs/introspection/include/hashing.h>
#include <maps/libs/introspection/include/comparison.h>
#include <maps/libs/introspection/include/tuple_for_each.h>
#include <maps/libs/json/include/value.h>
#include <maps/libs/chrono/include/time_point.h>
#include <maps/libs/geolib/include/point.h>
#include <maps/libs/geolib/include/line.h>
#include <maps/libs/stringutils/include/join.h>

#include <mapreduce/yt/interface/client.h>
#include <mapreduce/yt/util/temp_table.h>

#include <util/generic/size_literals.h>
#include <util/generic/yexception.h>
#include <util/generic/iterator_range.h>

#include <opencv2/opencv.hpp>

#include <unordered_map>
#include <vector>
#include <tuple>
#include <utility>

namespace maps::mrc::eye::pov {
using maps::introspection::operator==;
using maps::introspection::operator<;

const std::string MDS_PREFIX = "http://storage-int.mds.yandex.net/get-maps_mrc/";

std::string getFileName(const std::string& path)
{
    return std::filesystem::path(std::string(path)).filename().string();
}


NYT::TNode createPredictObjectVisibilityOperationSpec() {
    static const std::string TITLE = "Predict object visibility";

    NYT::TNode operationSpec = yt::baseGpuOperationSpec(TITLE, yt::PoolType::AdHoc)
        ("mapper", yt::baseGpuWorkerSpec()
            ("gpu_limit", 1)
            ("memory_limit", 16_GB)
        )
        ("reducer", yt::baseGpuWorkerSpec()
            ("memory_limit", 16_GB)
        );

    operationSpec("max_failed_job_count", 30);

    return operationSpec;
}

std::pair<db::TId, db::Feature> deserializeToFeature(const NYT::TNode & node)
{
    auto feature = db::Feature(
        node["source_id"].AsString(),
        yt::deserialize<geolib3::Point2>(node["pos"]),
        yt::deserialize<geolib3::Heading>(node["heading"]),
        chrono::formatSqlDateTime(yt::deserialize<chrono::TimePoint>(node["date"])),
        mds::Key(node["mds_group_id"].AsString(), node["mds_path"].AsString()),
        db::Dataset::Agents)
        .setSize(node["width"].IntCast<int>(),
                 node["height"].IntCast<int>())
        .setOrientation(yt::deserialize<common::ImageOrientation>(node["orientation"]));

    return std::make_pair(
        node[FEATURE_ID_COLUMN].AsInt64(),
        std::move(feature)
    );
}

/// Matches pairs of frames and evaluates fundamental matrix for them.
/// For each object on a frame evaluates corresponding epipolar lines
/// in other frame.
class PredictObjectVisibilityMapper : public maps::mrc::yt::Mapper
{
public:
    Y_SAVELOAD_JOB(objectFileName_, clusterFileName_);

    PredictObjectVisibilityMapper() = default;

    PredictObjectVisibilityMapper(
        const std::string& objectFileName,
        const std::string& clusterFileName)
        : objectFileName_(objectFileName)
        , clusterFileName_(clusterFileName)
    { }

    void Do(TReader* reader, TWriter* writer) override
    {
        readFeaturesPairs(reader);
        loadClusters();
        loadDetections();
        loadImages();
        makeCarsMasks();
        detectKeypoints();
        matchPairs(writer);
    }

private:
    void readFeaturesPairs(TReader* reader)
    {
        for (; reader->IsValid(); reader->Next()) {
            const NYT::TNode &inpRow = reader->GetRow();
            auto [feature1Id, feature1] = deserializeToFeature(inpRow["first"]);
            auto [feature2Id, feature2] = deserializeToFeature(inpRow["second"]);
            if (feature1Id > feature2Id) {
                continue;
            }
            featurePairs_.emplace_back(feature1Id, feature2Id);
            idToFeature_.emplace(feature1Id, std::move(feature1));
            idToFeature_.emplace(feature2Id, std::move(feature2));
        }
    }

    void loadDetections() {
        auto objects = json::Value::fromFile(objectFileName_);
        auto featuresObjects = objects["features_objects"];
        for (const auto& featureObjectsItem : featuresObjects) {
            db::TId featureId = featureObjectsItem[FEATURE_ID_COLUMN].as<db::TId>();

            if (!idToFeature_.count(featureId)) {
                continue;
            }

            auto orientationField = featureObjectsItem["orientation"];
            if (orientationField.exists()) {
                const auto orientation = common::ImageOrientation::fromExif(
                        orientationField.as<int>()
                    );
                const auto featureOrientation = idToFeature_.at(featureId).orientation();
                if (orientation != featureOrientation) {
                    WARN() << "The oirientation of detections "
                        << (int) orientation
                        << " differs from the orientation of the feature "
                        << (int) featureOrientation << " for feature " << featureId
                        << ". Ignoring these detections";
                        continue;
                }
            }

            for (auto detectionJson : featureObjectsItem["objects"]) {
                auto detection = Detection::fromJson(detectionJson);
                auto objectId = ObjectId{
                    .featureId = featureId,
                    .detectionId = detection.id};
                if (objectIdToClusterId.count(objectId)) {
                    featureIdToDetections_.emplace(featureId, std::move(detection));
                }
            }
        }
    }

    void loadClusters() {
        auto value = json::Value::fromFile(clusterFileName_);
        auto clusters = value["clusters"];
        for (const auto& cluster : clusters) {
            db::TId clusterId = cluster["cluster_id"].as<db::TId>();
            for (const auto& object : cluster["objects"]) {
                objectIdToClusterId.emplace(
                    ObjectId{.featureId = object[FEATURE_ID_COLUMN].as<db::TId>(),
                             .detectionId = object["object_id"].as<db::TId>()},
                    clusterId
                );
            }
        }
    }

    void loadImages() {
        for(const auto& [id, feature] : idToFeature_) {
            const std::string url = MDS_PREFIX + feature.mdsGroupId()
                + "/" + feature.mdsPath();
            featureIdToImageData_.emplace(
                id,
                ImageData(common::load(url),feature.orientation())
            );
        }
    }

    void makeCarsMasks() {
        constexpr int MASK_DILATE_KERNEL_SZ = 5;
        static const cv::Mat kernel = cv::Mat::ones(MASK_DILATE_KERNEL_SZ, MASK_DILATE_KERNEL_SZ, CV_8UC1);

        const carsegm::CarSegmentator carSegmentator;
        for (auto& [_, imageData] : featureIdToImageData_) {
            cv::Mat mask;
            cv::dilate(
                carSegmentator.segment(imageData.scaledDecodedImage()),
                mask, kernel);
            imageData.setScaledCarMask(mask);
        }
    }

    // ищем ключевые точки удаляем маску машинок и сжатую картинку, потому что теперь
    // они больше не нужны, достаточно только ключевых точек
    void detectKeypoints() {
        const superpoint::SuperpointDetector sptsDetector;
        for (auto& [_, imageData] : featureIdToImageData_) {
            auto kpts = sptsDetector.detect(imageData.scaledDecodedImage());
            auto filteredKpts = filterPointsByMask(kpts, imageData.scaledCarMask());
            imageData.setScaledKpts(std::move(filteredKpts));
            imageData.resetScaledCarMask();
            imageData.resetEncodedImage();
        }
    }

    MatchedKeypoints matchFeatures(
        const keypoints_matcher::SuperglueMatcher& matcher,
        db::TId id0, db::TId id1)
    {
        REQUIRE(featureIdToImageData_.count(id0), "id0 was not found in featureIdToImageData_");
        const ImageData& data0 = featureIdToImageData_.at(id0);
        REQUIRE(featureIdToImageData_.count(id1), "id1 was not found in featureIdToImageData_");
        const ImageData& data1 = featureIdToImageData_.at(id1);

        MatchedKeypoints match
            = getKeypointsMatch(matcher, data0.scaledKpts(), data0.scale(), data1.scaledKpts(), data1.scale());

        return match;
    }

    void matchPairs(TWriter* writer) {
        constexpr double SUPERGLUE_CONFIDENCE_THRESHOLD = 0.65;

        keypoints_matcher::SuperglueMatcher matcher(SUPERGLUE_CONFIDENCE_THRESHOLD);

        for (const auto& [id0, id1] : featurePairs_) {
            const auto& feature0 = idToFeature_.at(id0);
            const auto& feature1 = idToFeature_.at(id1);
            auto matchedKeypoints = matchFeatures(matcher, id0, id1);
            if (const auto matchInfo = getFrameMatch(matchedKeypoints.points0, matchedKeypoints.points1)){
                for (const auto& [_, detection] : MakeIteratorRange(featureIdToDetections_.equal_range(id0))) {
                    REQUIRE(featureIdToImageData_.count(id1), "matchPairs: id1 was not found in featureIdToImageData_");
                    const auto transformedBbox = transformByImageOrientation(
                        detection.bbox,
                        feature0.size(),
                        feature0.orientation());
                    const auto matchPlace = FrameMatchPlace::Second;
                    auto epilines = calcEpilines(matchPlace, matchInfo->fundMatrix, transformedBbox);

                    auto objectId = ObjectId{id0, detection.id};
                    REQUIRE(objectIdToClusterId.count(objectId), "matchPairs: id0 was not found in objectIdToClusterId "
                        << id0 << " " << detection.id);
                    writer->AddRow(
                        yt::serialize(
                            ObjectMatchInfo {
                                .clusterId = objectIdToClusterId.at(objectId),
                                .objectId = ObjectId{.featureId = id0, .detectionId = detection.id},
                                .predictedFeatureId = id1,
                                .matchPlace = matchPlace,
                                .framesMatch = matchInfo.value(),
                                .epilines = std::move(epilines)
                            }
                        )
                    );
                }
                for (const auto& [_, detection] : MakeIteratorRange(featureIdToDetections_.equal_range(id1))) {
                    REQUIRE(featureIdToImageData_.count(id0), "matchPairs: id0 was not found in featureIdToImageData_");
                    const auto transformedBbox = transformByImageOrientation(
                        detection.bbox,
                        feature1.size(),
                        feature1.orientation());
                    const auto matchPlace = FrameMatchPlace::First;
                    auto epilines = calcEpilines(matchPlace, matchInfo->fundMatrix, transformedBbox);
                    auto objectId = ObjectId{id1, detection.id};
                    REQUIRE(objectIdToClusterId.count(objectId), "matchPairs: id1 was not found in objectIdToClusterId "
                        << id1 << " " << detection.id);
                    writer->AddRow(
                        yt::serialize(
                            ObjectMatchInfo {
                                .clusterId = objectIdToClusterId.at(objectId),
                                .objectId = ObjectId{.featureId = id1, .detectionId = detection.id},
                                .predictedFeatureId = id0,
                                .matchPlace = matchPlace,
                                .framesMatch = matchInfo.value(),
                                .epilines = std::move(epilines)
                            }
                        )
                    );
                }
            }
        }
    }

    TString objectFileName_;
    TString clusterFileName_;

    std::unordered_map<db::TId, db::Feature> idToFeature_;
    std::vector<std::pair<db::TId, db::TId>> featurePairs_;
    std::unordered_multimap<db::TId, Detection> featureIdToDetections_;
    std::unordered_map<ObjectId, db::TId, introspection::Hasher> objectIdToClusterId;
    std::unordered_map<db::TId, ImageData> featureIdToImageData_;
};

REGISTER_MAPPER(PredictObjectVisibilityMapper);


/// Takes output of PredictObjectVisibilityMapper and also
/// a ground-truth information about clusters and computes
/// custom visibility factors for predicting if this object
/// should be visible on the other frame.
class ComputeCluterVisibilityFactorsReducer : public maps::mrc::yt::Reducer
{
public:
    Y_SAVELOAD_JOB(objectFileName_, clusterFileName_, featureFileName_);

    ComputeCluterVisibilityFactorsReducer() = default;

    ComputeCluterVisibilityFactorsReducer(
        const std::string& objectFileName,
        const std::string& clusterFileName,
        const std::string& featureFileName)
        : objectFileName_(objectFileName)
        , clusterFileName_(clusterFileName)
        , featureFileName_(featureFileName)
    { }

    void Do(maps::mrc::yt::Reader* reader, maps::mrc::yt::Writer* /*writer*/) override
    {
        constexpr int GOOD_POINTS_THRESHOLD = 20;
        for (; reader->IsValid(); reader->Next()) {
            auto info = yt::deserializeToObjectMatchInfo(reader->GetRow());
            if (info.framesMatch.goodPtsCnt < GOOD_POINTS_THRESHOLD) {
                continue;
            }
            auto clusterId = info.clusterId;
            clusterIds_.insert(clusterId);
            objectIds_.insert(info.objectId);
            featureIds_.insert(info.objectId.featureId);
            featureIds_.insert(info.predictedFeatureId);
            auto id = ClusterWithFeatureId{.clusterId = clusterId, .featureId = info.predictedFeatureId};
            clusterIdToObjectsMatchInfos_.emplace(id, std::move(info));
        }
    }

    void Finish(maps::mrc::yt::Writer* writer) override
    {
        constexpr size_t MATCHES_LIMIT = 20;
        const auto idToFeature = loadFeatures();
        const auto clusterFeaturePairs = loadClusterFeaturePairs();
        const auto objectIdToDetection = loadDetections();

        for (auto bucketBeginIt = clusterIdToObjectsMatchInfos_.begin();
            bucketBeginIt != clusterIdToObjectsMatchInfos_.end();)
        {
            auto bucketEndIt = clusterIdToObjectsMatchInfos_.upper_bound(bucketBeginIt->first);
            auto [clusterId, featureId] = bucketBeginIt->first;
            const bool isVisible =
                clusterFeaturePairs.count(std::make_pair(clusterId, featureId));
            const auto& feature = idToFeature.at(featureId);
            std::vector<DetectionMatchData> matches;
            std::vector<MatchEpilines> epilinesVec;
            for (const auto& [_, objMatchInfo]:
                 MakeIteratorRange(bucketBeginIt, bucketEndIt)) {
                const auto& featureWithDetection =
                    idToFeature.at(objMatchInfo.objectId.featureId);
                if (!belongsToSamePassage(feature, featureWithDetection)) {
                    const auto& detection =
                        objectIdToDetection.at(objMatchInfo.objectId);
                    const auto transformedBbox = transformByImageOrientation(
                        detection.bbox,
                        featureWithDetection.size(),
                        featureWithDetection.orientation());

                    matches.push_back(DetectionMatchData{
                        .matchPlace = objMatchInfo.matchPlace,
                        .match = objMatchInfo.framesMatch,
                        .detectionFrameMercPosition =
                            featureWithDetection.mercatorPos(),
                        .detectionFrameHeading = featureWithDetection.heading(),
                        .detectionBbox = transformedBbox,
                    });
                    epilinesVec.push_back(objMatchInfo.epilines);
                }
                if (matches.size() >= MATCHES_LIMIT) {
                    break;
                }
            }
            auto visibilityFactors = calcObjectVisibilityFactors(
                common::transformByImageOrientation(
                    feature.size(), feature.orientation()),
                feature.mercatorPos(),
                feature.heading(),
                matches,
                epilinesVec);

            writer->AddRow(yt::serialize(CatboostDatasetRow{
                .key = std::to_string(static_cast<int>(isVisible)),
                .value = visibilityFactorsToTsv(visibilityFactors)}));
            bucketBeginIt = bucketEndIt;
        }
    }

private:

    std::unordered_map<db::TId, db::Feature>
    loadFeatures() const
    {
        std::unordered_map<db::TId, db::Feature> result;
        TIFStream stream(featureFileName_);
        NYT::TTableReaderPtr<NYT::TNode> frameReader = NYT::CreateTableReader<NYT::TNode>(&stream);
        for (; frameReader->IsValid(); frameReader->Next()) {
            auto [featureId, feature] = deserializeToFeature(frameReader->GetRow());
            if (!featureIds_.count(featureId)) {
                continue;
            }
            result.emplace(featureId, std::move(feature));
        }
        return result;
    }

    std::set<std::pair<db::TId, db::TId>>
    loadClusterFeaturePairs() const
    {
        std::set<std::pair<db::TId, db::TId>> clusterFeaturePairs;
        auto value = json::Value::fromFile(clusterFileName_);
        auto clusters = value["clusters"];
        for (const auto& cluster : clusters) {
            db::TId clusterId = cluster["cluster_id"].as<db::TId>();
            for (const auto& object : cluster["objects"]) {
                clusterFeaturePairs.emplace(
                    clusterId,
                    object[FEATURE_ID_COLUMN].as<db::TId>()
                );
            }
        }
        return clusterFeaturePairs;
    }

    std::unordered_map<ObjectId, Detection, introspection::Hasher>
    loadDetections() {
        std::unordered_map<ObjectId, Detection, introspection::Hasher> result;
        auto objects = json::Value::fromFile(objectFileName_);
        auto featuresObjects = objects["features_objects"];
        for (const auto& featureObjectsItem : featuresObjects) {
            db::TId featureId = featureObjectsItem[FEATURE_ID_COLUMN].as<db::TId>();
            if (!featureIds_.count(featureId)) {
                continue;
            }
            for (auto detectionJson : featureObjectsItem["objects"]) {
                auto detection = Detection::fromJson(detectionJson);
                db::TId detectionId = detection.id;
                result.emplace(
                    ObjectId{.featureId = featureId, .detectionId = detectionId},
                    std::move(detection));
            }
        }
        return result;
    }

    NYT::TNode visibilityFactorsToNode(const std::vector<float>& factors)
    {
        auto node = NYT::TNode::CreateMap();
        for (size_t i = 0; i < factors.size(); ++i) {
            node(std::to_string(i), factors[i]);
        }
        return node;
    }

    std::string visibilityFactorsToTsv(const std::vector<float>& factors)
    {
        std::ostringstream os;
        csv::OutputStream csvWriter(os, '\t');

        for (auto value : factors) {
            csvWriter << value;
        }

        return os.str();
    }


    TString objectFileName_;
    TString clusterFileName_;
    TString featureFileName_;

    std::unordered_set<db::TId> featureIds_;
    std::unordered_set<db::TId> clusterIds_;
    std::unordered_set<ObjectId, introspection::Hasher> objectIds_;
    std::multimap<ClusterWithFeatureId, ObjectMatchInfo> clusterIdToObjectsMatchInfos_;

};

REGISTER_REDUCER(ComputeCluterVisibilityFactorsReducer);


// Writes columns description table in format expected by catboost
// https://docs.yandex-team.ru/catboost/concepts/input-data_column-descfile#yt-tables
void writeColumnsDescriptions(NYT::IClientPtr& ytClient, const std::string& path)
{
    size_t columnIdx = 0;
    auto writer = ytClient->CreateTableWriter<NYT::TNode>(TString(path));
    writer->AddRow(yt::serialize(
        CatboostDatasetRow{
            .key = std::to_string(columnIdx++),
            .value = "Label"
        }
    ));

    const auto emptyFactors = calcObjectVisibilityFactors({}, {}, {}, {}, {});
    for (size_t i = 0; i < emptyFactors.size(); ++i) {
        writer->AddRow(yt::serialize(CatboostDatasetRow{
            .key = std::to_string(columnIdx++),
            .value = toTsv(std::make_tuple("Num", std::to_string(i)))}));
    }
}

} // namespace maps::mrc::eye::pov


int main(int argc, char* argv[]) try {
    NYT::Initialize(argc, argv);

    maps::cmdline::Parser parser("Calculates visibility factors for objects");
    auto featureTableYtPath = parser.string("feature-table")
        .required()
        .help("Path to table with features");

    auto pairTableYtPath = parser.string("pair-table")
        .required()
        .help("Path to table with features pairs");

    auto clusterFileYtPath = parser.string("cluster-file")
        .required()
        .help("Input yt json file with clusters");

    auto objectFileYtPath = parser.string("object-file")
        .required()
        .help("Input yt json file with detections");

    auto outputFactorsTableYtPath = parser.string("output-factors-table")
        .required()
        .help("Output table with result");

    auto outputColumnsDescriptionTableYtPath = parser.string("output-columns-description-table")
        .required()
        .help("Output table with result");

    parser.parse(argc, argv);

    auto client = NYT::CreateClient("hahn");

    namespace pov = maps::mrc::eye::pov;

    auto rowsCount = pov::getRowsCount(client, pairTableYtPath);
    constexpr int PAIRS_IN_JOB = 500;
    const size_t jobsCount = std::max(1, (int)rowsCount / PAIRS_IN_JOB);

    pov::writeColumnsDescriptions(client, outputColumnsDescriptionTableYtPath);

    client->MapReduce(
        NYT::TMapReduceOperationSpec()
        .AddInput<NYT::TNode>(TString(pairTableYtPath))
        .MapperSpec(NYT::TUserJobSpec()
            .AddFile(NYT::TRichYPath(TString(clusterFileYtPath)))
            .AddFile(NYT::TRichYPath(TString(objectFileYtPath)))
        )
        .ReducerSpec(NYT::TUserJobSpec()
            .AddFile(NYT::TRichYPath(TString(clusterFileYtPath)))
            .AddFile(NYT::TRichYPath(TString(objectFileYtPath)))
            .AddFile(NYT::TRichYPath(TString(featureTableYtPath)).Format("yson"))
        )
        .AddOutput<NYT::TNode>(TString(outputFactorsTableYtPath))
        .MapJobCount(jobsCount)
        .PartitionCount(10)
        .SortBy({pov::CLUSTER_ID_COLUMN})
        .ReduceBy({pov::CLUSTER_ID_COLUMN}),
        new pov::PredictObjectVisibilityMapper(
            pov::getFileName(objectFileYtPath),
            pov::getFileName(clusterFileYtPath)
        ),
        new pov::ComputeCluterVisibilityFactorsReducer(
            pov::getFileName(objectFileYtPath),
            pov::getFileName(clusterFileYtPath),
            pov::getFileName(featureTableYtPath)
        ),
        NYT::TOperationOptions()
            .Spec(pov::createPredictObjectVisibilityOperationSpec())
    );

    return EXIT_SUCCESS;
}
catch (const maps::Exception& e) {
    ERROR() << e;
    return EXIT_FAILURE;
}
catch (const yexception& e) {
    FATAL() << e.what();

    if (e.BackTrace()) {
        FATAL() << e.BackTrace()->PrintToString();
    }

    return EXIT_FAILURE;
}
catch (const std::exception& e) {
    ERROR() << e.what();
    return EXIT_FAILURE;
}
catch (...) {
    ERROR() << "unknown error";
    return EXIT_FAILURE;
}
