#include "gravity.h"

#include "utils.h"

#include <maps/libs/geolib/include/distance.h>

#include <cmath>

using maps::geolib3::Radians;
using maps::geolib3::PI;
using maps::geolib3::cos;
using maps::geolib3::sin;

namespace maps::mrc::pos_improvment {

namespace {

static constexpr int FRAME_WIDTH = 9; // number of neighboring gps points
                                      // for calculating gravity

// Checks if the pitchAngle is valid for the provided gps point.
// Returns positive value if the angle is greater than the valid angle
// Returns negative value if the angle is less than the valid angle
//
// Lets assume that the car has pitchAngle at gpsEventIndex position.
// Lets take some time range near the gpsEventIndex and calculate the
// speed change at the neighboring gps point using accelerometer and gyroscope.
// If the angle is too big, the speed after the gpsEventIndex will be
// too low, and the speed before the gpsEventIndex will be to high.
// If the angle is too small, the speed after the gpsEventIndex will be
// too high, and the speed before the gpsEventIndex will be to low.
// Returns (calculated speed before the point - gps speed before the point
// + gps speed after the point - calcualted speed after the point)
//
// @param events should contain gps, gyroscope and accelerometer events
class SpeedDeviationFinder {
public:
    SpeedDeviationFinder(const Vector3Events& events,
                         int gpsEventIndex,
                         const UnitVector3& carFrontVector,
                         const UnitVector3& carUpVector,
                         const UnitVector3& carRightVector,
                         Radians pitchAngle)
        : events_(events)
        , gpsEventIndex_(gpsEventIndex)
        , carFrontVector_(carFrontVector)
        , carUpVector_(carUpVector)
        , carRightVector_(carRightVector)
        , pitchAngle_(pitchAngle)
    {
        defineUsefullEventsRange();
    }

    MetersPerSec result() {
        return calculateDeviation();
    }

private:
    // Finds neighboring range with FRAME_WIDTH gpsSpeed events.
    // gpsEventIndex_ is the center of the range.
    // If gpsEventIndex_ is close to events_ border, the result range is
    // not symmetrical and contains less than rangeSize eventType events.
    //
    // Fills firstGpsIndex_, lastGpsIndex_, middleTime_
    void defineUsefullEventsRange() {
        std::vector<Time> gpsPointTimes;
        gpsPointTimes.reserve(FRAME_WIDTH);

        size_t pointsNum = 0;
        firstGpsIndex_ = gpsEventIndex_;
        for (int i = gpsEventIndex_ - 1; i >= 0; i--) {
            if (events_[i].type == EventType::GpsSpeed) {
                firstGpsIndex_ = i;
                pointsNum++;
                gpsPointTimes.push_back(events_[i].time);
                if (pointsNum == FRAME_WIDTH / 2) {
                    break;
                }
            }
        }
        std::reverse(gpsPointTimes.begin(), gpsPointTimes.end());

        pointsNum = 0;
        lastGpsIndex_ = gpsEventIndex_;
        for (size_t i = gpsEventIndex_ + 1; i < events_.size(); i++) {
            if (events_[i].type == EventType::GpsSpeed) {
                lastGpsIndex_ = i;
                pointsNum++;
                gpsPointTimes.push_back(events_[i].time);
                if (pointsNum == FRAME_WIDTH / 2) {
                    break;
                }
            }
        }

        if (gpsPointTimes.size() % 2 == 1) {
            middleTime_ = gpsPointTimes[gpsPointTimes.size() / 2];
        } else {
            middleTime_ = gpsPointTimes[gpsPointTimes.size() / 2 - 1] +
                (gpsPointTimes[gpsPointTimes.size() / 2]
                 - gpsPointTimes[gpsPointTimes.size() / 2 - 1]) / 2.0;
        }
    }

    void initVars() {
        prevAccTime_ = events_[gpsEventIndex_].time;
        prevGyroTime_ = events_[gpsEventIndex_].time;
        speedBySensors_ = 0.0_mps;
        curPitchAngle_ = pitchAngle_;
        curFrontAcceleration_ = 0.0_mps2;
    }

