#include <maps/wikimap/mapspro/services/mrc/libs/classifiers/include/position_accuracy_classifier.h>
#include <maps/libs/geolib/include/distance.h>
#include <kernel/catboost/catboost_calcer.h>

#include <algorithm>

extern const NCatboostCalcer::TCatboostCalcer positionAccuracyClassifier;

namespace maps::mrc::classifiers {
namespace {

struct MovementStats {
    std::vector<float> accuracy;
    std::vector<float> distance;
    std::vector<float> speed;
};

MovementStats extractMovementStats(const db::TrackPoints& trackPoints)
{
    if (trackPoints.empty()) {
        return {};
    }
    std::vector<float> accuracy, distance, speed;
    accuracy.reserve(trackPoints.size());
    distance.reserve(trackPoints.size() - 1);
    speed.reserve(trackPoints.size() - 1);

    for (size_t i = 0; i != trackPoints.size(); ++i) {
        accuracy.push_back(trackPoints[i].accuracyMeters().value_or(-1));
        if (i == 0) {
            continue;
        }

        const auto& current = trackPoints[i];
        const auto& previous = trackPoints[i - 1];
        auto time = std::chrono::duration_cast<std::chrono::seconds>(current.timestamp() - previous.timestamp());

        auto currentDistance = geolib3::fastGeoDistance(current.geodeticPos(), previous.geodeticPos());
        auto currentSpeed = currentDistance / (time.count() + 1e-6f);
        distance.push_back(currentDistance);
        speed.push_back(currentSpeed);
    }

    return {accuracy, distance, speed};
}

float median(std::vector<float> values, float valueIfEmpty)
{
    auto end = std::remove(values.begin(), values.end(), -1.0f);
    size_t emptyCount = std::distance(end, values.end());
    size_t totalCount = values.size();
    values.erase(end, values.end());

    if (values.empty() or emptyCount > totalCount / 2) {
        return valueIfEmpty;
    }

    std::nth_element(values.begin(), values.begin() + values.size() / 2, values.end());
    if (values.size() % 2 == 0) {
        return (values[values.size() / 2 - 1] + values[values.size() / 2]) / 2;
    }
    return values[values.size() / 2];
}

float min(const std::vector<float>& values, float valueIfEmpty) {
    if (values.empty()) {
        return valueIfEmpty;
    }
    return *std::min_element(values.begin(), values.end());
}

float max(const std::vector<float>& values, float valueIfEmpty) {
    if (values.empty()) {
        return valueIfEmpty;
    }
    return *std::max_element(values.begin(), values.end());
}

} // anonymous namespace

std::array<float, 8> calculatePositionFeatures(const db::TrackPoints& neighborTrackPoints)
{
    auto [accuracy, distance, speed] = extractMovementStats(neighborTrackPoints);

    std::array<float, 8> features;
    features[0] = neighborTrackPoints.size();
    features[1] = median(accuracy, -1.0f);
    features[2] = max(accuracy, -1.0f);
    features[3] = min(accuracy, -1.0f);
    features[4] = median(distance, -1.0f);
    features[5] = max(distance, -1.0f);
    features[6] = median(speed, -1.0f);
    features[7] = max(speed, -1.0f);

    return features;
}

bool isPositionInaccurate(const db::TrackPoints& neighborTrackPoints)
{
    ASSERT(not neighborTrackPoints.empty());
    auto features = calculatePositionFeatures(neighborTrackPoints);
    auto probability = positionAccuracyClassifier.DoCalcRelev(features.data());

    static constexpr double badPositionThreshold = 0.5;
    return probability > badPositionThreshold;
}

} // maps::mrc::classifiers
