#include "azimuth.h"

#include "geom.h"
#include "match_function.h"

#include <maps/libs/geolib/include/distance.h>
#include <maps/libs/log8/include/log8.h>

#include <cmath>

using maps::geolib3::Radians;

namespace maps::mrc::pos_improvment {

// Calculates azimuth for each gps point using gyroscope rotation speed and
// zero starting value.
// @param events should contain gps, gravity and gyro events.
std::vector<Radians> findGyroAzimuthForEachGps(const TrackEventPtrs& events,
                                               const GpsEvents& gpsEvents,
                                               const GravityEvents& gravityEvents,
                                               const GyroscopeEvents& gyroEvents) {
    REQUIRE(events.size() > 0 && gpsEvents.size() > 0 &&
            gravityEvents.size() > 0 && gyroEvents.size(),
            "Empty input");
    std::vector<Radians> gyroAzimuthForEachGps;
    gyroAzimuthForEachGps.reserve(gpsEvents.size());

    Time prevTime = events[0]->time;
    UnitVector3 curGravity = gravityEvents[0].values();
    RotationSpeedVector curGyroValue = gyroEvents[0].values();
    Radians curAzimuth(0);

    for (auto trackEvent : events) {
        if (trackEvent->type == EventType::Gyroscope) {
            auto& event = *static_cast<GyroscopeEvent const *>(trackEvent);
            curGyroValue = event.values();
        }
        if (trackEvent->type == EventType::Gravity) {
            auto& event = *static_cast<GravityEvent const *>(trackEvent);
            curGravity = event.values();
        }
        RadiansPerSec azimuthRotSpeed = -curGyroValue.rotSpeedAroundAxis(curGravity);
        Radians azimuthRotation = azimuthRotSpeed * (trackEvent->time - prevTime);
        curAzimuth += azimuthRotation;
        prevTime = trackEvent->time;

        if (trackEvent->type == EventType::Gps) {
            gyroAzimuthForEachGps.push_back(curAzimuth);
        }
    }
    return gyroAzimuthForEachGps;
}

namespace {

// Calculates precise azimuth for each gps point.
// gps azimuth has precise average values for any long time range
// gyro azimuth has precise derivative in any short time range
std::vector<Radians> findActualAzimuzhForEachGps(
    const GpsEvents& gpsEvents,
    std::vector<Radians> gyroAzimuthForEachGps)
{
    REQUIRE(gpsEvents.size() == gyroAzimuthForEachGps.size(),
            "gyroAzimuthForEachGps should contain the same number of points");
    const int FRAME_WIDTH = 70;

    std::vector<Radians> gpsAzimuthForEachGps(gpsEvents.size());
    for (size_t i = 0; i < gpsEvents.size(); i++) {
        gpsAzimuthForEachGps[i] = *gpsEvents[i].direction;
    }

    auto getWeight = [&](int64_t baseIndex, int64_t curIndex) {
        // the bigger the distance between two gps points, the bigger
        // gyroscope errors sum between two gps points
        double distanceWeight = 1 - 0.01 * std::abs(curIndex - baseIndex);
        distanceWeight = std::max(0.0, distanceWeight);

        // azimuth is more precise when speed is big
        double speedWeight = (0.1 + std::min(0.1, gpsEvents[curIndex].speed->value()));
        return speedWeight * distanceWeight;
    };

    return matchFunction(
        gpsAzimuthForEachGps, gyroAzimuthForEachGps, FRAME_WIDTH, getWeight);
}

// Removes points with small speed, points without azimuth, points
// which are close to each other
GpsEvents filterOutEventsWithoutAzimuth(const GpsEvents& gpsEvents) {
    constexpr MetersPerSec ENOUGH_SPEED_FOR_GOOD_AZIMUTH(3.0);
    GpsEvents filteredEvents;
    filteredEvents.reserve(gpsEvents.size());
    geolib3::Point2 prevEventPos{0, 0};

    for(const auto& gpsEvent : gpsEvents) {
        Meters distanceFromLastPoint = distanceMeters(gpsEvent.mercatorPos,
                                                      prevEventPos);
        if (gpsEvent.hasValidDirection()
            && (distanceFromLastPoint > 1.0_m
                || *gpsEvent.speed > ENOUGH_SPEED_FOR_GOOD_AZIMUTH))
        {
            filteredEvents.push_back(gpsEvent);
            prevEventPos = gpsEvent.mercatorPos;
        }
    }
    return filteredEvents;
}

} // anonymous namespace

CarGroundDirectionEvents calculateAzimuthEvents(
    GpsEvents gpsEvents,
    const GravityEvents& gravityEvents,
    const GyroscopeEvents& gyroEvents)
{
    gpsEvents = filterOutEventsWithoutAzimuth(std::move(gpsEvents));
    TrackEventPtrs events;
    events.reserve(gpsEvents.size() + gravityEvents.size() + gyroEvents.size());
    insert(events, gpsEvents);
    insert(events, gravityEvents);
    insert(events, gyroEvents);

    std::vector<Radians> gyroAzimuthForEachGps = findGyroAzimuthForEachGps(
        events, gpsEvents, gravityEvents, gyroEvents);
    std::vector<Radians> actualAzimuthForEachGps
        = findActualAzimuzhForEachGps(gpsEvents, gyroAzimuthForEachGps);

    // calculate azimuth between gps points
    CarGroundDirectionEvents azimuthEvents;
    azimuthEvents.reserve(gyroEvents.size());

    Time prevTime = events[0]->time;
    UnitVector3 curGravity = gravityEvents[0].values();
    RotationSpeedVector curGyroValue = gyroEvents[0].values();
    Radians curAzimuth(0); // will be initialized on the first gps point
    int gpsIndex = 0;

    for (auto trackEvent : events) {
        if (trackEvent->type == EventType::Gyroscope) {
            auto& event = *static_cast<GyroscopeEvent const *>(trackEvent);
            curGyroValue = event.values();
        }
        if (trackEvent->type == EventType::Gravity) {
            auto& event = *static_cast<GravityEvent const *>(trackEvent);
            curGravity = event.values();
        }
        RadiansPerSec azimuthRotSpeed = -curGyroValue.rotSpeedAroundAxis(curGravity);
        Radians azimuthRotation = azimuthRotSpeed * (trackEvent->time - prevTime);
        curAzimuth += azimuthRotation;
        prevTime = trackEvent->time;

        if (trackEvent->type == EventType::Gps) {
            curAzimuth = actualAzimuthForEachGps[gpsIndex];
            gpsIndex++;
        }
        if (gpsIndex > 0 && trackEvent->type == EventType::Gyroscope) {
            azimuthEvents.emplace_back(
                EventType::CarGroundDirection, trackEvent->time, curAzimuth);
        }
    }

    return azimuthEvents;
}

} // namespace maps::mrc::pos_improvment
