#include <maps/wikimap/mapspro/services/mrc/eye/lib/object_manager/impl/collision.h>
#include <maps/wikimap/mapspro/services/mrc/eye/lib/object_manager/impl/common.h>

#include <maps/libs/log8/include/log8.h>

namespace maps::mrc::eye {

db::IdTo<std::vector<size_t>> findCollisions(
    const std::vector<db::TIdSet>& objectsDetectionIds)
{
    db::IdTo<std::vector<size_t>> detectionIdToObjectIndxs;
    for (size_t objectIndx = 0; objectIndx < objectsDetectionIds.size(); objectIndx++) {
        for (db::TId detectionId : objectsDetectionIds[objectIndx]) {
            detectionIdToObjectIndxs[detectionId].push_back(objectIndx);
        }
    }

    for (auto it = detectionIdToObjectIndxs.begin();
         it != detectionIdToObjectIndxs.end();)
    {
        if (it->second.size() > 1) {
            it++;
            continue;
        }
        it = detectionIdToObjectIndxs.erase(it);
    }

    return detectionIdToObjectIndxs;
}

namespace {

// Выбираем множество представительных детекций для объектов,
// которые участвуют в коллизиях
//
// Выход:
//   * std::unordered_map<size_t, db::TIdSet> - отображение индекса
//     объекта в множество id представительных детекций объекта.
//     В качестве представительных детекций выбирается на более
//     N (фиксированный порог) наибольших детекций объекта, которые
//     не участвуют в коллизиях
//   * db::IdTo<size_t> - отображение detectionId представительной
//     детекции в индекс объекта, который она представляет, т.е.
//     это обратное отображение к первому отображению из результата.
std::pair<std::unordered_map<size_t, db::TIdSet>, db::IdTo<size_t>>
selectObjectsDetectionSets(
    const DetectionStore& detectionStore,
    const std::vector<db::TIdSet>& objectsDetectionIds,
    const db::IdTo<std::vector<size_t>>& collisions)
{
    static const size_t DETECTION_SET_MAX_SIZE = 20;

    db::TIdSet collisionDetectionIds;
    for (const auto& [detectionId, _] : collisions) {
        collisionDetectionIds.insert(detectionId);
    }

    std::unordered_map<size_t, db::TIdSet> objectIndxToDetectionIds;
    db::IdTo<size_t> detectionIdToObjectIndx;

    for (const auto& [_, objectIndxs] : collisions) {
        for (const size_t objectIndx : objectIndxs) {
            if (objectIndxToDetectionIds.count(objectIndx) != 0) {
                continue;
            }

            db::TIdSet detectionIds = objectsDetectionIds.at(objectIndx);
            for (db::TId detectionId : collisionDetectionIds) {
                detectionIds.erase(detectionId);
            }

            if (detectionIds.empty()) {
                continue;
            }

            detectionIds = selectBiggestDetectionIds(
                detectionStore,
                detectionIds,
                DETECTION_SET_MAX_SIZE
            );

            for (db::TId detectionId : detectionIds) {
                detectionIdToObjectIndx[detectionId] = objectIndx;
            }
            objectIndxToDetectionIds[objectIndx] = std::move(detectionIds);
        }
    }

    return {objectIndxToDetectionIds, detectionIdToObjectIndx};
}

// Находит пары детекций, которые надо проверить, чтобы разрешить
// коллизии.
// Пара детекций добавляется в результат по следующему принципу:
//   * Первая детекций в паре - детекций, с которой есть коллизия
//   * Вторая детекция в паре - детекция из набора представительных
//   детекций объекта, который участвует в коллизии с первой детекцией
//   из пары детекций.
DetectionIdPairSet makeDetectionPairs(
    const db::IdTo<std::vector<size_t>>& collisions,
    const std::unordered_map<size_t, db::TIdSet>& objectIndxToDetectionIds)
{
    DetectionIdPairSet pairs;
    for (const auto& [detectionId0, objectIndxs] : collisions) {
        for (const size_t objectIndx : objectIndxs) {
            auto detectionIdsIt = objectIndxToDetectionIds.find(objectIndx);
            if (objectIndxToDetectionIds.end() == detectionIdsIt) {
                continue;
            }

            for (db::TId detectionId1 : detectionIdsIt->second) {
                pairs.emplace(detectionId0, detectionId1);
            }
        }
    }

    return pairs;
}

// Вычисляет уверенность принадлежности детекций, у которых есть
// коллизии, ко всем объектам из коллизии.
//
//   * db::IdTo<std::unordered_map<size_t, double>> - отобажение
//   со следующими значениями:
//     detectionId -> objectIndx -> score
//   Т.е. по id детекции находим отображение индекса объекта в
//   уверенность принадлежности к этому объекту.
db::IdTo<std::unordered_map<size_t, double>> makeCollisionsScores(
    const MatchedFrameDetections& matches,
    const db::IdTo<std::vector<size_t>>& collisions,
    const db::IdTo<size_t>& detectionIdToObjectIndx)
{
    db::IdTo<std::unordered_map<size_t, std::vector<double>>> collisionsScoresVec;
    for (const MatchedFrameDetection& match : matches) {
        db::TId collisionDetectionId;
        db::TId detectionId;

        if (collisions.count(match.id0().detectionId) != 0) {
            collisionDetectionId = match.id0().detectionId;
            detectionId = match.id1().detectionId;
        } else {
            collisionDetectionId = match.id1().detectionId;
            detectionId = match.id0().detectionId;
        }

        size_t objectIndx = detectionIdToObjectIndx.at(detectionId);
        collisionsScoresVec[collisionDetectionId][objectIndx].push_back(match.relevance());
    }


    db::IdTo<std::unordered_map<size_t, double>> collisionsAvgScores;
    for (const auto& [detectionId, objectIdToScores] : collisionsScoresVec) {
        for (const auto& [objectId, scores] : objectIdToScores) {
            double avgScore = 0.;
            for (double score : scores) {
                avgScore += score;
            }
            avgScore /= scores.size();

            collisionsAvgScores[detectionId][objectId] = avgScore;
        }
    }

    return collisionsAvgScores;
}

// Выбираем к какому объекту привязать детекцию из коллизии.
// Если для детекции нашлись матчи хотя бы с одним из объектов,
// то среди всех объектов выбираем тот, уверенность принадлежности
// к которому самая большая. Если матчей не нашлось, то выбираем
// объект, который включает в себя наибольшее число детекций.
db::IdTo<size_t> findCollisionSolutions(
    const std::vector<db::TIdSet>& objectsDetectionIds,
    const db::IdTo<std::vector<size_t>>& collisions,
    const db::IdTo<std::unordered_map<size_t, double>>& collisionsScores)
{
    db::IdTo<size_t> collisionSolutions;
    for (const auto& [detectionId, objectIndxs] : collisions) {
        auto scoresIt = collisionsScores.find(detectionId);
        if (collisionsScores.end() != scoresIt && !scoresIt->second.empty()) {
            const auto& collisionScores = scoresIt->second;

            size_t bestObjectIndx = collisionScores.begin()->first;
            double bestScore = collisionScores.begin()->second;

            for (const auto& [objectIndx, score] : collisionScores) {
                if (score > bestScore) {
                    bestScore = score;
                    bestObjectIndx = objectIndx;
                }
            }

            collisionSolutions[detectionId] = bestObjectIndx;
        } else {
            size_t bestObjectIndx = objectIndxs.front();
            size_t bestObjectSize = objectsDetectionIds[bestObjectIndx].size();

            for (size_t objectIndx : objectIndxs) {
                size_t objectSize = objectsDetectionIds[objectIndx].size();
                if (objectSize > bestObjectSize) {
                    bestObjectSize = objectSize;
                    bestObjectIndx = objectIndx;
                }
            }

            collisionSolutions[detectionId] = bestObjectIndx;
        }
    }

    return collisionSolutions;
}

} // namespace

db::IdTo<size_t> findCollisionSolutions(
    const DetectionStore& detectionStore,
    const FrameMatcher& frameMatcher,
    const DetectionMatcher& detectionMatcher,
    const std::vector<db::TIdSet>& objectsDetectionIds,
    const db::IdTo<std::vector<size_t>>& collisions)
{
    const auto [objectIndxToDetectionIds, detectionIdToObjectIndx]
        = selectObjectsDetectionSets(detectionStore, objectsDetectionIds, collisions);
    DetectionIdPairSet pairs = makeDetectionPairs(collisions, objectIndxToDetectionIds);
    MatchedFrameDetections matches = detectionMatcher.makeMatches(detectionStore, pairs, &frameMatcher);
    db::IdTo<std::unordered_map<size_t, double>> collisionsScores
        = makeCollisionsScores(matches, collisions, detectionIdToObjectIndx);

    return findCollisionSolutions(
        objectsDetectionIds,
        collisions, collisionsScores
    );
}

std::vector<db::TIdSet> applyCollisionSolutions(
    std::vector<db::TIdSet> objectsDetectionIds,
    const db::IdTo<std::vector<size_t>>& collisions,
    const db::IdTo<size_t>& collisionSolutions)
{
    for (const auto& [detectionId, bestObjectIndx] : collisionSolutions) {
        for (size_t objectIndx : collisions.at(detectionId)) {
            if (bestObjectIndx != objectIndx) {
                objectsDetectionIds[objectIndx].erase(detectionId);
            }
        }
    }

    objectsDetectionIds.erase(
        std::remove_if(objectsDetectionIds.begin(), objectsDetectionIds.end(),
            [](const auto& detectionIds) {
                return detectionIds.empty();
            }
        ),
        objectsDetectionIds.end()
    );

    return objectsDetectionIds;
}

std::vector<db::TIdSet> resolveCollisions(
    const DetectionStore& store,
    const FrameMatcher& frameMatcher,
    const DetectionMatcher& detectionMatcher,
    std::vector<db::TIdSet> objectsDetectionIds)
{
    db::IdTo<std::vector<size_t>> collisions
        = findCollisions(objectsDetectionIds);
    INFO() << "There are " << collisions.size() << " collisions";

    db::IdTo<size_t> collisionSolutions
        = findCollisionSolutions(
            store, frameMatcher, detectionMatcher,
            objectsDetectionIds, collisions
        );

    return applyCollisionSolutions(
        std::move(objectsDetectionIds),
        collisions, collisionSolutions
    );
}

} // namespace maps::mrc::eye