    void handleSensors(const Vector3Event& trackEvent) {
        if (trackEvent.type == EventType::Accelerometer) {
            auto& event = static_cast<const AccelerometerEvent&>(trackEvent);
            UnitVector3 horizontalFrontVector(carFrontVector_ * cos(curPitchAngle_)
                                              - carUpVector_ * sin(curPitchAngle_));
            curFrontAcceleration_ = event.values().accAlongAxis(horizontalFrontVector);
            speedBySensors_ += curFrontAcceleration_ * (event.time - prevAccTime_);
            prevAccTime_ = event.time;
        } else if (trackEvent.type == EventType::Gyroscope) {
            auto& event = static_cast<const GyroscopeEvent&>(trackEvent);
            curPitchAngle_ += event.values().rotSpeedAroundAxis(carRightVector_)
                * (event.time - prevGyroTime_);
            prevGyroTime_ = event.time;
        } else /* if gps speed event */ {
            speedBySensors_ += curFrontAcceleration_ * (trackEvent.time - prevAccTime_);
            prevAccTime_ = trackEvent.time;
        }
    }

    // returns median(gpsSpeed - speedBySensors)
    MetersPerSec calculateSpeedDelta() {
        speedDifferences_.reserve(FRAME_WIDTH + 1);

        initVars();
        for (int i = gpsEventIndex_; i <= (int)lastGpsIndex_; i++) {
            handleSensors(events_[i]);
            countSpeedDelta(events_[i]);
        }
        initVars();
        for (int i = gpsEventIndex_ - 1; i >= (int)firstGpsIndex_; i--) {
            handleSensors(events_[i]);
            countSpeedDelta(events_[i]);
        }

        std::nth_element(speedDifferences_.begin(),
                         speedDifferences_.begin() + speedDifferences_.size() / 2,
                         speedDifferences_.end());
        return speedDifferences_[speedDifferences_.size() / 2];
    }

    void countSpeedDelta(const Vector3Event& trackEvent) {
        if (trackEvent.type == EventType::GpsSpeed) {
            auto& event = static_cast<const GpsSpeedEvent&>(trackEvent);
            speedDifferences_.push_back(event.gpsSpeed() - speedBySensors_);
        }
    }

    MetersPerSec calculateDeviation() {
        // Accelereometer gives us only the speed difference between gps
        // point, but not the speed value
        // We should calcualate speed from zero base and then find the
        // median difference between calculated speed and gps speed.
        // Then we should add this difference to the calculated speed so the median
        // difference between gps and calculated speed will be zero
        MetersPerSec medianSpeedDifference = calculateSpeedDelta();
        deviations_.reserve(FRAME_WIDTH + 1);

        initVars();
        speedBySensors_ = medianSpeedDifference; // now the median difference between
                                                 // gps and sensors speed is zero
        for (int i = gpsEventIndex_; i <= (int)lastGpsIndex_; i++) {
            handleSensors(events_[i]);
            countDeviation(events_[i]);
        }

        initVars();
        speedBySensors_ = medianSpeedDifference;
        for (int i = gpsEventIndex_ - 1; i >= (int)firstGpsIndex_; i--) {
            handleSensors(events_[i]);
            countDeviation(events_[i]);
        }

        std::nth_element(deviations_.begin(),
                         deviations_.begin() + deviations_.size() / 2,
                         deviations_.end());
        return deviations_[deviations_.size() / 2];
    }

    void countDeviation(const Vector3Event& trackEvent) {
        if (trackEvent.type == EventType::GpsSpeed) {
            auto& event = static_cast<const GpsSpeedEvent&>(trackEvent);
            MetersPerSec speedDelta = event.gpsSpeed() -speedBySensors_;
            if (event.time > middleTime_) {
                deviations_.push_back(speedDelta);
            } else {
                deviations_.push_back(-speedDelta);
            }
        }
    }

private:
    const Vector3Events& events_;
    int gpsEventIndex_;
    const UnitVector3& carFrontVector_;
    const UnitVector3& carUpVector_;
    const UnitVector3& carRightVector_;
    Radians pitchAngle_;
    size_t firstGpsIndex_;
    size_t lastGpsIndex_;
    Time middleTime_;

