#include "clustering.h"

#include "turn.h"

#include <maps/libs/geolib/include/point.h>
#include <maps/libs/geolib/include/vector.h>
#include <maps/libs/geolib/include/static_geometry_searcher.h>
#include <maps/libs/geolib/include/distance.h>

#include <maps/libs/log8/include/log8.h>

#include <set>

namespace maps {
namespace wiki {
namespace signals_graph {
namespace {

constexpr double MAX_ANGLE_SIM_RAD = M_PI / 4;

constexpr double TURN_CLUSTER_DISTANCE = 25.0;
constexpr double CROSS_CLUSTER_DISTANCE = 40.0;

std::vector<size_t> findClose(
    const std::vector<Turn>& turns,
    const geolib3::StaticGeometrySearcher<geolib3::Point2, size_t>& searcher,
    const geolib3::BoundingBox& bbox,
    double clusterDistance,
    const Turn& elem,
    std::function<bool(const Turn&, const Turn&)> cmp = {}
) {
    auto res = searcher.find(bbox);

    std::vector<size_t> close;

    for (auto it = res.first; it != res.second; ++it) {
        const geolib3::Point2& point = it->geometry();
        const Turn& neighbor = turns[it->value()];

        auto dist = geolib3::distance(elem.mercPoint, point);

        if (dist > clusterDistance) {
            continue;
        }

        if (!cmp || cmp(elem, neighbor)) {
            close.push_back(it->value());
        }
    }

    return close;
}

void clusterClose(
    std::vector<Turn>& turns,
    double clusterDistance,
    std::vector<std::set<size_t>>& clusterIndex,
    std::function<bool(const Turn&, const Turn&)> cmp = {}
) {
    geolib3::StaticGeometrySearcher<geolib3::Point2, size_t> searcher;
    int nextCluster = 0;

    for (size_t i = 0; i < turns.size(); ++i) {
        searcher.insert(&turns[i].mercPoint, i);
    }

    searcher.build();

    geolib3::Vector2 diag(2 * clusterDistance, 2 * clusterDistance);

    for (size_t elemIndex = 0; elemIndex < turns.size(); ++elemIndex) {
        Turn& elem = turns[elemIndex];
        if (elem.clusterId != -1) {
            continue;
        }

        geolib3::BoundingBox bbox(elem.mercPoint - diag, elem.mercPoint + diag);

        auto close = findClose(turns, searcher, bbox, clusterDistance, elem, cmp);

        for (const auto index : close) {
            if (turns[index].clusterId != -1) {
                elem.clusterId = turns[index].clusterId;
                clusterIndex[elem.clusterId].insert(elemIndex);

                break;
            }
        }

        if (elem.clusterId == -1 && close.size() > 0) {
            clusterIndex.emplace_back();
            clusterIndex.back().insert(elemIndex);

            elem.clusterId = nextCluster;
            for (const auto index : close) {
                turns[index].clusterId = nextCluster;
                clusterIndex.back().insert(index);
            }

            ++nextCluster;
        }
    }
}

Turn averageTurn(
    const std::vector<Turn>& turns,
    const std::set<size_t>& cluster
) {
    double xSum = 0.0;
    double ySum = 0.0;

    double speedSum = 0.0;

    geolib3::Vector2 inVecSum{0.0, 0.0};
    geolib3::Vector2 outVecSum{0.0, 0.0};

    size_t weight = 0;

    for (size_t index : cluster) {
        const Turn& turn = turns[index];

        xSum += turn.mercPoint.x() * turn.weight;
        ySum += turn.mercPoint.y() * turn.weight;

        speedSum += turn.averageSpeed * turn.weight;

        inVecSum += turn.directionIn.vector() * turn.weight;
        outVecSum += turn.directionOut.vector() * turn.weight;

        weight += turn.weight;
    }

    auto inDir = geolib3::Direction2(inVecSum / weight);
    auto outDir = geolib3::Direction2(outVecSum / weight);

    geolib3::Point2 avgMercPoint{xSum / weight, ySum / weight};

    return Turn{
        "",
        speedSum / weight, geolib3::angle(inDir, outDir),
        inDir, outDir, weight, -1,
        avgMercPoint
    };
}

} // namespace

std::vector<Turn> clusterTurns(std::vector<Turn>& turns) {
    std::vector<std::set<size_t>> clusterIndex;
    clusterClose(
        turns, TURN_CLUSTER_DISTANCE, clusterIndex,
        [](const Turn& cur, const Turn& other) {
            double inDiff = angle(cur.directionIn, other.directionIn);
            double outDiff = angle(cur.directionOut, other.directionOut);

            return inDiff <= MAX_ANGLE_SIM_RAD && outDiff <= MAX_ANGLE_SIM_RAD;
        }
    );

    std::vector<Turn> result;

    for (size_t index = 0; index < clusterIndex.size(); ++index) {
        result.push_back(averageTurn(turns, clusterIndex[index]));
    }

    return result;
}

std::vector<Turn> clusterClusters(std::vector<Turn>& turns) {
    geolib3::StaticGeometrySearcher<geolib3::Point2, size_t> searcher;
    std::vector<std::set<size_t>> clusterIndex;

    clusterClose(turns, CROSS_CLUSTER_DISTANCE, clusterIndex);

    std::vector<Turn> result;

    for (size_t index = 0; index < clusterIndex.size(); ++index) {
        result.push_back(averageTurn(turns, clusterIndex[index]));
    }

    return result;
}

} // namespace maps
} // namespace wiki
} // namespace signals_graph
