#include <maps/wikimap/mapspro/services/mrc/libs/signdetect/include/sign_relations.h>

namespace maps::mrc::signdetect {

namespace {

constexpr double MULT_KOEF = 1.7;

struct ObjectsDistance {
    int horizontalCenterDistance;
    int verticalCenterDistance;
    int horizontalBoxDistance;
    int verticalBoxDistance;
};

struct Relation {
    Relation(size_t firstIndex, size_t secondIndex, double dist)
        : firstObjectIndex(firstIndex)
        , secondObjectIndex(secondIndex)
        , objectsDistance(dist)
    {}

    size_t firstObjectIndex;
    size_t secondObjectIndex;
    double objectsDistance;
};

/*
    pair.first - набор индексов в signs, для которых знак имеет
                 тип дополнительной таблицы
    pair.second - набор индексов в signs, для которых знак есть обычный знак
*/
std::pair<std::vector<size_t>, std::vector<size_t>>
splitTablesAndSigns(const DetectedSigns& signs) {
    std::vector<size_t> addTables;
    std::vector<size_t> bigSigns;
    for (size_t i = 0; i < signs.size(); i++) {
        if (maps::mrc::traffic_signs::isAdditionalTable(signs[i].sign)) {
            addTables.emplace_back(i);
        } else {
            bigSigns.emplace_back(i);
        }
    }
    return {addTables, bigSigns};
}

int calcSegmentDistance(int min1, int max1, int min2, int max2) {
    if (min1 < min2) {
        return (min2 < max1) ? 0 : (min2 - max1);
    }
    // (min2 <= min1)
    return (min1 < max2) ? 0 : (min1 - max2);
}

ObjectsDistance calcObjectsDistance(const cv::Rect& bbox1, const cv::Rect& bbox2) {
    const int x1min = (bbox1.width  < 0) ? (bbox1.x + bbox1.width)  : bbox1.x;
    const int x1max = (bbox1.width  > 0) ? (bbox1.x + bbox1.width)  : bbox1.x;
    const int y1min = (bbox1.height < 0) ? (bbox1.y + bbox1.height) : bbox1.y;
    const int y1max = (bbox1.height > 0) ? (bbox1.y + bbox1.height) : bbox1.y;

    const int x2min = (bbox2.width  < 0) ? (bbox2.x + bbox2.width)  : bbox2.x;
    const int x2max = (bbox2.width  > 0) ? (bbox2.x + bbox2.width)  : bbox2.x;
    const int y2min = (bbox2.height < 0) ? (bbox2.y + bbox2.height) : bbox2.y;
    const int y2max = (bbox2.height > 0) ? (bbox2.y + bbox2.height) : bbox2.y;

    ObjectsDistance result;
    result.horizontalCenterDistance = abs(x1min + x1max - x2min - x2max) / 2;
    result.verticalCenterDistance = abs(y1min + y1max - y2min - y2max) / 2;
    result.horizontalBoxDistance = calcSegmentDistance(x1min, x1max, x2min, x2max);
    result.verticalBoxDistance = calcSegmentDistance(y1min, y1max, y2min, y2max);

    return result;
}

// ищем привязку таблички к знаку
std::vector<Relation> tablesToSignsRelations(
    const DetectedSigns& signs,
    const std::vector<size_t>& tablesIndices,
    const std::vector<size_t>& signsIndices)
{
    std::vector<Relation> relations;
    for (size_t i = 0; i < tablesIndices.size(); i++) {
        const cv::Rect rcTable = signs[tablesIndices[i]].box;

        std::vector<Relation> verticalRelations;
        std::vector<Relation> horizontalRelations;
        for (size_t j = 0; j < signsIndices.size(); j++) {
            const cv::Rect rcSign = signs[signsIndices[j]].box;
            const ObjectsDistance objDist = calcObjectsDistance(rcTable, rcSign);
            const double widthMax  = MULT_KOEF * std::max(rcTable.width, rcSign.width);
            const double heightMax = MULT_KOEF * std::max(rcTable.height, rcSign.height);
            if ((0 == objDist.horizontalBoxDistance) &&
                (objDist.verticalBoxDistance <= heightMax) &&
                (rcTable.y + rcTable.height / 2 > rcSign.y + rcSign.height / 2))
            {
                verticalRelations.emplace_back(
                    tablesIndices[i],
                    signsIndices[j],
                    objDist.horizontalCenterDistance + objDist.verticalBoxDistance);
            }
            if ((0 == objDist.verticalBoxDistance) &&
                (objDist.horizontalBoxDistance <= widthMax))
            {
                horizontalRelations.emplace_back(
                    tablesIndices[i],
                    signsIndices[j],
                    objDist.verticalCenterDistance + objDist.horizontalBoxDistance);
            }
        }
        if (0 < verticalRelations.size()) {
            std::vector<Relation>::iterator it =
                std::min_element(
                    verticalRelations.begin(), verticalRelations.end(),
                    [&](const Relation& a, const Relation& b) {
                        return a.objectsDistance < b.objectsDistance;
                    }
                );
            relations.emplace_back(*it);
        }
        else if (0 < horizontalRelations.size()) {
            std::vector<Relation>::iterator it =
                std::min_element(
                    horizontalRelations.begin(), horizontalRelations.end(),
                    [&](const Relation& a, const Relation& b) {
                        return a.objectsDistance < b.objectsDistance;
                    }
                );
            relations.emplace_back(*it);
        }
    }
    return relations;
}

std::map<size_t, size_t> makeTableToSignMap(const std::vector<Relation>& relations) {
    std::map<size_t, size_t> result;
    for (size_t i = 0; i < relations.size(); i++) {
        result[relations[i].firstObjectIndex] = relations[i].secondObjectIndex;
    }
    return result;
}

// ищем связь таблички (1), которая пока не привязана к знаку,
// с табличкой (2) которая уже привязана к знаку
void updateRelationsByIntermediateTable (
    const DetectedSigns& signs,
    const std::vector<size_t>& tablesIndices,
    std::vector<Relation>& relations)
{
    std::map<size_t, size_t> tableIndexToSignIndex = makeTableToSignMap(relations);

    for (size_t i1 = 0; i1 < tablesIndices.size(); i1++) {
        const size_t idx1 = tablesIndices[i1];
        auto it1 = tableIndexToSignIndex.find(idx1);
        if (it1 != tableIndexToSignIndex.end()) {
            continue;
        }

        const cv::Rect rc1 = signs[idx1].box;
        std::vector<Relation> verticalRelations;
        std::vector<Relation> horizontalRelations;
        for (size_t i2 = 0; i2 < tablesIndices.size(); i2++) {
            if (i1 == i2) {
                continue;
            }
            const size_t idx2 = tablesIndices[i2];
            auto it2 = tableIndexToSignIndex.find(idx2);
            if (it2 == tableIndexToSignIndex.end()) {
                continue;
            }
            const cv::Rect rc2 = signs[idx2].box;
            const ObjectsDistance objDist = calcObjectsDistance(rc1, rc2);
            const double widthMax  = MULT_KOEF * std::max(rc1.width, rc2.width);
            const double heightMax = MULT_KOEF * std::max(rc1.height, rc2.height);
            if ((0 == objDist.horizontalBoxDistance) &&
                (objDist.verticalBoxDistance <= heightMax) &&
                (rc1.y + rc1.height / 2 > rc2.y + rc2.height / 2))
            {
                verticalRelations.emplace_back(
                    idx1, it2->second,
                    objDist.horizontalCenterDistance + objDist.horizontalBoxDistance);
            }
            if ((0 == objDist.verticalBoxDistance) &&
                (objDist.horizontalBoxDistance <= widthMax))
            {
                horizontalRelations.emplace_back(
                    idx1, it2->second,
                    objDist.verticalCenterDistance + objDist.verticalBoxDistance);
            }
        }
        if (0 < verticalRelations.size()) {
            std::vector<Relation>::iterator it =
                std::min_element(
                    verticalRelations.begin(), verticalRelations.end(),
                    [&](const Relation& a, const Relation& b) {
                        return a.objectsDistance < b.objectsDistance;
                    }
                );
            relations.emplace_back(*it);
        }
        else if (0 < horizontalRelations.size()) {
            std::vector<Relation>::iterator it =
                std::min_element(
                    horizontalRelations.begin(), horizontalRelations.end(),
                    [&](const Relation& a, const Relation& b) {
                        return a.objectsDistance < b.objectsDistance;
                    }
                );
            relations.emplace_back(*it);
        }
    }
}

std::vector<std::pair<size_t, size_t>> convertRelationsToPairs(const std::vector<Relation>& relations)
{
    std::vector<std::pair<size_t, size_t>> result;
    for (size_t i = 0; i < relations.size(); i++) {
        result.emplace_back(relations[i].firstObjectIndex, relations[i].secondObjectIndex);
    }
    return result;
}

} // namespace

std::vector<std::pair<size_t, size_t>> foundRelations(const DetectedSigns& signs) {
    std::vector<size_t> tablesIndices, signsIndices;
    std::tie(tablesIndices, signsIndices) = splitTablesAndSigns(signs);
    std::vector<Relation> relations = tablesToSignsRelations(signs, tablesIndices, signsIndices);
    updateRelationsByIntermediateTable(signs, tablesIndices, relations);
    return convertRelationsToPairs(relations);
}

} // namespace maps::mrc::signdetect
