#include <maps/wikimap/mapspro/services/mrc/eye/lib/detection/include/triangles_clusterizer.h>
#include <maps/wikimap/mapspro/services/mrc/eye/lib/detection/impl/cluster_store.h>

#include <maps/wikimap/mapspro/services/mrc/libs/db/include/eye/frame.h>

namespace maps::mrc::eye {

namespace {

constexpr int TRIANGLES_GOOD_CNT_THR = 40;
constexpr int MATCHES_GOOD_CNT_THR = 40;
constexpr float  SAMPSON_DIST_THR = 4.f;

struct Component {
    db::TIdSet detectionIds;
    std::map<DetectionIdPair, int> matchNumbers;
    // (индексы + 1) или -(индекс +1) в векторе матчей,
    // отрицательный соответствует, обратному
    // порядку detectionId в паре матчей
    // мы используем пару detectionId вместо пары FrameDetectionId,
    // потому что detectionId сами по себе уникальны, а FrameDetectionId нам
    // нужен чтобы не склеить две детекции на одном фрейме
};
using Components = std::vector<Component>;

struct Triangle {
    std::array<FrameDetectionId, 3> fdIds;
    cv::Mat F01;
    cv::Mat F02;
    cv::Mat F12;
    int minGoodPtsCnt;
};
using Triangles = std::vector<Triangle>;

cv::Rect getBbox(
    const DetectionStore& store,
    const FrameDetectionId& fdId)
{
    const db::eye::Frame& frame = store.frameById(fdId.frameId);
    const db::eye::Detection& detection = store.detectionById(fdId.detectionId);
    return transformByImageOrientation(
        detection.box(),
        frame.originalSize(),
        frame.orientation());
}

cv::Point2f getCenter(const cv::Rect& bbox) {
    return (cv::Point2f)(bbox.tl() + bbox.br()) / 2.f;
}

cv::Point3f getEpiline(const cv::Point2f& ptf, int whichImage, const cv::Mat& F)
{
    std::vector<cv::Point3f> temp;
    cv::computeCorrespondEpilines(std::vector<cv::Point2f>({ptf}), whichImage, F, temp);
    return temp[0];
}

std::optional<cv::Point2f> lineIntersection(
    const cv::Vec3f& line0,
    const cv::Vec3f& line1)
{
    float det = line0[0] * line1[1] - line0[1] * line1[0];
    if (abs(det) < 1e-3) {
        return std::nullopt;
    }
    cv::Point2f ptf;
    ptf.x = (line0[1] * line1[2] - line1[1] * line0[2]) / det;
    ptf.y = (line1[0] * line0[2] - line0[0] * line1[2]) / det;
    return ptf;
}

bool pointBboxTest(const cv::Point2f& pt, const cv::Rect& bbox)
{
    return (pt.x >= bbox.x &&
            pt.y >= bbox.y &&
            pt.x <= bbox.x + bbox.width &&
            pt.y <= bbox.y + bbox.height);
}

std::vector<int> sortedMatchedIndicies(
    const Component& component,
    const MatchedFrameDetections& matches)
{
    std::vector<int> indices(component.matchNumbers.size());
    int i = 0;
    for (auto it = component.matchNumbers.begin();
              it != component.matchNumbers.end();
              it++) {
        indices[i] = abs(it->second) - 1;
        i++;
    }
    std::sort(indices.begin(), indices.end(),
        [&](const auto& lhs, const auto& rhs) {
            return matches[lhs].data()->goodPtsCnt > matches[rhs].data()->goodPtsCnt;
        }
    );
    return indices;
}

size_t findComponentIndex(
    db::TId detectionId,
    const Components& components)
{
    for (size_t i = 0; i < components.size(); i++) {
        if (components[i].detectionIds.count(detectionId)) {
            return i;
        }
    }
    return -1;
}

Components splitComponents(const MatchedFrameDetections& matches)
{
    Components components;

    for (size_t idx = 0; idx < matches.size(); idx++) {
        const MatchedFrameDetection& match = matches[idx];
        const FrameDetectionId& fdId0 = match.id0();
        const FrameDetectionId& fdId1 = match.id1();

        ASSERT(fdId0.frameId != fdId1.frameId);

        int cmpIdx0 = findComponentIndex(fdId0.detectionId, components);
        int cmpIdx1 = findComponentIndex(fdId1.detectionId, components);
        if (-1 == cmpIdx0 && -1 == cmpIdx1) {
            Component component;
            component.detectionIds.emplace(fdId0.detectionId);
            component.detectionIds.emplace(fdId1.detectionId);
            component.matchNumbers[{fdId0.detectionId, fdId1.detectionId}] = (int)(idx + 1);
            component.matchNumbers[{fdId1.detectionId, fdId0.detectionId}] = -(int)(idx + 1);
            components.emplace_back(component);
        } else if (-1 != cmpIdx0 && -1 == cmpIdx1) {
            Component& component = components[cmpIdx0];
            component.detectionIds.emplace(fdId1.detectionId);
            component.matchNumbers[{fdId0.detectionId, fdId1.detectionId}] = (int)(idx + 1);
            component.matchNumbers[{fdId1.detectionId, fdId0.detectionId}] = -(int)(idx + 1);
        } else if (-1 == cmpIdx0 && -1 != cmpIdx1) {
            Component& component = components[cmpIdx1];
            component.detectionIds.emplace(fdId0.detectionId);
            component.matchNumbers[{fdId0.detectionId, fdId1.detectionId}] = (int)(idx + 1);
            component.matchNumbers[{fdId1.detectionId, fdId0.detectionId}] = -(int)(idx + 1);
        } else { // if (-1 != cmpIdx0 && -1 != cmpIdx1)
            Component& component0 = components[cmpIdx0];
            component0.matchNumbers[{fdId0.detectionId, fdId1.detectionId}] = (int)(idx + 1);
            component0.matchNumbers[{fdId1.detectionId, fdId0.detectionId}] = -(int)(idx + 1);
             if (cmpIdx0 != cmpIdx1) {
                const Component& component1 = components[cmpIdx1];
                component0.detectionIds.insert(
                    component1.detectionIds.begin(),
                    component1.detectionIds.end());
                component0.matchNumbers.insert(
                    component1.matchNumbers.begin(),
                    component1.matchNumbers.end());
                components.erase(components.begin() + cmpIdx1);
            }
        }
    }
    return components;
}

/*
    Треугольники возвращаются отсортированными по убыванию поля minGoodPtsCnt
*/
Triangles extractSortedTriangles(
    const Component& component,
    const MatchedFrameDetections& matches)
{
    if (component.detectionIds.size() < 3) {
        return {};
    }

    Triangles triangles;
    auto end = component.detectionIds.end();
    auto prevEnd = std::prev(end);
    auto prevPrevEnd = std::prev(prevEnd);
    for (auto it0 = component.detectionIds.begin(); it0 != prevPrevEnd; it0++) {
        for (auto it1 = std::next(it0); it1 != prevEnd; it1++) {
            auto itMatchNum01 = component.matchNumbers.find({*it0, *it1});
            if (itMatchNum01 == component.matchNumbers.end()) {
                continue;
            }
            const MatchedFrameDetection& match01 = matches[abs(itMatchNum01->second) - 1];
            if (match01.data()->goodPtsCnt < TRIANGLES_GOOD_CNT_THR)
            {
                continue;
            }
            for (auto it2 = std::next(it1); it2 != end; it2++) {
                auto itMatchNum02 = component.matchNumbers.find({*it0, *it2});
                if (itMatchNum02 == component.matchNumbers.end()) {
                    continue;
                }
                const MatchedFrameDetection& match02 = matches[abs(itMatchNum02->second) - 1];
                if (match02.data()->goodPtsCnt < TRIANGLES_GOOD_CNT_THR)
                {
                    continue;
                }

                auto itMatchNum12 = component.matchNumbers.find({*it1, *it2});
                if (itMatchNum12 == component.matchNumbers.end()) {
                    continue;
                }
                const MatchedFrameDetection& match12 = matches[abs(itMatchNum12->second) - 1];
                if (match12.data()->goodPtsCnt < TRIANGLES_GOOD_CNT_THR)
                {
                    continue;
                }

                Triangle triangle;
                ASSERT(match01.data() && match02.data() && match12.data());

                triangle.fdIds[0] = (itMatchNum01->second > 0) ? match01.id0() :  match01.id1();
                ASSERT(triangle.fdIds[0] == ((itMatchNum02->second > 0) ? match02.id0() :  match02.id1()));
                ASSERT(triangle.fdIds[0].detectionId == *it0);

                triangle.fdIds[1] = (itMatchNum01->second > 0) ? match01.id1() :  match01.id0();
                ASSERT(triangle.fdIds[1] == ((itMatchNum12->second > 0) ? match12.id0() :  match12.id1()));
                ASSERT(triangle.fdIds[1].detectionId == *it1);

                triangle.fdIds[2] = (itMatchNum02->second > 0) ? match02.id1() :  match02.id0();
                ASSERT(triangle.fdIds[2] == ((itMatchNum12->second > 0) ? match12.id1() :  match12.id0()));
                ASSERT(triangle.fdIds[2].detectionId == *it2);

                triangle.F01 = (itMatchNum01->second > 0) ? match01.data()->fundMatrix : match01.data()->fundMatrix.t();
                triangle.F02 = (itMatchNum02->second > 0) ? match02.data()->fundMatrix : match02.data()->fundMatrix.t();
                triangle.F12 = (itMatchNum12->second > 0) ? match12.data()->fundMatrix : match12.data()->fundMatrix.t();

                triangle.minGoodPtsCnt =
                    std::min(
                        std::min(
                            match01.data()->goodPtsCnt,
                            match02.data()->goodPtsCnt),
                        match12.data()->goodPtsCnt);
                triangles.emplace_back(triangle);
            }
        }
    }

    std::sort(triangles.begin(), triangles.end(),
        [](const auto& lhs, const auto& rhs) {
            return lhs.minGoodPtsCnt > rhs.minGoodPtsCnt;
        }
    );
    return triangles;
}

void clusterizeByTriangles(
    const DetectionStore& store,
    const Component& component,
    const MatchedFrameDetections& matches,
    ClusterStore& clusters)
{
    Triangles triangles = extractSortedTriangles(component, matches);
    for (size_t idx = 0; idx < triangles.size(); idx++) {
        const Triangle& triangle = triangles[idx];
        std::array<cv::Rect, 3> arBBoxes;
        std::array<cv::Point2f, 3> arCenters;
        for (size_t i = 0; i < 3; i++) {
            arBBoxes[i] = getBbox(store, triangle.fdIds[i]);
            arCenters[i] = getCenter(arBBoxes[i]);
        }

        // linesIJ - I - индекс картинки на которой ищем линию, J - индекс картинки на которой точка
        const cv::Point3f lines10 = getEpiline(arCenters[0], 1, triangle.F01);
        const cv::Point3f lines20 = getEpiline(arCenters[0], 1, triangle.F02);
        const cv::Point3f lines01 = getEpiline(arCenters[1], 2, triangle.F01);
        const cv::Point3f lines21 = getEpiline(arCenters[1], 1, triangle.F12);
        const cv::Point3f lines02 = getEpiline(arCenters[2], 2, triangle.F02);
        const cv::Point3f lines12 = getEpiline(arCenters[2], 2, triangle.F12);

        std::array<std::optional<cv::Point2f>, 3> arInters;
        arInters[0] = lineIntersection(lines01, lines02);
        arInters[1] = lineIntersection(lines10, lines12);
        arInters[2] = lineIntersection(lines20, lines21);

        int cnt = 0;
        for (size_t i = 0; i < 3; i++) {
            if (arInters[i].has_value() && pointBboxTest(arInters[i].value(), arBBoxes[i])) {
                cnt += 1;
            }
        }
        if (cnt < 3) {
            continue;
        }

        std::array<ClusterStore::ConstIterator, 3> arClusterIt;
        for (size_t i = 0; i < 3; i++) {
            arClusterIt[i] = clusters.find(triangle.fdIds[i].detectionId);
        }

        if (clusters.end() == arClusterIt[0] &&
            clusters.end() == arClusterIt[1] &&
            clusters.end() == arClusterIt[2])
        {
            clusters.createNewCluster({triangle.fdIds[0], triangle.fdIds[1], triangle.fdIds[2]});
        } else if (clusters.end() != arClusterIt[0] &&
                   clusters.end() == arClusterIt[1] &&
                   clusters.end() == arClusterIt[2])
        {
            clusters.tryAddNewElementInCluster(arClusterIt[0]->second, triangle.fdIds[1], triangle.fdIds[2]);
        } else if (clusters.end() == arClusterIt[0] &&
                   clusters.end() != arClusterIt[1] &&
                   clusters.end() == arClusterIt[2])
        {
            clusters.tryAddNewElementInCluster(arClusterIt[1]->second, triangle.fdIds[0], triangle.fdIds[2]);
        } else if (clusters.end() == arClusterIt[0] &&
                   clusters.end() == arClusterIt[1] &&
                   clusters.end() != arClusterIt[2])
        {
            clusters.tryAddNewElementInCluster(arClusterIt[2]->second, triangle.fdIds[0], triangle.fdIds[1]);
        } else if (clusters.end() == arClusterIt[0] &&
                   clusters.end() != arClusterIt[1] &&
                   clusters.end() != arClusterIt[2])
        {
            clusters.tryMergeTwoClustersAndNewElement(arClusterIt[1]->second, arClusterIt[2]->second, triangle.fdIds[0]);
        } else if (clusters.end() != arClusterIt[0] &&
                   clusters.end() == arClusterIt[1] &&
                   clusters.end() != arClusterIt[2])
        {
            clusters.tryMergeTwoClustersAndNewElement(arClusterIt[0]->second, arClusterIt[2]->second, triangle.fdIds[1]);
        } else if (clusters.end() != arClusterIt[0] &&
                   clusters.end() != arClusterIt[1] &&
                   clusters.end() == arClusterIt[2])
        {
            clusters.tryMergeTwoClustersAndNewElement(arClusterIt[0]->second, arClusterIt[1]->second, triangle.fdIds[2]);
        } else {
            clusters.tryMergeThreeClusters(arClusterIt[0]->second, arClusterIt[1]->second, arClusterIt[2]->second);
        }
    }
}

void clusterizeByMatch(
    const Component& component,
    const MatchedFrameDetections& matches,
    ClusterStore& clusters)
{
    std::vector<int> indices = sortedMatchedIndicies(component, matches);

    for (size_t i = 0; i < indices.size(); i++) {
        const MatchedFrameDetection& match = matches[indices[i]];
        if (match.data()->goodPtsCnt < MATCHES_GOOD_CNT_THR) {
            break;
        }
        if (match.data()->sampsonDistance > SAMPSON_DIST_THR) {
            continue;
        }
        const FrameDetectionId& fdId0 = match.id0();
        const FrameDetectionId& fdId1 = match.id1();

        ASSERT(fdId0.frameId != fdId1.frameId);

        auto clusterIt0 = clusters.find(fdId0.detectionId);
        auto clusterIt1 = clusters.find(fdId1.detectionId);

        if (clusters.end() == clusterIt0 && clusters.end() == clusterIt1)
        {// no clusters
            clusters.createNewCluster({fdId0, fdId1});
        } else if (clusters.end() == clusterIt0) {
            clusters.tryAddNewElementInCluster(clusterIt1->second, fdId0);
        } else if (clusters.end() == clusterIt1) {
            clusters.tryAddNewElementInCluster(clusterIt0->second, fdId1);
        } else {
            clusters.tryMergeTwoClusters(clusterIt0->second, clusterIt1->second);
        }
    }
}

void clusterizeComponent(
    const DetectionStore& store,
    const Component& component,
    const MatchedFrameDetections& matches,
    std::vector<db::TIdSet>& result)
{
    ClusterStore clusters = makeClusterStoreConsistentWithMatchesVerdicts(
        store, component.detectionIds, matches);

    clusterizeByTriangles(store, component, matches, clusters);
    clusterizeByMatch(component, matches, clusters);
    std::vector<db::TIdSet> temp = clusters.makeDetectionClusters();
    result.insert(result.end(), temp.begin(), temp.end());
}

} // namespace


TrianglesDetectionClusterizer::TrianglesDetectionClusterizer()
{}

std::vector<db::TIdSet> TrianglesDetectionClusterizer::clusterize(
    const DetectionStore& store,
    const db::TIdSet& detectionIds,
    const MatchedFrameDetections& matches) const
{
    std::vector<db::TIdSet> result;
    Components components = splitComponents(matches);
    for (size_t i = 0; i < components.size(); i++) {
        clusterizeComponent(store, components[i], matches, result);
    }
    expandClustersByUnusedDetectionIds(detectionIds, result);
    return result;
}

} // namespace maps::mrc::eye
