#include <maps/wikimap/mapspro/services/mrc/eye/lib/common/include/secure_config.h>
#include <maps/wikimap/mapspro/services/mrc/eye/lib/detection/include/superpoints_match.h>
#include <maps/wikimap/mapspro/services/mrc/eye/lib/detection/include/match.h>
#include <maps/wikimap/mapspro/services/mrc/eye/lib/detection/include/store.h>
#include <maps/wikimap/mapspro/services/mrc/eye/lib/detection/include/yt_serialization.h>

#include <maps/wikimap/mapspro/services/mrc/libs/common/include/types.h>
#include <maps/wikimap/mapspro/services/mrc/libs/common/include/exif.h>
#include <maps/wikimap/mapspro/services/mrc/libs/db/include/common.h>
#include <maps/wikimap/mapspro/services/mrc/libs/carsegm/include/carsegm.h>
#include <maps/wikimap/mapspro/services/mrc/libs/superpoint/include/superpoint.h>
#include <maps/wikimap/mapspro/services/mrc/libs/keypoints_matcher/include/superglue_matcher.h>
#include <maps/wikimap/mapspro/services/mrc/libs/superpoint/include/superpoint.h>
#include <maps/wikimap/mapspro/services/mrc/libs/yt/include/common.h>
#include <maps/wikimap/mapspro/services/mrc/libs/yt/include/serialization.h>

#include <maps/libs/http/include/http.h>
#include <maps/libs/log8/include/log8.h>
#include <maps/libs/common/include/exception.h>

#include <mapreduce/yt/interface/client.h>
#include <mapreduce/yt/util/temp_table.h>

#include <util/generic/size_literals.h>

#include <opencv2/opencv.hpp>

#include <filesystem>

