#include "common.h"
#include "params.h"
#include "verification.h"
#include <maps/wikimap/mapspro/services/mrc/eye/lib/object_manager/impl/object.h>
#include <maps/wikimap/mapspro/services/mrc/eye/lib/object_manager/impl/clusters.h>
#include <maps/wikimap/mapspro/services/mrc/eye/lib/object_manager/impl/greedy_merge.h>
#include <maps/wikimap/mapspro/services/mrc/eye/lib/object_manager/impl/passage.h>
#include <maps/wikimap/mapspro/services/mrc/eye/lib/object_manager/impl/merge_candidates.h>
#include <maps/wikimap/mapspro/services/mrc/eye/lib/object_manager/impl/collision.h>
#include <maps/wikimap/mapspro/services/mrc/eye/lib/object_manager/impl/common.h>
#include <maps/wikimap/mapspro/services/mrc/eye/lib/object_manager/include/object_manager.h>
#include <maps/wikimap/mapspro/services/mrc/eye/lib/object_manager/impl/db.h>

#include <maps/wikimap/mapspro/services/mrc/eye/lib/detection/include/match_candidates.h>
#include <maps/wikimap/mapspro/services/mrc/eye/lib/detection/include/store_utils.h>

#include <maps/wikimap/mapspro/services/mrc/libs/db/include/eye/object_gateway.h>

#include <algorithm>

