#include "cluster_store.h"
#include <maps/wikimap/mapspro/services/mrc/eye/lib/detection/include/position_clusterizer.h>
#include <maps/wikimap/mapspro/services/mrc/eye/lib/detection/include/utils.h>
#include <maps/wikimap/mapspro/services/mrc/eye/lib/location/include/location.h>

#include <maps/libs/geolib/include/distance.h>

#include <boost/geometry.hpp>
#include <boost/geometry/geometry.hpp>
#include <boost/geometry/index/rtree.hpp>

namespace maps::mrc::eye {

namespace {

Location estimateObjectLocation(
    const DetectionStore& store,
    const FrameDetectionIdSet& frameDetectionIds)
{
    db::TIdSet detectionIds;
    for (const auto& [frameId, detectionId] : frameDetectionIds) {
        detectionIds.insert(detectionId);
    }

    const auto [groups, detections, frames, locations, devices] = store.slice(detectionIds);

    // Don't check type consistency!
    const auto type = groups.front().type();

    switch (type) {
        case db::eye::DetectionType::HouseNumber:
            return findHouseNumberLocation(frames, locations, detections);

        case db::eye::DetectionType::Sign:
            return findSignLocation(devices, frames, locations, detections);

        case db::eye::DetectionType::TrafficLight:
            return findTrafficLightLocation(devices, frames, locations, detections);

        case db::eye::DetectionType::RoadMarking:
            return findRoadMarkingLocation(locations);

        default:
            throw RuntimeError() << "Bad detection type " << type;
    }
}

double avg(double weightX, double x, double weightY, double y)
{
    return (weightX*x + weightY*y);
}

geolib3::Point2 mergeObjectPosition(
        const DetectionStore& store,
        const FrameDetectionIdSet& lhsFrameDetectionIds, const geolib3::Point2& lhs,
        const FrameDetectionIdSet& rhsFrameDetectionIds, const geolib3::Point2& rhs)
{
    ASSERT(not lhsFrameDetectionIds.empty() and not rhsFrameDetectionIds.empty());

    // Don't check type consistency!
    const auto type = store.getDetectionType(lhsFrameDetectionIds.begin()->detectionId);

    double total = lhsFrameDetectionIds.size() + rhsFrameDetectionIds.size();

    double lhsWeight = lhsFrameDetectionIds.size() / total;
    double rhsWeight = rhsFrameDetectionIds.size() / total;

    switch (type) {
        case db::eye::DetectionType::HouseNumber:
            {
                FrameDetectionIdSet totalFrameDetectionIds{lhsFrameDetectionIds};
                totalFrameDetectionIds.insert(rhsFrameDetectionIds.begin(), rhsFrameDetectionIds.end());
                return estimateObjectLocation(store, totalFrameDetectionIds).mercatorPosition;
            }
        case db::eye::DetectionType::Sign:
        case db::eye::DetectionType::TrafficLight:
        case db::eye::DetectionType::RoadMarking:
            return {
                avg(lhsWeight, lhs.x(), rhsWeight, rhs.x()),
                avg(lhsWeight, lhs.y(), rhsWeight, rhs.y())
            };

        default:
            throw RuntimeError() << "Bad detection type " << type;
    }
}

struct PseudoClique {
    db::eye::DetectionType type;
    ClusterStore::ConstIterator clusterIt;
    geolib3::Point2 mercatorPos;
};

using PseudoCliqueList = std::list<PseudoClique>;

using IndexPoint = boost::geometry::model::point<double, 2, boost::geometry::cs::cartesian>;
using IndexBox = boost::geometry::model::box<IndexPoint>;

struct PseudoCliqueIndexer {
    using result_type = IndexPoint;