namespace maps::mrc::eye {

namespace {

static const TString COLUMN_NAME_FRAME_ID    = "frame_id";
static const TString COLUMN_NAME_URL_CONTEXT = "url_context";
static const TString COLUMN_NAME_ORIENTATION = "orientation";

static const TString COLUMN_NAME_FRAME_ID0        = "frame_id0";
static const TString COLUMN_NAME_FRAME_ID1        = "frame_id1";
static const TString COLUMN_NAME_KEYPOINTS0       = "keypoints0";
static const TString COLUMN_NAME_KEYPOINTS1       = "keypoints1";
static const TString COLUMN_NAME_FRAME_MATCH_DATA = "frame_match_data";

constexpr int MAX_PTS_COUNT = 2500;

NYT::TNode createSuperpointOperationSpec(size_t ytConcurrency, bool useGpu, std::optional<yt::PoolType> ytPoolType) {
    static const std::string TITLE = "Superpoints Matcher";

    NYT::TNode operationSpec;
    if (useGpu) {
        operationSpec = yt::baseGpuOperationSpec(TITLE, ytPoolType)
            ("mapper", yt::baseGpuWorkerSpec()
                ("gpu_limit", 1)
                ("memory_limit", 16_GB)
            );
    } else {
        operationSpec = yt::baseCpuOperationSpec(TITLE, ytPoolType)
            ("mapper", NYT::TNode::CreateMap()
                ("cpu_limit", 1)
                ("memory_limit", 16_GB)
            );
    }

    operationSpec("resource_limits", NYT::TNode()("user_slots", ytConcurrency));

    operationSpec("max_failed_job_count", 30);

    return operationSpec;
}

NYT::TNode createFundMatOperationSpec(bool useGpu, std::optional<yt::PoolType> ytPoolType) {
    static const std::string TITLE = "[Superpoint matcher] Find fund matrix";

    if (useGpu) {
        return yt::baseCpuOperationSpec(TITLE, ytPoolType)
            ("mapper", yt::baseGpuWorkerSpec()
                ("cpu_limit", 1)
                ("memory_limit", 1_GB)
            );
    } else {
        return yt::baseCpuOperationSpec(TITLE, ytPoolType)
            ("mapper", NYT::TNode::CreateMap()
                ("cpu_limit", 1)
                ("memory_limit", 1_GB)
            );
    }
}

struct FrameImageInfo {
    json::Value urlContext{json::null};
    int16_t orientation;
};

std::map<db::TId, FrameImageInfo> collectFrameInfos(
    const DetectionStore& store,
    const std::vector<std::pair<db::TId, db::TId>>& frameIdPairs)
{
    std::map<db::TId, FrameImageInfo> result;
    for (size_t i = 0; i < frameIdPairs.size(); i++) {
        const std::pair<db::TId, db::TId>& framePair = frameIdPairs[i];
        ASSERT(framePair.first != framePair.second);
        if (result.find(framePair.first) == result.end()) {
            const db::eye::Frame& frame = store.frameById(framePair.first);
            result[framePair.first] = {
                .urlContext = frame.urlContext().json(),
                .orientation = (int16_t)frame.orientation()
            };
        }
        if (result.find(framePair.second) == result.end()) {
            const db::eye::Frame& frame = store.frameById(framePair.second);
            result[framePair.second] = {
                .urlContext = frame.urlContext().json(),
                .orientation = (int16_t)frame.orientation()
            };
        }
    }
    return result;
};

void uploadFrameImageInfo(
    NYT::IClientBasePtr client,
    const TString& frameTable,
    const std::map<db::TId, FrameImageInfo>& frameImageInfoById)
{
    NYT::TTableWriterPtr<NYT::TNode> writerFrame = client->CreateTableWriter<NYT::TNode>(frameTable);
    for (auto it = frameImageInfoById.cbegin(); it != frameImageInfoById.end(); it++) {
        const FrameImageInfo& imageInfo = it->second;
        writerFrame->AddRow(
            NYT::TNode()
            (COLUMN_NAME_FRAME_ID, NYT::TNode(it->first))
            (COLUMN_NAME_URL_CONTEXT, yt::serialize(imageInfo.urlContext))
            (COLUMN_NAME_ORIENTATION, NYT::TNode(imageInfo.orientation))
        );
    }
    writerFrame->Finish();
}

void uploadFrameIdPairs(
    NYT::IClientBasePtr client,
    const TString& framePairsTable,
    const std::vector<std::pair<db::TId, db::TId>>& frameIdPairs)
{
    NYT::TTableWriterPtr<NYT::TNode> writer = client->CreateTableWriter<NYT::TNode>(framePairsTable);
    for (size_t i = 0; i < frameIdPairs.size(); i++) {
        writer->AddRow(
            NYT::TNode()
                (COLUMN_NAME_FRAME_ID0, NYT::TNode(frameIdPairs[i].first))
                (COLUMN_NAME_FRAME_ID1, NYT::TNode(frameIdPairs[i].second))
        );
    }
    writer->Finish();
}

class TFrameKeypointMatcherMapper
    : public NYT::IMapper<NYT::TTableReader<NYT::TNode>, NYT::TTableWriter<NYT::TNode>> {
public:
    Y_SAVELOAD_JOB(frameTable_, frameLoader_);

    TFrameKeypointMatcherMapper() = default;

    TFrameKeypointMatcherMapper(const std::string& frameTable, const FrameLoader& frameLoader)
        : frameTable_(frameTable.c_str())
        , frameLoader_(frameLoader) {}

    void Do(NYT::TTableReader<NYT::TNode>* reader, NYT::TTableWriter<NYT::TNode>* writer) override {
        loadPairs(reader);
        loadFramesInfo();
        downloadImages();
        makeCarsMasks();
        detectKeypoints();
        matchPairs();
        writePairs(writer);
    }
private:
    /*
        изображение нужно дважды, для поиска маски машин и для поиска ключевых точек,
        хранится в сжатом виде. Это с одной стороны позволяет не качать его два раза
        с другой занимает существенно меньше памяти.

        Держать изображение мы вынуждены, потому что, вначале ищем маски машинок
        для всех изображений, затем ключевые точки (снова для всех изображений)
        Скачать, разжать и подсчитать и маски и точки вариант не очень хороший,
        потому что тогда будут для каждого изображения попеременно загружаться
        и вызываться две разные сетки на GPU.
    */
    struct FrameData {
        FrameImageInfo imageInfo;
        std::string encodedImage;
        double scale = 1.;
        cv::Mat scaledCarMask;
        common::Keypoints scaledKpts;
        cv::Mat decodedImage() const {
            return common::transformByImageOrientation(
                common::decodeImage(encodedImage),
                common::ImageOrientation::fromExif(imageInfo.orientation));
        }
    };

    TString frameTable_;
    FrameLoader frameLoader_;

    std::set<db::TId> usedFrameId_;
    std::vector<std::pair<db::TId, db::TId>> frameIdPairs_;
    std::map<db::TId, FrameData> frameDataById_;
    std::vector<MatchedKeypoints> matchedKeypoints_;

    // загружаем пары fid, объекты на которых будем матчить в frameIdPairs_
    // одновременно заполняем множество usedFrameId_
    void loadPairs(NYT::TTableReader<NYT::TNode>* reader) {
        for (; reader->IsValid(); reader->Next()) {
            const NYT::TNode &inpRow = reader->GetRow();
            db::TId fid0 = inpRow[COLUMN_NAME_FRAME_ID0].AsInt64();
            db::TId fid1 = inpRow[COLUMN_NAME_FRAME_ID1].AsInt64();
            usedFrameId_.insert(fid0);
            usedFrameId_.insert(fid1);
            frameIdPairs_.emplace_back(fid0, fid1);
        }
    }
    // загружаем информацию (url, orientation) о frames из таблицы frameTable_ в frameDataById_
    void loadFramesInfo() {
        TIFStream stream(frameTable_);
        NYT::TTableReaderPtr<NYT::TNode> frameReader = NYT::CreateTableReader<NYT::TNode>(&stream);
        for (; frameReader->IsValid(); frameReader->Next()) {
            const NYT::TNode &inpRow = frameReader->GetRow();
            FrameData data;
            db::TId fid = inpRow[COLUMN_NAME_FRAME_ID].AsInt64();
            if (0 == usedFrameId_.count(fid)) {
                continue;
            }
            data.imageInfo.urlContext = yt::deserialize<json::Value>(inpRow[COLUMN_NAME_URL_CONTEXT]);
            data.imageInfo.orientation = (int16_t)inpRow[COLUMN_NAME_ORIENTATION].AsInt64();
            frameDataById_[fid] = data;
        }
    }
    // загружаем frame по url
    void downloadImages() {
        for (auto it = frameDataById_.begin(); it != frameDataById_.end(); it++) {
            it->second.encodedImage = frameLoader_.loadRaw(it->second.imageInfo.urlContext);
        }
    }

    double getImageScale(const cv::Mat& image) const {
        static const int DESIRED_MAX_SIZE = 1024;

        int maxSize = std::max(image.cols, image.rows);
        if (maxSize > DESIRED_MAX_SIZE) {
            return DESIRED_MAX_SIZE / static_cast<double>(maxSize);
        } else {
            return 1.;
        }
    }

    // вычисляем маски машин
    void makeCarsMasks() {
        constexpr int MASK_DILATE_KERNEL_SZ = 5;
        static const cv::Mat kernel = cv::Mat::ones(MASK_DILATE_KERNEL_SZ, MASK_DILATE_KERNEL_SZ, CV_8UC1);

        const carsegm::CarSegmentator carSegmentator;
        for (auto it = frameDataById_.begin(); it != frameDataById_.end(); it++) {
            cv::Mat image = it->second.decodedImage();

            it->second.scale = getImageScale(image);
            if (it->second.scale != 1.0) {
                cv::resize(image, image, cv::Size(), it->second.scale, it->second.scale);
            }

            cv::dilate(carSegmentator.segment(image), it->second.scaledCarMask, kernel);
        }
    }
    // ищем ключевые точки удаляем маску машинок и сжатую картинку, потому что теперь
    // они больше не нужны, достаточно только ключевых точек
    void detectKeypoints() {
        const superpoint::SuperpointDetector sptsDetector;
        for (auto it = frameDataById_.begin(); it != frameDataById_.end(); it++) {
            cv::Mat image = it->second.decodedImage();

            if (it->second.scale != 1.0) {
                cv::resize(image, image, cv::Size(), it->second.scale, it->second.scale);
            }

            it->second.scaledKpts = filterPointsByMask(sptsDetector.detect(image), it->second.scaledCarMask);
            it->second.scaledCarMask.release();
            it->second.encodedImage.clear();
        }
    }
    void matchPairs() {
        constexpr double SUPERGLUE_CONFIDENCE_THRESHOLD = 0.65;
        keypoints_matcher::SuperglueMatcher matcher(SUPERGLUE_CONFIDENCE_THRESHOLD);

        for (size_t i = 0; i < frameIdPairs_.size(); i++) {
            const db::TId id0 = frameIdPairs_[i].first;
            const db::TId id1 = frameIdPairs_[i].second;
            const FrameData& data0 = frameDataById_[id0];
            const FrameData& data1 = frameDataById_[id1];

            MatchedKeypoints match
                = getKeypointsMatch(matcher, data0.scaledKpts, data0.scale, data1.scaledKpts, data1.scale);

            match.frameId0 = id0;
            match.frameId1 = id1;
            matchedKeypoints_.push_back(match);
        }
    }

    void writePairs(NYT::TTableWriter<NYT::TNode>* writer) {
        for (size_t i = 0; i < matchedKeypoints_.size(); i++) {
            const MatchedKeypoints& match = matchedKeypoints_[i];
            writer->AddRow(
                NYT::TNode()
                (COLUMN_NAME_FRAME_ID0, match.frameId0)
                (COLUMN_NAME_FRAME_ID1, match.frameId1)
                (COLUMN_NAME_KEYPOINTS0, yt::serialize(match.points0))
                (COLUMN_NAME_KEYPOINTS1, yt::serialize(match.points1))
            );
        }
    }
};

REGISTER_MAPPER(TFrameKeypointMatcherMapper);

class TFrameFundMatMapper
    : public NYT::IMapper<NYT::TTableReader<NYT::TNode>, NYT::TTableWriter<NYT::TNode>> {
public:
    TFrameFundMatMapper() = default;

    void Do(NYT::TTableReader<NYT::TNode>* reader, NYT::TTableWriter<NYT::TNode>* writer) override {
        loadPairs(reader);
        matchPairs();
        writePairs(writer);
    }
private:
    std::vector<MatchedKeypoints> matchedKeypoints_;
    MatchedFramesPairs matchedFrames_;

    void loadPairs(NYT::TTableReader<NYT::TNode>* reader) {
        for (; reader->IsValid(); reader->Next()) {
            const NYT::TNode &inpRow = reader->GetRow();

            matchedKeypoints_.push_back(MatchedKeypoints{
                inpRow[COLUMN_NAME_FRAME_ID0].AsInt64(),
                inpRow[COLUMN_NAME_FRAME_ID1].AsInt64(),
                yt::deserialize<std::vector<cv::Point2f>>(inpRow[COLUMN_NAME_KEYPOINTS0]),
                yt::deserialize<std::vector<cv::Point2f>>(inpRow[COLUMN_NAME_KEYPOINTS1])
            });
        }
    }
    void matchPairs() {
        for (size_t i = 0; i < matchedKeypoints_.size(); i++) {
            std::optional<FramesMatchData> match = getFrameMatch(
                std::move(matchedKeypoints_[i].points0),
                std::move(matchedKeypoints_[i].points1)
            );

            if (!match.has_value()) {
                continue;
            }
            matchedFrames_.push_back(
                MatchedFramesPair {
                    .id0 = matchedKeypoints_[i].frameId0,
                    .id1 = matchedKeypoints_[i].frameId1,
                    .match = std::move(match.value())
                }
            );
        }
    }

    void writePairs(NYT::TTableWriter<NYT::TNode>* writer) {
        for (size_t i = 0; i < matchedFrames_.size(); i++) {
            const MatchedFramesPair& match = matchedFrames_[i];
            writer->AddRow(
                NYT::TNode()
                (COLUMN_NAME_FRAME_ID0, match.id0)
                (COLUMN_NAME_FRAME_ID1, match.id1)
                (COLUMN_NAME_FRAME_MATCH_DATA, yt::serialize(match.match))
            );
        }
    }
};

REGISTER_MAPPER(TFrameFundMatMapper);

MatchedFramesPairs readFrameMatches(const NYT::TTableReaderPtr<NYT::TNode>& reader) {
    MatchedFramesPairs matches;
    for (; reader->IsValid(); reader->Next()) {
        const NYT::TNode& inpRow = reader->GetRow();
        matches.push_back(MatchedFramesPair{
            .id0 = inpRow[COLUMN_NAME_FRAME_ID0].AsInt64(),
            .id1 = inpRow[COLUMN_NAME_FRAME_ID1].AsInt64(),
            .match = yt::deserialize<FramesMatchData>(
                inpRow[COLUMN_NAME_FRAME_MATCH_DATA])});
    }
    return matches;
}

/*
    Фукция одинаковости фотографий.
    На вход должны быть поданы два пути к таблицам на YT
        1. frameTable - таблица с информацией о картинках, формат
            ----------------------------------
            | frame_id   | url | orientation |
            ----------------------------------
        2. framePairsTable - таблица пар feature, для которых надо оценивать
        одинаковость пар однотипных объектов
            -----------------------------
            | frame_id0   | frame_id1   |
            -----------------------------
    В результате отдаётся набор пар < frame_id0, frame_id1>, с оценкой "одинаковости".
*/
MatchedFramesPairs matchFrames(
    const NYT::IClientBasePtr& client,
    const TString& frameTable,
    const TString& framePairsTable,
    const FrameLoader& frameLoader,
    size_t pairsCount,
    bool useGpu,
    std::optional<yt::PoolType> ytPoolType,
    size_t ytConcurrency)
{
    constexpr int PAIRS_IN_JOB = 500;

    const size_t jobsCount = std::max(1, (int)pairsCount / PAIRS_IN_JOB);
    const NYT::TTempTable matchTable(client);

    REQUIRE(SecureConfig::isInitialized(), "Secure config was not initialized");

    const NYT::TTempTable keypointsTable(client);
    client->Map(
        NYT::TMapOperationSpec()
        .AddInput<NYT::TNode>(framePairsTable)
        .MapperSpec(NYT::TUserJobSpec()
            .AddFile(NYT::TRichYPath(frameTable).Format("yson"))
        )
        .AddOutput<NYT::TNode>(keypointsTable.Name())
        .JobCount(jobsCount),
        new TFrameKeypointMatcherMapper(
            std::filesystem::path(std::string(frameTable)).filename().string(),
            frameLoader),
        NYT::TOperationOptions()
            .Spec(createSuperpointOperationSpec(ytConcurrency, useGpu, ytPoolType))
            .SecureVault(SecureConfig::instance())
    );

    client->Map(
        NYT::TMapOperationSpec()
        .AddInput<NYT::TNode>(keypointsTable.Name())
        .AddOutput<NYT::TNode>(matchTable.Name())
        .JobCount(jobsCount),
        new TFrameFundMatMapper(),
        NYT::TOperationOptions()
            .Spec(createFundMatOperationSpec(useGpu, ytPoolType))
    );

    NYT::TTableReaderPtr<NYT::TNode> reader = client->CreateTableReader<NYT::TNode>(matchTable.Name());
    return readFrameMatches(reader);
}
} //namespace


common::Keypoints filterPointsByMask(const common::Keypoints& kpts, const cv::Mat& mask)
{
    const int ptsCnt = kpts.scores.size();
    REQUIRE(ptsCnt == kpts.coords.rows, "Different amount of items in scores and coordinates of keypoints structure");
    REQUIRE(ptsCnt == kpts.descriptors.rows, "Different amount of items in scores and descriptors of keypoints structure");
    if (0 == ptsCnt) {
        return kpts;
    }

    const int descriptorSize = kpts.descriptors.cols;

    common::Keypoints results;
    results.imageWidth = kpts.imageWidth;
    results.imageHeight = kpts.imageHeight;
    results.coords.create(std::min(ptsCnt, MAX_PTS_COUNT), 1, CV_32FC2);
    results.descriptors.create(std::min(ptsCnt, MAX_PTS_COUNT), descriptorSize, CV_32FC1);

    std::vector<size_t> idxSorted(ptsCnt);
    std::iota(idxSorted.begin(), idxSorted.end(), 0);
    std::stable_sort(idxSorted.begin(), idxSorted.end(),
        [&kpts](size_t i1, size_t i2) {return kpts.scores[i1] > kpts.scores[i2];});

    const int rows = mask.rows;
    const int cols = mask.cols;
    size_t outIdx = 0;
    for (int i = 0; (i < ptsCnt) && (outIdx < MAX_PTS_COUNT); i++) {
        const size_t idx = idxSorted[i];
        const cv::Vec2f& coords2f = kpts.coords.at<cv::Vec2f>(idx, 0);

        int col = (int)floor(coords2f[0]);
        int row = (int)floor(coords2f[1]);
        if (col >= cols || row >= rows) {
            continue;
        }
        if (mask.at<uint8_t>(row, col) != 0) {
            // на машинах маска равна 1 и точку надо отбросить
            continue;
        }
        results.scores.push_back(kpts.scores[idx]);
        results.coords.at<cv::Vec2f>(outIdx, 0) = coords2f;
        std::copy_n(kpts.descriptors.ptr<float>(idx), descriptorSize, results.descriptors.ptr<float>(outIdx));
        outIdx++;
    }
    if (0 == outIdx) {
        results.coords.release();
        results.descriptors.release();
    } else {
        results.coords = results.coords.rowRange(0, outIdx);
        results.descriptors = results.descriptors.rowRange(0, outIdx);
    }
    return results;
}

std::optional<FramesMatchData> getFrameMatch(
    std::vector<cv::Point2f> points0,
    std::vector<cv::Point2f> points1)
{
    constexpr size_t FIND_FUND_MAT_PTS_CNT_MIN = 7;
    constexpr int FUND_MAT_GOOD_PTS_CNT_MIN = 15;
    constexpr double RANSAC_REPROJ_THRESHOLD = 3.;
    constexpr double RANSAC_CONFIDENCE = 0.99;

    REQUIRE(points0.size() == points1.size(), "Different count keypoints on frames");

    const size_t ptsPairsCnt = points0.size();
    if (ptsPairsCnt <= FIND_FUND_MAT_PTS_CNT_MIN) {
        return std::nullopt;
    }

    std::vector<uint8_t> goodPtsMask;
    FramesMatchData match;
    match.fundMatrix = cv::findFundamentalMat(points0, points1, cv::FM_RANSAC, RANSAC_REPROJ_THRESHOLD, RANSAC_CONFIDENCE, goodPtsMask);
    match.goodPtsCnt = cv::sum(goodPtsMask)[0];
    match.ptsCnt0 = points0.size();
    match.ptsCnt1 = points1.size();
    if (match.goodPtsCnt < FUND_MAT_GOOD_PTS_CNT_MIN) {
        return std::nullopt;
    }

    for (int i = (int)ptsPairsCnt - 1; i >= 0; i--) {
        if (goodPtsMask[i] != 0) {
            continue;
        }
        points0.erase(points0.begin() + i);
        points1.erase(points1.begin() + i);
    }
    cv::convexHull(points0, match.hull0, false, true);
    cv::convexHull(points1, match.hull1, false, true);
    return match;
}


FrameSuperpointsMatcher::FrameSuperpointsMatcher(
    NYT::IClientBasePtr client,
    const FrameLoader& frameLoader,
    bool useGpu,
    std::optional<yt::PoolType> ytPoolType,
    size_t ytConcurrency)
    : client_(client)
    , frameLoader_(frameLoader)
    , useGpu_(useGpu)
    , ytPoolType_(ytPoolType)
    , ytConcurrency_(ytConcurrency) {}

MatchedFramesPairs FrameSuperpointsMatcher::makeMatches(
    const DetectionStore& store,
    const std::vector<std::pair<db::TId, db::TId>>& frameIdPairs) const
{
    if (0 == frameIdPairs.size()) {
        return {};
    }
    std::map<db::TId, FrameImageInfo> frameImageInfoById = collectFrameInfos(store, frameIdPairs);

    const NYT::TTempTable frameTable(client_);
    uploadFrameImageInfo(client_, frameTable.Name(), frameImageInfoById);
    const NYT::TTempTable framePairsTable(client_);
    uploadFrameIdPairs(client_, framePairsTable.Name(), frameIdPairs);
    return matchFrames(client_,
        frameTable.Name(),
        framePairsTable.Name(),
        frameLoader_,
        frameIdPairs.size(),
        useGpu_,
        ytPoolType_,
        ytConcurrency_);
}

namespace {

struct DetectionInfo {
    DetectionInfo(db::TId _id, const cv::Rect& _bbox)
        : id(_id)
        , bbox(_bbox) {}

