#include <maps/wikimap/mapspro/services/mrc/eye/lib/detection/impl/cluster_store.h>
#include <maps/wikimap/mapspro/services/mrc/eye/lib/detection/include/utils.h>

namespace maps::mrc::eye {

namespace {


bool overlap(const FrameDetectionIdSet& lhs, const FrameDetectionIdSet& rhs)
{
    auto lhsIt = lhs.begin();
    auto rhsIt = rhs.begin();
    while (lhsIt != lhs.end() and rhsIt != rhs.end()) {
        if (lhsIt->frameId < rhsIt->frameId) {
            ++lhsIt;
        } else if (lhsIt->frameId > rhsIt->frameId) {
            ++rhsIt;
        } else if (lhsIt->detectionId < rhsIt->detectionId) {
            ++lhsIt;
        } else if (lhsIt->detectionId > rhsIt->detectionId) {
            ++rhsIt;
        } else {
            return true;
        }
    }
    return false;
}

bool mayBeMerged(const Cluster& cluster0, const Cluster& cluster1)
{
    return !hasCommonFrame(cluster0.detectionIds, cluster1.detectionIds) &&
        ! overlap(cluster0.detectionIds, cluster1.forbiddenDetectionIds) &&
        ! overlap(cluster1.detectionIds, cluster0.forbiddenDetectionIds);
}

} // namespace

ClusterStore::ConstIterator ClusterStore::find(db::TId detectionId) const
{
    return clusterByDetectionId_.find(detectionId);
}

ClusterStore::ConstIterator ClusterStore::begin() const
{
    return clusterByDetectionId_.begin();
}

ClusterStore::ConstIterator ClusterStore::end() const
{
    return clusterByDetectionId_.end();
}

void ClusterStore::createNewCluster(FrameDetectionIdSet frameDetectionIds, FrameDetectionIdSet forbiddenDetectionIds)
{
    for (auto it = frameDetectionIds.begin(); it != frameDetectionIds.end(); it++) {
        ASSERT(clusterByDetectionId_.find(it->detectionId) == clusterByDetectionId_.end());
    }
    clusters_.push_front({.detectionIds = frameDetectionIds, .forbiddenDetectionIds = std::move(forbiddenDetectionIds)});
    for (auto it = frameDetectionIds.begin(); it != frameDetectionIds.end(); it++) {
        clusterByDetectionId_[it->detectionId] = clusters_.begin();
    }
}

bool ClusterStore::tryAddNewElementInCluster(
    ClusterIter clusterIt,
    const FrameDetectionId& fdId)
{
    if (hasCommonFrame(fdId, clusterIt->detectionIds)) {
        return false;
    }
    if (clusterIt->forbiddenDetectionIds.count(fdId)) {
        return false;
    }
    clusterByDetectionId_[fdId.detectionId] = clusterIt;
    clusterIt->detectionIds.insert(fdId);
    return true;
}

bool ClusterStore::tryAddNewElementInCluster(
    ClusterIter clusterIt,
    const FrameDetectionId& fdId0,
    const FrameDetectionId& fdId1)
{
    if (hasCommonFrame(fdId0, clusterIt->detectionIds) ||
        hasCommonFrame(fdId1, clusterIt->detectionIds))
    {
        return false;
    }
    if (clusterIt->forbiddenDetectionIds.count(fdId0) ||
            clusterIt->forbiddenDetectionIds.count(fdId1))
    {
        return false;
    }

    clusterByDetectionId_[fdId0.detectionId] = clusterIt;
    clusterIt->detectionIds.insert(fdId0);
    clusterByDetectionId_[fdId1.detectionId] = clusterIt;
    clusterIt->detectionIds.insert(fdId1);
    return true;
}

void ClusterStore::mergeTwoClusters(
    ClusterIter clusterIt0,
    ClusterIter clusterIt1)
{
    for (const auto& [frameId, detectionId] : clusterIt1->detectionIds) {
        clusterByDetectionId_[detectionId] = clusterIt0;
    }
    clusterIt0->detectionIds.merge(clusterIt1->detectionIds);
    clusterIt0->forbiddenDetectionIds.merge(clusterIt1->forbiddenDetectionIds);
    clusters_.erase(clusterIt1);
}

bool ClusterStore::tryMergeTwoClusters(ClusterIter clusterIt0, ClusterIter clusterIt1)
{
    if (clusterIt0 == clusterIt1) {
        return true;
    }
    if (!mayBeMerged(*clusterIt0, *clusterIt1)) {
        return false;
    }
    mergeTwoClusters(clusterIt0, clusterIt1);
    return true;
}

bool ClusterStore::tryMergeThreeClusters(
    ClusterIter clusterIt0,
    ClusterIter clusterIt1,
    ClusterIter clusterIt2)
{
    if (clusterIt0 == clusterIt1 && clusterIt0 == clusterIt2) {
        return true;
    }
    if (clusterIt0 != clusterIt1 && ! mayBeMerged(*clusterIt0, *clusterIt1))
    {
        return false;
    }
    if (clusterIt0 != clusterIt2 && ! mayBeMerged(*clusterIt0, *clusterIt2)) {
        return false;
    }
    if (clusterIt1 != clusterIt2 && ! mayBeMerged(*clusterIt1, *clusterIt2)) {
        return false;
    }

    bool merge02 = (clusterIt0 != clusterIt2 && clusterIt1 != clusterIt2);
    if (clusterIt0 != clusterIt1) {
        mergeTwoClusters(clusterIt0, clusterIt1);
    }
    if (merge02) {
        mergeTwoClusters(clusterIt0, clusterIt2);
    }
    return true;
}

bool ClusterStore::tryMergeTwoClustersAndNewElement(
    ClusterIter clusterIt0,
    ClusterIter clusterIt1,
    const FrameDetectionId& fdId)
{
    if (clusterIt0 == clusterIt1) {
        if (!hasCommonFrame(fdId, clusterIt0->detectionIds)) {
            return false;
        }
        clusterByDetectionId_[fdId.detectionId] = clusterIt0;
        clusterIt0->detectionIds.insert(fdId);
        return true;
    }

    // clusterIt0 != clusterIt1
    if (!mayBeMerged(*clusterIt0, *clusterIt1)) {
        return false;
    }
    if (hasCommonFrame(fdId, clusterIt0->detectionIds) ||
        hasCommonFrame(fdId, clusterIt1->detectionIds))
    {
        return false;
    }
    clusterByDetectionId_[fdId.detectionId] = clusterIt0;
    clusterIt0->detectionIds.insert(fdId);
    mergeTwoClusters(clusterIt0, clusterIt1);
    return true;
}

std::vector<db::TIdSet> ClusterStore::makeDetectionClusters()
{
    std::vector<db::TIdSet> detectionClusters;
    detectionClusters.reserve(clusters_.size());

    for (const auto& cluster : clusters_) {
        db::TIdSet detectionCluster;
        for (const auto& [frameId, detectionId] : cluster.detectionIds) {
            detectionCluster.insert(detectionId);
        }
        detectionClusters.push_back(detectionCluster);
    }

    return detectionClusters;
}

void expandClustersByUnusedDetectionIds(
    db::TIdSet detectionIds,
    std::vector<db::TIdSet>& clusters)
{
    for (size_t i = 0; i < clusters.size(); i++) {
        for(auto it = clusters[i].begin(); it != clusters[i].end(); it++) {
            detectionIds.erase(*it);
        }
    }
    for (auto it = detectionIds.begin(); it != detectionIds.end(); it++) {
        db::TIdSet temp = {*it};
        clusters.emplace_back(temp);
    }

}

ClusterStore makeClusterStoreConsistentWithMatchesVerdicts(
    const DetectionStore& detectionStore,
    const db::TIdSet& detectionIds,
    const MatchedFrameDetections& matches)
{
    ClusterStore clusters;

    for (db::TId detectionId: detectionIds) {
        clusters.createNewCluster(
            {{detectionStore.frameId(detectionId), detectionId}});
    }

    std::vector<std::pair<FrameDetectionId, FrameDetectionId>> forbiddenMatches;
    std::vector<std::pair<FrameDetectionId, FrameDetectionId>> approvedMatches;

    for (const auto& match: matches) {
        if (match.verdict() == MatchedFrameDetection::Verdict::No) {
            forbiddenMatches.emplace_back(match.id0(), match.id1());

        } else if (match.verdict() == MatchedFrameDetection::Yes) {
            approvedMatches.emplace_back(match.id0(), match.id1());
        }
    }

    for (const auto& [id0, id1]: forbiddenMatches) {
        auto clusterIt0 = clusters.find(id0.detectionId);
        if (clusterIt0 != clusters.end()) {
            clusterIt0->second->forbiddenDetectionIds.insert(id1);
        }

        auto clusterIt1 = clusters.find(id1.detectionId);
        if (clusterIt1 != clusters.end()) {
            clusterIt1->second->forbiddenDetectionIds.insert(id0);
        }
    }

    for (const auto& [id0, id1]: approvedMatches) {
        auto clusterIt0 = clusters.find(id0.detectionId);
        auto clusterIt1 = clusters.find(id1.detectionId);
        if (clusterIt0 != clusters.end() && clusterIt1 != clusters.end())
            clusters.tryMergeTwoClusters(clusterIt0->second, clusterIt1->second);
    }
    return clusters;
}

} // namespace maps::mrc::eye