    result_type operator()(PseudoCliqueList::iterator it) const
    {
        return {it->mercatorPos.x(), it->mercatorPos.y()};
    }
};

IndexBox makeIndexBox(const geolib3::Point2& point, double radiusMeters)
{
    const double radius = geolib3::toMercatorUnits(radiusMeters, point);

    return IndexBox(
        IndexPoint(point.x() - radius, point.y() - radius),
        IndexPoint(point.x() + radius, point.y() + radius)
    );
}

using SpatialIndex = boost::geometry::index::rtree<
    PseudoCliqueList::iterator,
    boost::geometry::index::quadratic<16>,
    PseudoCliqueIndexer
>;

using PseudoCliqueList = std::list<PseudoClique>;

double cliqueMergeDistanceMeters(db::eye::DetectionType type)
{
    switch (type) {
        case db::eye::DetectionType::Sign:
            return 20;
        case db::eye::DetectionType::TrafficLight:
            return 30;
        case db::eye::DetectionType::HouseNumber:
            return 50;
        case db::eye::DetectionType::RoadMarking:
            return 30;
        default:
            throw RuntimeError() << "Invalid detection type " << type;
    }
}

bool havePositiveMatchWeight(
    const db::IdTo<db::IdTo<RelevanceWithVerdict>>& matchWeights,
    db::TId lhsDetectionId, db::TId rhsDetectionId)
{
    auto lhsMatchWeightsIt = matchWeights.find(lhsDetectionId);
    if (matchWeights.end() != lhsMatchWeightsIt) {
        auto& lhsMatchWeight = lhsMatchWeightsIt->second;
        auto weightIt = lhsMatchWeight.find(rhsDetectionId);
        if (lhsMatchWeight.end() != weightIt) {
            if (weightIt->second.verdict != MatchedFrameDetection::Verdict::No &&
                weightIt->second.relevance > 0)
            {
                return true;
            }
        }
    }
    return false;
}

bool areConnected(
    const db::IdTo<db::IdTo<RelevanceWithVerdict>>& matchWeights,
    const FrameDetectionIdSet& lhsFrameDetectionIds,
    const FrameDetectionIdSet& rhsFrameDetectionIds)
{
    if (hasCommonFrame(lhsFrameDetectionIds, rhsFrameDetectionIds)) {
        return false;
    }

    for (const auto& [lhsFrameId, lhsDetectionId] : lhsFrameDetectionIds) {
        for (const auto& [rhsFrameId, rhsDetectionId] : rhsFrameDetectionIds) {
            if (havePositiveMatchWeight(matchWeights, lhsDetectionId, rhsDetectionId)) {
                return true;
            }
            if (havePositiveMatchWeight(matchWeights, rhsDetectionId, lhsDetectionId)) {
                return true;
            }
        }
    }

    return false;
}

} // namespace

void mergePseudoCliques(
    const DetectionStore& store,
    const MatchedFrameDetections& matches,
    ClusterStore& clusters)
{

    PseudoCliqueList cliques;
    SpatialIndex index;

    for (auto clusterIt = clusters.begin(); clusterIt != clusters.end(); clusterIt++) {
        const auto type = store.getDetectionType(clusterIt->second->detectionIds.begin()->detectionId);
        auto [position, _] = estimateObjectLocation(store, clusterIt->second->detectionIds);

        cliques.push_front({type, clusterIt, position});
        index.insert(cliques.begin());
    }

    db::IdTo<db::IdTo<RelevanceWithVerdict>> matchWeights;
    for (size_t i = 0; i < matches.size(); i++) {
        const auto& [frameId0, detectionId0] = matches[i].id0();
        const auto& [frameId1, detectionId1] = matches[i].id1();
        auto relevanceWithVerdict = RelevanceWithVerdict{
            .relevance = matches[i].relevance(),
            .verdict = matches[i].verdict()
        };

        matchWeights[detectionId0][detectionId1] = relevanceWithVerdict;
        matchWeights[detectionId1][detectionId0] = std::move(relevanceWithVerdict);
    }

    auto getNearest = [&](PseudoCliqueList::iterator it) {
        const auto type = it->type;
        const geolib3::Point2& position = it->mercatorPos;

        const double epsilon = geolib3::toMercatorUnits(cliqueMergeDistanceMeters(type), position);
        auto query = boost::geometry::index::intersects(makeIndexBox(position, epsilon));

        auto result = cliques.end();
        double minDistance = epsilon;

        for (auto queryIt = index.qbegin(query); queryIt != index.qend(); ++queryIt) {
            auto other = *queryIt;
            if (other == it or other->type != type) {
                continue;
            }

            const double distance = geolib3::distance(position, other->mercatorPos);
            if (distance >= minDistance) {
                continue;
            }

            if (areConnected(matchWeights, it->clusterIt->second->detectionIds,
                    other->clusterIt->second->detectionIds))
            {
                minDistance = distance;
                result = other;
            }
        }

        return result;
    };

    for (bool mergedAnyClique = true; mergedAnyClique; ) {
        mergedAnyClique = false;

        for (auto it = cliques.begin(); it != cliques.end(); ) {
            auto other = getNearest(it);

            if (other == cliques.end()) {
                ++it;
                continue;
            }

            index.remove(it);
            index.remove(other);

            // const auto [newPosition, _] = estimateObjectLocation(store, it->detectionIds);
            // it->mercatorPos = newPosition;

            // One more temporary hack
            // Speed up reposition of two merged cliques
            it->mercatorPos = mergeObjectPosition(
                store,
                it->clusterIt->second->detectionIds,
                it->mercatorPos,
                other->clusterIt->second->detectionIds,
                other->mercatorPos
            );

            if (clusters.tryMergeTwoClusters(it->clusterIt->second, other->clusterIt->second)) {
                cliques.erase(other);
                index.insert(it);
                mergedAnyClique = true;
            }
        }
    }
}

std::vector<db::TIdSet> PositionDetectionClusterizer::clusterize(
    const DetectionStore& store,
    const db::TIdSet& detectionIds,
    const MatchedFrameDetections& matches) const
{
    if (detectionIds.empty()) {
        return {};
    }
    ClusterStore clusters = makeClusterStoreConsistentWithMatchesVerdicts(store, detectionIds, matches);
    mergePseudoCliques(store, matches, clusters);

    return clusters.makeDetectionClusters();
}

} // namespace maps::mrc::eye