    MetersPerSec speedBySensors_;
    Time prevAccTime_;
    Time prevGyroTime_;
    Radians curPitchAngle_;
    MetersPerSec2 curFrontAcceleration_;

    std::vector<MetersPerSec> speedDifferences_;
    std::vector<MetersPerSec> deviations_;
};

// find the right pitch angle using a binary search
Radians findCarPitchAngle(Vector3Events& events,
                          size_t gpsEventIndex,
                          const UnitVector3& carFrontVector,
                          const UnitVector3& carUpVector,
                          const UnitVector3& carRightVector)
{
    // We want to find such pitchAngle that findSpeedDeviation(pitchAngle) == 0
    Radians lAngle = -PI / 4;
    Radians rAngle = PI / 4;
    while (rAngle - lAngle > 0.001_rad) {
        Radians curPitchAngle = (rAngle + lAngle) / 2.0;
        MetersPerSec deviation = SpeedDeviationFinder(
            events, gpsEventIndex,
            carFrontVector, carUpVector,
            carRightVector, curPitchAngle).result();
        if (deviation < 0.0_mps) {
            lAngle = curPitchAngle;
        } else {
            rAngle = curPitchAngle;
        }
    }
    return (rAngle + lAngle) / 2.0;
}

} // anonymous namespace

std::pair<GravityEvents, CarHorizontalFrontVecEvents>
findGravityAndHorizontalFrontVector(const AccelerometerEvents& accEvents,
                                    const GyroscopeEvents& gyroEvents,
                                    const GpsEvents& gpsEvents,
                                    const UnitVector3& carUpVector,
                                    const UnitVector3& carFrontVector)
{
    GpsSpeedEvents gpsSpeedEvents;
    gpsSpeedEvents.reserve(gpsEvents.size());
    for (const auto& gpsEvent : gpsEvents) {
        if (gpsEvent.speed) {
            gpsSpeedEvents.emplace_back(gpsEvent);
        }
    }
    Vector3Events events;
    events.reserve(gpsSpeedEvents.size() + accEvents.size() + gyroEvents.size());
    insert(events, gpsSpeedEvents);
    insert(events, accEvents);
    insert(events, gyroEvents);

    UnitVector3 carRightVector(geolib3::crossProduct(carFrontVector, carUpVector));
    GravityEvents gravityEvents;
    CarHorizontalFrontVecEvents carHorizontalFrontVecEvents;
    Time prevGyroTime = gyroEvents[0].time;
    bool gpsStarted = false;
    Radians carPitchAngle(0);

    for (size_t i = 0; i < events.size(); i++) {
        if (events[i].type == EventType::GpsSpeed) {
            carPitchAngle = findCarPitchAngle(
                events, i, carFrontVector, carUpVector, carRightVector);
            gpsStarted = true;
        }
        if (events[i].type == EventType::Gyroscope && gpsStarted) {
            auto& event = static_cast<const GyroscopeEvent&>(events[i]);
            carPitchAngle += event.values().rotSpeedAroundAxis(carRightVector)
                * (event.time - prevGyroTime);
            prevGyroTime = event.time;
            UnitVector3 frontVector(carFrontVector * cos(carPitchAngle)
                                    - carUpVector * sin(carPitchAngle));
            UnitVector3 upVector(carUpVector * cos(carPitchAngle)
                                 + carFrontVector * sin(carPitchAngle));
            UnitVector3 gravityVector(-upVector);

            carHorizontalFrontVecEvents.emplace_back(event.time, frontVector);
            gravityEvents.emplace_back(event.time, gravityVector);
        }
    }

    return {gravityEvents, carHorizontalFrontVecEvents};
}

} // namespace maps::mrc::pos_improvment