namespace maps::mrc::eye {

namespace {

DetectionIdPairSet makeMatchCandidates(
    const DetectionStore& store,
    const std::vector<db::TIdSet>& passages)
{
    DetectionIdPairSet detectionIdPairs;
    for (const db::TIdSet& passage : passages) {
        DetectionIdPairSet passageDetectionIdPairs
            = generateMatchCandidates(
                store,
                passage, passage,
                MATCH_CANDIDATE_PARAMS
            );

        detectionIdPairs.insert(
            passageDetectionIdPairs.begin(), passageDetectionIdPairs.end()
        );
    }

    return detectionIdPairs;
}

std::vector<MatchedFrameDetections> splitMatchesByPassages(
    const MatchedFrameDetections& matches,
    const std::vector<db::TIdSet>& passages)
{
    std::vector<MatchedFrameDetections> matchesByPassages(passages.size());

    db::IdTo<std::set<size_t>> detectionIdToPassageIndices;
    for (size_t passageIndx = 0; passageIndx < passages.size(); passageIndx++) {
        for (db::TId detectionId : passages[passageIndx]) {
            detectionIdToPassageIndices[detectionId].insert(passageIndx);
        }
    }

    for (const MatchedFrameDetection& match : matches) {
        const std::set<size_t>& passageIndices0
            = detectionIdToPassageIndices.at(match.id0().detectionId);
        const std::set<size_t>& passageIndices1
            = detectionIdToPassageIndices.at(match.id1().detectionId);

        std::vector<size_t> passageIndices;
        std::set_intersection(
            passageIndices0.begin(), passageIndices0.end(),
            passageIndices1.begin(), passageIndices1.end(),
            std::back_inserter(passageIndices)
        );

        for (size_t passageIndx : passageIndices) {
            matchesByPassages[passageIndx].push_back(match);
        }
    }

    return matchesByPassages;
}

std::optional<MatchedFrameDetection::Verdict>
selectStrongerVerdict(
    const std::optional<MatchedFrameDetection::Verdict>& one,
    const std::optional<MatchedFrameDetection::Verdict>& other
)
{
    if (one == other) {
        return one;
    }
    if (one == MatchedFrameDetection::Verdict::No ||
            other == MatchedFrameDetection::Verdict::No)
    {
        return MatchedFrameDetection::Verdict::No;
    }

    if (one.has_value()) {
        return one.value();
    }

    if (other.has_value()) {
        return other.value();
    }
    return std::nullopt;
}

} // namespace

std::vector<ObjectsInPassage> makeObjectsByPassages(
    const DetectionStore& store,
    const FrameMatcher& frameMatcher,
    const DetectionMatcher& detectionMatcher,
    const DetectionClusterizer& clusterizer,
    const std::vector<db::TIdSet>& passages)
{
    db::TIdSet initDetectionIds;
    for (const auto& detectionIds : passages) {
        initDetectionIds.insert(detectionIds.begin(), detectionIds.end());
    };

    std::vector<MatchedFrameDetections> matchesByPassages
        = splitMatchesByPassages(
            detectionMatcher.makeMatches(store, makeMatchCandidates(store, passages), &frameMatcher),
            passages
        );
    ASSERT(passages.size() == matchesByPassages.size());

    std::vector<ObjectsInPassage> objectsByPassages;
    for (size_t i = 0; i < passages.size(); i++) {
        const MatchedFrameDetections& matches = matchesByPassages[i];

        std::vector<db::TIdSet> clusters;
        if (!matches.empty()) {
            clusters = clusterizer.clusterize(store, passages[i], matches);
        } else {
            for (db::TId detectionId : passages[i]) {
                clusters.push_back({detectionId});
            }
        }

        removeUntouchedClusters(&clusters, initDetectionIds);

        ObjectsInPassage result;
        for (db::TIdSet& cluster : clusters) {
            db::TId primaryId = choosePrimaryId(store, cluster);
            result.objectByPrimaryId.emplace(primaryId, makeObject(store, primaryId));
            result.locationByPrimaryId.emplace(primaryId, makeObjectLocation(store, cluster));
            result.detectionIdsByPrimaryId[primaryId] = std::move(cluster);
        }
        objectsByPassages.push_back(std::move(result));
    }

    return objectsByPassages;
}


namespace {

ObjectStore makeObjectStore(
    pqxx::transaction_base& txn,
    const DetectionStore& detectionStore,
    chrono::TimePoint actualAt)
{
    ObjectStore store;
    auto objects = loadNeighboringObjects(
        txn, detectionStore, OBJECT_VISIBILITY_DISTANCE_METERS);
    objects.erase(
        std::remove_if(
            objects.begin(),
            objects.end(),
            [&](const auto& object) {
                return object.disappearedAt().has_value() &&
                    object.disappearedAt().value() < actualAt;
            }
        ),
        objects.end()
    );

    store.extendByObjects(txn, std::move(objects));
    store.extendByDetectionIds(txn, detectionStore.detectionIds());

    return store;
}

std::map<ObjectPassageIndxPair, std::pair<db::TIdSet, db::TIdSet>>
generateDetectionIdSetsForComparison(
    const DetectionStore& detectionStore,
    const std::vector<ObjectPassageIndxPair>& mergeCandidates,
    const std::vector<ObjectsInPassage>& objectsByPassages,
    size_t detectionSetMaxSize)
{
    std::map<ObjectPassageIndxPair, std::pair<db::TIdSet, db::TIdSet>> mergeCandidatesToDetectionIdSets;

    for (const auto& objectPassageIndxPair : mergeCandidates) {
        const auto& [primaryId1, passageIndx1] = objectPassageIndxPair.first;
        const auto& [primaryId2, passageIndx2] = objectPassageIndxPair.second;

        mergeCandidatesToDetectionIdSets[objectPassageIndxPair]
            = std::make_pair(
                selectBiggestDetectionIds(
                    detectionStore,
                    objectsByPassages[passageIndx1].detectionIdsByPrimaryId.at(primaryId1),
                    detectionSetMaxSize
                ),
                selectBiggestDetectionIds(
                    detectionStore,
                    objectsByPassages[passageIndx2].detectionIdsByPrimaryId.at(primaryId2),
                    detectionSetMaxSize
                )
            );
    }

    return mergeCandidatesToDetectionIdSets;
}

std::map<ObjectPassageIndxPair, MatchedFrameDetections>
makeMatchesByObjectPassageIndxPair(
    const DetectionStore& store,
    const std::map<ObjectPassageIndxPair, std::pair<db::TIdSet, db::TIdSet>>& objectPassageIndxPairToDetectionIdSets,
    const FrameMatcher& frameMatcher,
    const DetectionMatcher& detectionMatcher,
    float matchRelevanceThreshold)
{
    std::map<DetectionIdPair, ObjectPassageIndxPair> detectionIdPairToObjectPassageIndxPair;
    DetectionIdPairSet detectionPairs;
    for (const auto& [objectPassageIndxPair, detectionSets] : objectPassageIndxPairToDetectionIdSets) {
        const db::TIdSet& detectionIds1 = detectionSets.first;
        const db::TIdSet& detectionIds2 = detectionSets.second;

        for (db::TId detectionId1 : detectionIds1) {
            for (db::TId detectionId2 : detectionIds2) {
                if (detectionId1 == detectionId2) {
                    continue;
                }

                DetectionIdPair detectionIdPair{detectionId1, detectionId2};
                detectionPairs.insert(detectionIdPair);
                detectionIdPairToObjectPassageIndxPair[detectionIdPair] = objectPassageIndxPair;
            }
        }
    }

    MatchedFrameDetections matches = detectionMatcher.makeMatches(store, detectionPairs, &frameMatcher);

    std::map<ObjectPassageIndxPair, MatchedFrameDetections> objectPassageIndxPairToMatches;
    for (MatchedFrameDetection& match : matches) {
        if (match.relevance() <= matchRelevanceThreshold) {
            continue;
        }
        // reversed match
        auto it = detectionIdPairToObjectPassageIndxPair.find({match.id1().detectionId, match.id0().detectionId});
        if (detectionIdPairToObjectPassageIndxPair.end() != it) {
            objectPassageIndxPairToMatches[it->second].push_back(reverseMatch(match));
        }

        it = detectionIdPairToObjectPassageIndxPair.find({match.id0().detectionId, match.id1().detectionId});
        if (detectionIdPairToObjectPassageIndxPair.end() != it) {
            objectPassageIndxPairToMatches[it->second].push_back(std::move(match));
        }
    }

    return objectPassageIndxPairToMatches;
}


std::vector<MatchedObjects> sortObjectsMatches(
    const std::map<ObjectPassageIndxPair, MatchedFrameDetections>& objectPassageIndxPairToMatches)
{
    std::vector<MatchedObjects> objectsMatches;
    for (const auto& [objectPassageIndxPair, matches] : objectPassageIndxPairToMatches) {
        float avgMatchRelevance = 0.;
        std::optional<MatchedFrameDetection::Verdict> verdict;
        for (const MatchedFrameDetection& match : matches) {
            avgMatchRelevance += match.relevance();
            verdict = selectStrongerVerdict(verdict, match.verdict());
        }
        avgMatchRelevance /= matches.size();

        objectsMatches.push_back({
            objectPassageIndxPair.first,
            objectPassageIndxPair.second,
            avgMatchRelevance,
            verdict
        });
    }

    auto cmpKey = [](const auto& objectMatch) -> auto {
        return RelevanceWithVerdict{
            .relevance = objectMatch.relevance,
            .verdict = objectMatch.verdict
        };
    };
    std::sort(objectsMatches.begin(), objectsMatches.end(),
        [&](const auto& lhs, const auto& rhs) {
            using namespace std::rel_ops;
            return cmpKey(lhs) > cmpKey(rhs);
        }
    );

    return objectsMatches;
}

db::IdTo<db::TIdSet> choosePrimaryIds(
    const DetectionStore& detectionStore,
    std::vector<db::TIdSet> objectsDetectionIds,
    std::function<bool(db::TId)> isPrimaryDetectionId)
{
    db::IdTo<db::TIdSet> detectionIdsByPrimaryId;
    for (auto& detectionIds : objectsDetectionIds) {
        const db::TId primaryId = choosePrimaryId(detectionStore, detectionIds, isPrimaryDetectionId);
        detectionIdsByPrimaryId[primaryId] = std::move(detectionIds);
    }

    return detectionIdsByPrimaryId;
}

} // namespace

ObjectsInPassage makeFakeObjectsInPassage(
    const DetectionStore& detectionStore,
    const ObjectStore& objectStore) {
    ObjectsInPassage objectsInPassage;

    const auto& relationMap = objectStore.relationMap();

    for (const auto& object : objectStore.objects()) {
        if (object.deleted()) {
            continue;
        }

        const db::TId primaryId = object.primaryDetectionId();
        if (!detectionStore.hasDetection(primaryId) ||
            detectionStore.detectionById(primaryId).deleted())
        {
            // главная детекций объекта удалена, а сам объект еще
            // не удален. Это значит, что объект в скором времени
            // будет удален, а все его остальные детекции будут
            // перекластеризованы, поэтому сейчас его можно игнорировать
            continue;
        }

        db::TIdSet detectionIds{primaryId};
        for (db::TId detectionId : relationMap.at(primaryId)) {
            if (!detectionStore.hasDetection(detectionId) ||
                detectionStore.detectionById(detectionId).deleted())
            {
                // Детекция удалена, но оказалась связана с объектом
                // Такое могло произойти, если еще не была удалена
                // соответствующая primary_detection_relation,
                // однако она будет удалена в скором времени
                // поэтому эту удаленную детекцию можно не учитывать
                continue;
            }

            detectionIds.insert(detectionId);
        }

        const auto& objectLocation = objectStore.locationByObjectId(object.id());
        const Location location{
            objectLocation.mercatorPos(),
            objectLocation.rotation()
        };

        objectsInPassage.objectByPrimaryId.emplace(primaryId, object);
        objectsInPassage.locationByPrimaryId[primaryId] = location;
        objectsInPassage.detectionIdsByPrimaryId[primaryId] = std::move(detectionIds);
    }

    return objectsInPassage;
}

ObjectStore
loadExistingObjectsToMatch(
    pqxx::transaction_base& txn,
    DetectionStore& detectionStore,
    chrono::TimePoint actualAt
)
{
    INFO() << "Loading existing objects";
    ObjectStore objectStore;
    objectStore = makeObjectStore(txn, detectionStore, actualAt);
    detectionStore.extendByObjectRelations(txn, objectStore.relationMap());

    INFO() << detectionStore.detectionIds().size() << " detections in detection store";
    INFO() << detectionStore.frameById().size() << " frames in detection store";
    INFO() << objectStore.objectsCount() << " objects in object store";

    return objectStore;
}


std::vector<MatchedObjects> matchObjectsByPassages(
    const DetectionStore& detectionStore,
    const FrameMatcher& frameMatcher,
    const DetectionMatcher& detectionMatcher,
    const std::vector<ObjectsInPassage>& objectsByPassages
)
{
    INFO() << "Generating candidates for merging";
    auto mergeCandidates = generateMergeCandidates(objectsByPassages, MERGE_CANDIDATES_PARAMS);
    INFO() << mergeCandidates.size() << " candidates has been generated";

    // Генерируем пары детекций, по которым будем сравнивать объекты
    static constexpr size_t DETECTION_SET_MAX_SIZE = 5;
    INFO() << "Generating detection sets for matching";
    const auto mergeCandidatesToDetectionIdSets
        = generateDetectionIdSetsForComparison(
            detectionStore, mergeCandidates,
            objectsByPassages,
            DETECTION_SET_MAX_SIZE
        );
    INFO() << mergeCandidatesToDetectionIdSets.size() << " candidates left";

    // Находим матчи между двумя множествами детекций
    INFO() << "Making matches for merging candidates";
    static constexpr double RELEVANCE_THRESHOLD = 0.000005;
    std::map<ObjectPassageIndxPair, MatchedFrameDetections> mergeCandidatesToMatches
        = makeMatchesByObjectPassageIndxPair(
            detectionStore,
            mergeCandidatesToDetectionIdSets,
            frameMatcher,
            detectionMatcher,
            RELEVANCE_THRESHOLD
        );

    // сортируем по некоторой логике, основанной на матчах между наборами детекций
    INFO() << "Sorting candidates for merging by matches";
    std::vector<MatchedObjects> sortedObjectsMatches
        = sortObjectsMatches(mergeCandidatesToMatches);
    return sortedObjectsMatches;
}


db::IdTo<db::TIdSet> mergeObjectsByPassages(
    const std::vector<MatchedObjects>& sortedObjectsMatches,
    const DetectionStore& detectionStore,
    const ObjectStore& objectStore,
    const std::vector<ObjectsInPassage>& objectsByPassages,
    const FrameMatcher& frameMatcher,
    const DetectionMatcher& detectionMatcher,
    std::function<bool(const db::TIdSet&, const db::TIdSet&)> areAllowedToMerge)
{
    INFO() << "Merging objects";
    std::vector<db::TIdSet> mergedObjectsDetectionIds
        = greedyObjectsMerging(
            sortedObjectsMatches,
            objectsByPassages,
            areAllowedToMerge
        );

    // Каждая детекция должна принадлежать только одному объекту
    INFO() << "Resolving collisions";
    mergedObjectsDetectionIds
        = resolveCollisions(
            detectionStore, frameMatcher, detectionMatcher,
            std::move(mergedObjectsDetectionIds));

    // выбираем главные детекции
    INFO() << "Selecting primary detections for new objects";
    auto isPrimaryDetectionId =
        [&](db::TId detectionId) {
            return objectStore.isPrimaryDetectionId(detectionId);
        };
    db::IdTo<db::TIdSet>
            mergedDetectionIdsByPrimaryId = choosePrimaryIds(
                detectionStore,
                std::move(mergedObjectsDetectionIds),
                isPrimaryDetectionId);

    return mergedDetectionIdsByPrimaryId;
}

} // namespace maps::mrc::eye
