#include "points_clustering.h"

#include "utils.h"

#include <maps/libs/geolib/include/distance.h>

#include <unordered_set>

namespace maps::mrc::sensors_feature_positioner {

namespace {

using ClusterId = size_t;
using PointDensity = std::pair<PointId, double>;

double metersBetween(const Point& point1,
                     const Point& point2)
{
    double mercatorDistance = geolib3::distance(point1.mercatorPos,
                                                point2.mercatorPos);
    return geolib3::toMeters(
        mercatorDistance,
        point1.mercatorPos);
}

double calculateNeighboringDensity(const Point& curPoint,
                                   const Points& allPoints,
                                   double maxClusterRadiusMeters)
{
    double density = 0;
    for (PointId neighborId : curPoint.neighbors) {
        // if distance_between_points == 0, density += 2
        // if distance_between_points == maxClusterRadiusMeters, density += 1
        density += 2 - (metersBetween(curPoint, allPoints.at(neighborId))
                        / maxClusterRadiusMeters);
    }
    return density;
}

// For each point calculates points density near the point
std::vector<PointDensity> calculateNeighboringDensity(
    const Points& points,
    double maxClusterRadiusMeters)
{
    std::vector<PointDensity> neighboringDensity;
    neighboringDensity.reserve(points.size());
    for (auto& [id, point] : points) {
        neighboringDensity.push_back(
            {id, calculateNeighboringDensity(point, points, maxClusterRadiusMeters)});
    }
    return neighboringDensity;
}

// first neighbor will be the closest one
void sortNeighborsByDistance(Points& points) {
    for (auto& it : points) {
        auto& point = it.second;
        std::sort(point.neighbors.begin(), point.neighbors.end(),
                  [&](const PointId& lhs, const PointId& rhs) {
                      return metersBetween(point, points[lhs])
                          < metersBetween(point, points[rhs]);
                  });
    }
}

std::unordered_set<PointId> findClustersCenters(
    const Points& points,
    std::vector<PointDensity> neighboringDensity)
{
    std::sort(neighboringDensity.begin(), neighboringDensity.end(),
              [](const PointDensity& lhs, const PointDensity& rhs) {
                  return lhs.second > rhs.second;
              });

    std::unordered_set<PointId> usedPoints;
    std::unordered_set<PointId> clustersCenters;

    for (auto& [pointId, density] : neighboringDensity) {
        if (usedPoints.count(pointId)) {
            continue;
        }
        clustersCenters.insert(pointId);
        usedPoints.insert(pointId);
        const auto& neighbors = points.at(pointId).neighbors;
        usedPoints.insert(neighbors.begin(), neighbors.end());
    }
    return clustersCenters;
}

} // anonymous namespace

std::vector<Cluster> getClusters(Points points, double maxClusterRadiusMeters)
{
    sortNeighborsByDistance(points);

    std::vector<PointDensity> neighboringDensity
        = calculateNeighboringDensity(points, maxClusterRadiusMeters);

    std::unordered_set<PointId> clustersCenters
        = findClustersCenters(points, std::move(neighboringDensity));

    std::unordered_map<ClusterId, std::vector<PointId>> clusterPoints;
    std::unordered_map<ClusterId, std::set<PointId>> prohibitedPoints;

    for (auto& [id, point] : points) {
        if (clustersCenters.count(id)) {
            clusterPoints[id].push_back(id);
            continue;
        }
        std::optional<PointId> nearestClusterCenter;
        double distanceToNearestCenter = std::numeric_limits<double>::max();
        for (PointId neighborId : point.neighbors) {
            ClusterId clusterId = neighborId;
            if (!clustersCenters.count(clusterId)) {
                continue;
            }
            if (prohibitedPoints[clusterId].count(id)) {
                continue;
            }
            double curDistance = metersBetween(point, points[clusterId]);
            if (curDistance < distanceToNearestCenter) {
                nearestClusterCenter = clusterId;
                distanceToNearestCenter = curDistance;
            }
        }
        if (!nearestClusterCenter) {
            // The nearest cluster was not found because
            // all the clusters around already have prohibited neighbors.
            // We should create additional cluster
            nearestClusterCenter = id;
        }
        clusterPoints[*nearestClusterCenter].push_back(id);
        prohibitedPoints[*nearestClusterCenter].insert(
            point.prohibitedNeighbors.begin(),
            point.prohibitedNeighbors.end());
    }

    std::vector<Cluster> clusters;
    for (auto& [clusterId, pointIds] : clusterPoints) {
        clusters.push_back(pointIds);
    }
    return clusters;
}

} // namespace maps::mrc::sensors_feature_positioner