    DetectionInfo(const DetectionStore& store, const db::eye::Detection& detection)
        : id(detection.id()) {
        const db::eye::Frame& frame = store.frameByDetectionId(id);
        bbox = common::transformByImageOrientation(detection.box(), frame.originalSize(), frame.orientation());
    }
    db::TId id;
    cv::Rect bbox; // координаты на картинке уже сориентированной правильно: небом вверх
};

using DetectionInfos = std::vector<DetectionInfo>;
using PairDetectionInfos = std::pair<DetectionInfos, DetectionInfos>;

cv::Mat calculateFundamentalErrors(const cv::Mat& F, const PairDetectionInfos& pairDetectionInfos) {
    const DetectionInfos& infos0 = pairDetectionInfos.first;
    const DetectionInfos& infos1 = pairDetectionInfos.second;

    cv::Mat result((int)infos0.size(), (int)infos1.size(), CV_32FC1);
    cv::Mat center0 = cv::Mat::ones(3, 1, CV_64FC1);
    cv::Mat center1 = cv::Mat::ones(3, 1, CV_64FC1);
    for (size_t i = 0; i < infos0.size(); i++) {
        const cv::Rect& bbox0 = infos0[i].bbox;

        center0.at<double>(0, 0) = bbox0.x + bbox0.width / 2.;
        center0.at<double>(1, 0) = bbox0.y + bbox0.height / 2.;
        float *ptr = result.ptr<float>(i);
        for (size_t j = 0; j < infos1.size(); j++) {
            const cv::Rect& bbox1 = infos1[j].bbox;
            center1.at<double>(0, 0) = bbox1.x + bbox1.width / 2.;
            center1.at<double>(1, 0) = bbox1.y + bbox1.height / 2.;
            ptr[j] = sqrt(cv::sampsonDistance(center0, center1, F));
        }
    }
    return result;
}

double getHullDistance(const cv::Rect& bbox, const std::vector<cv::Point2f>& hull)
{
    const cv::Point2f center = (cv::Point2f)(bbox.tl() + bbox.br()) / 2.f;
    return cv::pointPolygonTest(hull, center, true);
}

double confidenceFunction(
    const DetectionInfo& info0,
    const DetectionInfo& info1,
    double fundErr,
    const std::vector<cv::Point2f>& hull0,
    const std::vector<cv::Point2f>& hull1)
{
    constexpr double MIN_FUND_ERR = 0.01;

    const double hullDistance0 = getHullDistance(info0.bbox, hull0);
    const double hullDistance1 = getHullDistance(info1.bbox, hull1);

    const double h0 = (hullDistance0 >= -M_E) ? 1. : log(abs(hullDistance0));
    const double h1 = (hullDistance1 >= -M_E) ? 1. : log(abs(hullDistance1));
    return MIN_FUND_ERR / std::max(MIN_FUND_ERR, fundErr) * (1. / h0 / h1);
}

std::string detectionTypeAsString(const db::eye::Detection& detection, db::eye::DetectionType type) {
    std::string result(toString(type));
    switch (type) {
    case db::eye::DetectionType::HouseNumber:
        result += "-" + detection.attrs<db::eye::DetectedHouseNumber>().number;
        break;
    case db::eye::DetectionType::Sign:
    {
        const db::eye::DetectedSign& sign = detection.attrs<db::eye::DetectedSign>();
        result += "-" + traffic_signs::toString(sign.type)
            + "-" + (sign.temporary ? "temporary" : "constant");
        break;
    }
    case db::eye::DetectionType::TrafficLight:
        break;
    case db::eye::DetectionType::RoadMarking:
    {
        const auto roadMarking = detection.attrs<db::eye::DetectedRoadMarking>();
        result += "-" + traffic_signs::toString(roadMarking.type);
        break;
    }
    default:
        throw RuntimeError() << "Unsupported detection type - " << type;
    }
    return result;
}

void fillDetectionInfos(
    const DetectionStore& store,
    const std::set<db::TId>& detectionIds,
    int pairElem,
    std::vector<PairDetectionInfos>& result,
    std::map<std::string, size_t>& typeToIndex)
{
    for (auto it = detectionIds.begin(); it != detectionIds.end(); it++) {
        const db::TId id = *it;
        const db::eye::Detection& detection = store.detectionById(id);
        const std::string key = detectionTypeAsString(detection, store.getDetectionType(id));
        size_t idx = -1;
        auto itmap = typeToIndex.find(key);
        if (itmap == typeToIndex.end()) {
            idx = result.size();
            result.emplace_back(PairDetectionInfos());
            typeToIndex[key] = idx;
        } else {
            idx = itmap->second;
        }
        if (0 == pairElem) {
            result[idx].first.emplace_back(store, detection);
        } else {
            result[idx].second.emplace_back(store, detection);
        }
    }
}

std::vector<PairDetectionInfos> getPairDetectionInfos(
    const DetectionStore& store,
    const std::set<db::TId>& detectionIds0,
    const std::set<db::TId>& detectionIds1)
{
    std::vector<PairDetectionInfos> result;
    std::map<std::string, size_t> typeToIndex;
    fillDetectionInfos(store, detectionIds0, 0, result, typeToIndex);
    fillDetectionInfos(store, detectionIds1, 1, result, typeToIndex);
    result.erase(
        std::remove_if(result.begin(), result.end(),
            [&](const PairDetectionInfos& info) {
                return info.first.empty() || info.second.empty();
            }
        ),
        result.end()
    );
    return result;
}

struct MatchedDetection {
    db::TId id0;
    db::TId id1;
    float confidence;
    double sampsonDistance;
    double hullDistance0;
    double hullDistance1;
};

std::vector<MatchedDetection> getAllMatched(
    const PairDetectionInfos& pairDetectionInfos,
    const cv::Mat& fundErrs,
    float errThreshold,
    const std::vector<cv::Point2f>& hull0,
    const std::vector<cv::Point2f>& hull1)
{
    const DetectionInfos& infos0 = pairDetectionInfos.first;
    const DetectionInfos& infos1 = pairDetectionInfos.second;

    std::vector<MatchedDetection> results;
    for (size_t i = 0; i < infos0.size(); i++) {
        MatchedDetection matchedDetection;
        matchedDetection.id0 = infos0[i].id;
        const float *ptr = fundErrs.ptr<float>(i);
        for (size_t j = 0; j < infos1.size(); j++) {
            const float err = ptr[j];
            if (err < errThreshold) {
                matchedDetection.id1 = infos1[j].id;
                matchedDetection.confidence = confidenceFunction(infos0[i], infos1[j], err, hull0, hull1);
                matchedDetection.sampsonDistance = err;
                matchedDetection.hullDistance0 = getHullDistance(infos0[i].bbox, hull0);
                matchedDetection.hullDistance1 = getHullDistance(infos1[j].bbox, hull1);
                results.emplace_back(matchedDetection);
            }
        }
    }
    return results;
}

std::vector<MatchedDetection> getMutualMatched(
    const PairDetectionInfos& pairDetectionInfos,
    const cv::Mat& fundErrs,
    const std::vector<cv::Point2f>& hull0,
    const std::vector<cv::Point2f>& hull1)
{
    const int rows = fundErrs.rows;
    const int cols = fundErrs.cols;

    std::vector<int> rowMinIdx(rows, -1); // индекс (колонка) минимума для строки
    std::vector<int> colMinIdx(cols, -1); // индекс (строка) минимума для колонки
    std::vector<float> colMin(cols, FLT_MAX);
    for (int row = 0; row < rows; row++) {
        const float* ptr = fundErrs.ptr<float>(row);
        float rowMin = FLT_MAX;
        for (int col = 0; col < cols; col++) {
            if (ptr[col] < rowMin) {
                rowMinIdx[row] = col;
                rowMin = ptr[col];
            }
            if (ptr[col] < colMin[col]) {
                colMinIdx[col] = row;
                colMin[col] = ptr[col];
            }
        }
    }

    const DetectionInfos& infos0 = pairDetectionInfos.first;
    const DetectionInfos& infos1 = pairDetectionInfos.second;
    REQUIRE(rows == (int)infos0.size() && cols == (int)infos1.size(),
        "Invalid size of fundamental errors matrix");

    std::vector<MatchedDetection> results;
    for (size_t idx0 = 0; idx0 < rowMinIdx.size(); idx0++) {
        const int idx1 = rowMinIdx[idx0];
        if (-1 == idx1 || colMinIdx[idx1] != (int)idx0) {
            continue;
        }
        MatchedDetection matchedDetection;
        matchedDetection.id0 = infos0[idx0].id;
        matchedDetection.id1 = infos1[idx1].id;
        matchedDetection.confidence = confidenceFunction(infos0[idx0], infos1[idx1], fundErrs.at<float>(idx0, idx1), hull0, hull1);
        matchedDetection.sampsonDistance = fundErrs.at<float>(idx0, idx1);
        matchedDetection.hullDistance0 = getHullDistance(infos0[idx0].bbox, hull0);
        matchedDetection.hullDistance1 = getHullDistance(infos1[idx1].bbox, hull1);
        results.push_back(matchedDetection);
    }
    return results;
}

std::vector<MatchedDetection> getMatchedDetections(
    const FramesMatchData& matchedFramesPair,
    const PairDetectionInfos& pairDetectionInfos)
{
    // Есть три варианта:
    //   1. Просто собрать все пары с расстояниями определенным матрицей dist
    //   2. Аналогично п.1, но отбросить пары с расстоянием больше порога
    //   3. Взять только взаимно-ближайшие
    // пока реализуем п.2, но выставив DISTANCE_THREHOLD = FLT_MAX, т.е.
    // фактически получается п.1. Для варианта 3, надо в этих циклах заполнить матрицу:
    //cv::Mat dist = cv::Mat::zeros(detections0.size(), detections1.size(), CV_32FC1);
    // и затем применить процедуру аналогичную nn_mutual, для поиска взаимно-ближайших
    constexpr bool MUTUAL_MATCHES_ONLY = true;
    constexpr float DISTANCE_THREHOLD = FLT_MAX;

    REQUIRE(0 < pairDetectionInfos.first.size() && 0 < pairDetectionInfos.second.size(),
        "One or both vector of detections is empty");

    cv::Mat fundErrs = calculateFundamentalErrors(matchedFramesPair.fundMatrix, pairDetectionInfos);
    if (MUTUAL_MATCHES_ONLY) {
        return getMutualMatched(pairDetectionInfos, fundErrs, matchedFramesPair.hull0, matchedFramesPair.hull1);
    }
    return getAllMatched(pairDetectionInfos, fundErrs, DISTANCE_THREHOLD, matchedFramesPair.hull0, matchedFramesPair.hull1);
}

std::vector<MatchedDetection> getMatchedDetections(
    const MatchedFramesPair& matchedFramesPair,
    const std::vector<PairDetectionInfos>& pairDetectionInfos)
{
    std::vector<MatchedDetection> results;
    for (size_t i = 0; i < pairDetectionInfos.size(); i++) {
        const std::vector<MatchedDetection> temp = getMatchedDetections(matchedFramesPair.match, pairDetectionInfos[i]);
        if (0 != temp.size()) {
            results.insert(results.end(), temp.begin(), temp.end());
        }
    }
    return results;
}

std::vector<std::pair<db::TId, db::TId>> getFramePairs(
    const DetectionStore& store,
    const DetectionIdPairSet& detectionPairs)
{
    std::set<std::pair<db::TId, db::TId>> result;
    for (const auto& [detectionId0, detectionId1] : detectionPairs) {
        REQUIRE(detectionId0 != detectionId1,
            "Unable to match detection with itself");

        db::TId frameId0 = store.frameId(detectionId0);
        db::TId frameId1 = store.frameId(detectionId1);
        if (frameId0 == frameId1) {
            continue;
        }
        result.emplace(std::min(frameId0, frameId1), std::max(frameId0, frameId1));
    }
    return {result.begin(), result.end()};
}

std::map<db::TId, std::set<db::TId>> mapFrameIdToDetectionId(
    const DetectionStore& store,
    const DetectionIdPairSet& detectionPairs)
{
    std::map<db::TId, std::set<db::TId>> frameIdToDetectionIds;
    for (const auto& [detectionId0, detectionId1] : detectionPairs) {
        frameIdToDetectionIds[store.frameId(detectionId0)].emplace(detectionId0);
        frameIdToDetectionIds[store.frameId(detectionId1)].emplace(detectionId1);
    }
    return frameIdToDetectionIds;
}

MatchedFrameDetections match(
    const DetectionStore& store,
    const DetectionIdPairSet& detectionPairs,
    const MatchedFramesPairs& matchedFrames)
{
    std::map<db::TId, std::set<db::TId>> frameIdToDetectionIds = mapFrameIdToDetectionId(store, detectionPairs);
    MatchedFrameDetections result;
    for (size_t idx = 0; idx < matchedFrames.size(); idx++) {
        const MatchedFramesPair& match = matchedFrames[idx];
        const std::set<db::TId>& detectionIds0 = frameIdToDetectionIds[match.id0];
        const std::set<db::TId>& detectionIds1 = frameIdToDetectionIds[match.id1];
        if (detectionIds0.size() == 0 || detectionIds1.size() == 0) {
            continue;
        }
        std::vector<PairDetectionInfos> pairDetections = getPairDetectionInfos(store, detectionIds0, detectionIds1);

        const float pairImagesConfidence = (float)match.match.goodPtsCnt / MAX_PTS_COUNT;
        std::vector<MatchedDetection> temp = getMatchedDetections(match, pairDetections);
        for (size_t i = 0; i < temp.size(); i++) {
            const MatchedDetection& md = temp[i];

            if (detectionPairs.count({md.id0, md.id1}) == 0 &&
                detectionPairs.count({md.id1, md.id0}) == 0)
            {
                continue;
            }

            result.emplace_back(
                FrameDetectionId{match.id0, md.id0},
                FrameDetectionId{match.id1, md.id1},
                md.confidence * pairImagesConfidence,
                std::nullopt, // verdict
                MatchedData{.goodPtsCnt = match.match.goodPtsCnt,
                 .sampsonDistance = md.sampsonDistance,
                 .fundMatrix = match.match.fundMatrix,
                 .hullDistance0 = md.hullDistance0,
                 .hullDistance1 = md.hullDistance1}
            );
        }
    }

    return result;
}

} // namespace

MatchedKeypoints getKeypointsMatch(
    const keypoints_matcher::SuperglueMatcher& matcher,
    const common::Keypoints& kpts0, double scale0,
    const common::Keypoints& kpts1, double scale1)
{
    keypoints_matcher::MatchedPairs matchedIndices = matcher.match(kpts0, kpts1);
    const size_t ptsPairsCnt = matchedIndices.size();

    MatchedKeypoints match;

    match.points0.resize(ptsPairsCnt);
    match.points1.resize(ptsPairsCnt);
    for (size_t i = 0; i < ptsPairsCnt; i++) {
        const std::pair<size_t, size_t>& matched = matchedIndices[i];
        match.points0[i] = kpts0.coords.at<cv::Vec2f>(matched.first, 0) / scale0;
        match.points1[i] = kpts1.coords.at<cv::Vec2f>(matched.second, 0) / scale1;
    }

    return match;
}


MatchedFrameDetections DetectionSuperpointsMatcher::makeMatches(
    const DetectionStore& store,
    const DetectionIdPairSet& detectionPairs,
    const FrameMatcher* frameMatcherPtr) const
{
    if (0 == detectionPairs.size()) {
        return {};
    }

    REQUIRE(frameMatcherPtr, "Frames matcher must be defined");

    const auto matchedFrames = frameMatcherPtr->makeMatches(
        store, getFramePairs(store, detectionPairs));

    return match(store, detectionPairs, matchedFrames);
}

} // namespace maps::mrc::eye
