#include "speed.h"

#include "match_function.h"

#include <maps/libs/log8/include/log8.h>

namespace maps::mrc::pos_improvment {

namespace {

// Calculates speed for each gps point using accelerometer and
// zero starting value.
// @param events should contain gps, horizontalFront and acc events.
std::vector<MetersPerSec> findAccSpeedForEachGps(
    TrackEventPtrs& events,
    const GpsEvents& gpsEvents,
    const CarHorizontalFrontVecEvents& horFrontEvents,
    const AccelerometerEvents& accEvents)
{
    REQUIRE(events.size() > 0 && gpsEvents.size() > 0 &&
            horFrontEvents.size() > 0 && accEvents.size(),
            "Empty input");
    std::vector<MetersPerSec> accSpeedForEachGps;
    accSpeedForEachGps.reserve(gpsEvents.size());

    Time prevTime = events[0]->time;
    UnitVector3 curFrontVector = horFrontEvents[0].values();
    AccelerationVector curAcceleration = accEvents[0].values();
    MetersPerSec curSpeed(0);

    for (auto trackEvent : events) {
        if (trackEvent->type == EventType::Accelerometer) {
            auto& event = *static_cast<AccelerometerEvent const *>(trackEvent);
            curAcceleration = event.values();
        }
        if (trackEvent->type == EventType::CarHorizontalFrontVector) {
            auto& event = *static_cast<CarHorizontalFrontVecEvent const *>(trackEvent);
            curFrontVector = event.values();
        }
        MetersPerSec2 frontAcc = curAcceleration.accAlongAxis(curFrontVector);
        MetersPerSec speedDelta = frontAcc * (trackEvent->time - prevTime);
        curSpeed += speedDelta;
        prevTime = trackEvent->time;

        if (trackEvent->type == EventType::Gps) {
            accSpeedForEachGps.push_back(curSpeed);
        }
    }
    return accSpeedForEachGps;
}

// Calculates precise speed for each gps point.
// gps speed has precise average values for any long time range
// acc speed has precise derivative in any short time range
std::vector<MetersPerSec> findActualSpeedForEachGps(
    const GpsEvents& gpsEvents,
    const std::vector<MetersPerSec> accSpeedForEachGps)
{
    REQUIRE(gpsEvents.size() == accSpeedForEachGps.size(),
            "accSpeedForEachGps should contain the same number of points");
    const int FRAME_WIDTH = 15;

    std::vector<MetersPerSec> gpsSpeedForEachGps(gpsEvents.size());
    for (size_t i = 0; i < gpsEvents.size(); i++) {
        gpsSpeedForEachGps[i] = *gpsEvents[i].speed;
    }

    auto getWeight = [&](int64_t baseIndex, int64_t curIndex) {
        // the bigger the distance between two gps points, the bigger
        // accelerometer errors sum between two gps points
        double distanceWeight = 1 - 0.01 * std::abs(curIndex - baseIndex);
        distanceWeight = std::max(0.0, distanceWeight);
        return distanceWeight;
    };

    return matchFunction(
        gpsSpeedForEachGps, accSpeedForEachGps, FRAME_WIDTH, getWeight);
}

// if the calculated speed looks suspiciously, it is better to use
// the unmodified speed
std::vector<MetersPerSec> fixBadSpeedPoints(
    const GpsEvents& gpsEvents,
    std::vector<MetersPerSec> improvedSpeedForEachGps)
{
    //return improvedSpeedForEachGps;
    constexpr MetersPerSec MAX_POSSIBLE_SPEED(110 / 3.6);
    constexpr MetersPerSec MAX_POSSIBLE_BACKWARD_SPEED(3);
    constexpr MetersPerSec2 MAX_POSSIBLE_ACCELERATION(3.0);
    constexpr MetersPerSec2 MAX_POSSIBLE_BREAKING(-5.0);
    for (size_t i = 0; i < gpsEvents.size()-1; i++) {
        MetersPerSec2 improvedAcc
            = (improvedSpeedForEachGps[i + 1] - improvedSpeedForEachGps[i])
            / (gpsEvents[i + 1].time - gpsEvents[i].time);
        MetersPerSec2 gpsAcc
            = (*gpsEvents[i + 1].speed - *gpsEvents[i].speed)
            / (gpsEvents[i + 1].time - gpsEvents[i].time);

        if ((improvedAcc > MAX_POSSIBLE_ACCELERATION
             || improvedAcc < MAX_POSSIBLE_BREAKING)
            && abs(improvedAcc) > abs(gpsAcc)) {
            improvedSpeedForEachGps[i] = *gpsEvents[i].speed;
        }
    }

    for (size_t i = 0; i < gpsEvents.size(); i++) {
        if ((improvedSpeedForEachGps[i] > MAX_POSSIBLE_SPEED
             || improvedSpeedForEachGps[i] < -MAX_POSSIBLE_BACKWARD_SPEED)
            && abs(improvedSpeedForEachGps[i]) > abs(*gpsEvents[i].speed))
        {
            improvedSpeedForEachGps[i] = *gpsEvents[i].speed;
        }
    }
    return improvedSpeedForEachGps;
}

// Removes points with small speed, points without azimuth, points
// which are close to each other
GpsEvents filterOutEventsWithoutSpeed(const GpsEvents& gpsEvents) {
    GpsEvents filteredEvents;
    filteredEvents.reserve(gpsEvents.size());
    for(const auto& gpsEvent : gpsEvents) {
        if (gpsEvent.speed) {
            filteredEvents.push_back(gpsEvent);
        }
    }
    return filteredEvents;
}

} // anonymous namespace

CarGroundSpeedEvents calculateSpeedEvents(
    GpsEvents gpsEvents,
    const CarHorizontalFrontVecEvents& horFrontEvents,
    const AccelerometerEvents& accEvents)
{
    gpsEvents = filterOutEventsWithoutSpeed(std::move(gpsEvents));
    TrackEventPtrs events;
    events.reserve(gpsEvents.size() + horFrontEvents.size() + accEvents.size());
    insert(events, gpsEvents);
    insert(events, horFrontEvents);
    insert(events, accEvents);

    std::vector<MetersPerSec> accSpeedForEachGps = findAccSpeedForEachGps(
        events, gpsEvents, horFrontEvents, accEvents);
    std::vector<MetersPerSec> actualSpeedForEachGps = findActualSpeedForEachGps(
        gpsEvents, accSpeedForEachGps);
    actualSpeedForEachGps = fixBadSpeedPoints(gpsEvents, actualSpeedForEachGps);

    // calculate azimuth between gps points

    CarGroundSpeedEvents speedEvents;
    speedEvents.reserve(accEvents.size());

    Time prevTime = accEvents[0].time;
    UnitVector3 curFrontVector = horFrontEvents[0].values();
    AccelerationVector curAcceleration = accEvents[0].values();
    MetersPerSec curSpeed(0); // will be initialized on the first gps point
    size_t gpsIndex = 0;

    for (auto trackEvent : events) {
        if (trackEvent->type == EventType::Accelerometer) {
            auto& event = *static_cast<AccelerometerEvent const *>(trackEvent);
            curAcceleration = event.values();
        }
        if (trackEvent->type == EventType::CarHorizontalFrontVector) {
            auto& event = *static_cast<CarHorizontalFrontVecEvent const *>(trackEvent);
            curFrontVector = event.values();
        }
        MetersPerSec2 frontAcc = curAcceleration.accAlongAxis(curFrontVector);
        MetersPerSec speedDelta = frontAcc * (trackEvent->time - prevTime);
        curSpeed += speedDelta;
        prevTime = trackEvent->time;

        if (trackEvent->type == EventType::Gps) {
            curSpeed = actualSpeedForEachGps[gpsIndex];
            gpsIndex++;
        }
        if (gpsIndex > 0 && trackEvent->type == EventType::Accelerometer) {
            speedEvents.emplace_back(EventType::CarGroundSpeed, trackEvent->time, curSpeed);
        }
    }

    return speedEvents;
}

} // namespace maps::mrc::pos_improvment
