#include <maps/wikimap/mapspro/services/mrc/eye/lib/location/include/location.h>
#include <maps/wikimap/mapspro/services/mrc/eye/lib/location/include/camera.h>

#include <maps/libs/geolib/include/direction.h>
#include <maps/libs/geolib/include/distance.h>

#include <library/cpp/iterator/zip.h>

#include <utility>

namespace maps::mrc::eye {

namespace {

geolib3::Point2 getShiftedPosition(
    const db::eye::FrameLocation& frameLocation,
    const cv::Mat& tvec)
{
    const auto [heading, _, __] = decomposeRotation(frameLocation.rotation());
    const geolib3::Direction2 cameraDirection(heading);

    const double x = tvec.at<double>(0, 0);
    const double z = tvec.at<double>(0, 2);
    geolib3::Radians angle = cameraDirection.radians() - std::atan2(x, z);
    const double distance = geolib3::toMercatorUnits(std::sqrt(x*x + z*z),
                                                     frameLocation.mercatorPos());
    return frameLocation.mercatorPos() +
           geolib3::Direction2{angle}.vector() * distance;
}

Location averageLocations(const std::vector<Location>& locations) {
    ASSERT(locations.size() > 0);

    double avgMercatorX = 0;
    double avgMercatorY = 0;
    geolib3::Vector2 avgVec{0., 0.};

    for (const auto& location : locations) {
        avgMercatorX += location.mercatorPosition.x();
        avgMercatorY += location.mercatorPosition.y();
        const auto [heading, _, __] = decomposeRotation(location.rotation);
        avgVec += geolib3::Direction2(heading).vector();
    }

    avgMercatorX /= locations.size();
    avgMercatorY /= locations.size();

    REQUIRE(geolib3::length(avgVec) > geolib3::EPS, "Invalid average vector");

    geolib3::Point2 avgMercator(avgMercatorX, avgMercatorY);
    geolib3::Heading avgHeading(geolib3::Direction2(avgVec).heading());

    return Location{
        avgMercator,
        toRotation(avgHeading, common::ImageOrientation(common::Rotation::CW_0))
    };
}

} // namespace

Location findHouseNumberLocation(
        const db::eye::Frames& frames,
        const db::eye::FrameLocations& locations,
        const db::eye::Detections& detections)
{
    ASSERT(frames.size() > 0);

    ASSERT(frames.size() == locations.size());
    ASSERT(frames.size() == detections.size());

    size_t nearest = 0;
    double maxRatio = 0;

    for (size_t i = 0; i < locations.size(); ++i) {
        const double ratio = getBoxToFrameSizeRatio(detections[i].box(), frames[i].originalSize());

        if (ratio > maxRatio) {
            nearest = i;
            maxRatio = ratio;
        }
    }

    const auto& location = locations[nearest];

    return {
        location.mercatorPos(),
        reverseRotationHeading(location.rotation())
    };
}

std::optional<Location> findLocationBySingleView(
    const db::eye::Device& device,
    const db::eye::Frame& frame,
    const db::eye::FrameLocation& frameLocation,
    const db::eye::Detection& detection,
    const std::vector<cv::Point3f>& objectCoords)
{
    const common::ImageBox box = common::transformByImageOrientation(
        detection.box(),
        frame.originalSize(),
        frame.orientation()
    );
    const std::vector<cv::Point2f> imageCoords{
        cv::Point2f(box.maxX(), box.maxY()),
        cv::Point2f(box.minX(), box.maxY()),
        cv::Point2f(box.maxX(), box.minY()),
        cv::Point2f(box.minX(), box.minY())
    };

    const auto [cameraMatrix, distortionCoeffs] = getCameraParameters(
        device.attrs(), frame.orientation(), frame.originalSize());

    cv::Mat rvec, tvec;
    const bool isOk = cv::solvePnP(
        objectCoords, imageCoords,
        cameraMatrix, distortionCoeffs,
        rvec,
        tvec
    );

    if (not isOk) {
        return std::nullopt;
    }

    return Location{
        getShiftedPosition(frameLocation, tvec),
        reverseRotationHeading(frameLocation.rotation())
    };
}

namespace {

std::vector<size_t> selectBiggestDetectionsIndx(
    const db::eye::Detections& detections,
    size_t limit)
{
    std::vector<size_t> indices;

    if (detections.size() <= limit) {
        for (size_t i = 0; i < detections.size(); i++) {
            indices.push_back(i);
        }
        return indices;
    }

    // <detectionId, площадь прямоугольника>
    std::vector<std::pair<size_t, double>> detectionSizes;
    for (size_t i = 0; i < detections.size(); i++) {
        detectionSizes.emplace_back(i, detections[i].box().area());
    }

    std::partial_sort(
        detectionSizes.begin(), detectionSizes.begin() + limit, detectionSizes.end(),
        [&](const std::pair<db::TId, double>& lhs, const std::pair<db::TId, double>& rhs) {
            return lhs.second > rhs.second;
        }
    );

    for (size_t i = 0; i < limit; i++) {
        indices.push_back(detectionSizes[i].first);
    }

    return indices;
}

} // namespace

Location findLocationByMultipleViews(
    const db::eye::Devices& devices,
    const db::eye::Frames& frames,
    const db::eye::FrameLocations& locations,
    const db::eye::Detections& detections,
    const std::vector<cv::Point3f>& objectCoords)
{
    static const size_t MAX_USED_DETECTIONS_COUNT = 5;

    ASSERT(frames.size() > 0);

    ASSERT(frames.size() == devices.size());
    ASSERT(frames.size() == locations.size());
    ASSERT(frames.size() == detections.size());

    std::vector<size_t> indices = selectBiggestDetectionsIndx(detections, MAX_USED_DETECTIONS_COUNT);

    std::vector<Location> objectLocations;

    double maxRatio = 0;
    size_t maxRatioIndex = 0;

    for (size_t i : indices) {
        const double ratio = getBoxToFrameSizeRatio(detections[i].box(), frames[i].originalSize());

        if (ratio > maxRatio) {
            maxRatioIndex = i;
            maxRatio = ratio;
        }

        const auto objectLocation = findLocationBySingleView(
            devices[i], frames[i], locations[i], detections[i], objectCoords
        );

        if (!objectLocation.has_value()) {
            continue;
        }

        double mercatorDist = geolib3::distance(
            objectLocation->mercatorPosition,
            locations[i].mercatorPos()
        );
        double metersDist = geolib3::toMeters(
            mercatorDist,
            locations[i].mercatorPos()
        );

        constexpr double METERS_DIST_THRESHOLD = 120.;

        if (metersDist < METERS_DIST_THRESHOLD) {
            objectLocations.push_back(objectLocation.value());
        }
    }

    if (!objectLocations.empty()) {
        return averageLocations(objectLocations);
    } else {
        return Location{
            locations[maxRatioIndex].mercatorPos(),
            reverseRotationHeading(locations[maxRatioIndex].rotation())
        };
    }
}

const std::vector<cv::Point3f>& defaultTrafficLightPattern()
{
    static const std::vector<cv::Point3f> coords {
        { 0.15,  0.45, 0},
        {-0.15,  0.45, 0},
        { 0.15, -0.45, 0},
        {-0.15, -0.45, 0},
    };

    return coords;
}

const std::vector<cv::Point3f>& defaultSignPattern()
{
    static const std::vector<cv::Point3f> coords {
        { 0.35,  0.35, 0},
        {-0.35,  0.35, 0},
        { 0.35, -0.35, 0},
        {-0.35, -0.35, 0},
    };

    return coords;
}

Eigen::Quaterniond reverseRotationHeading(const Eigen::Quaterniond& frameRotation)
{
    const auto [heading, _, __] = decomposeRotation(frameRotation);

    static const Eigen::Quaterniond baseRotation(
        makeRotationMatrix(
            -Eigen::Vector3d::UnitY(),
            -Eigen::Vector3d::UnitZ(),
            Eigen::Vector3d::UnitX()
        )
    );

    const auto angle = geolib3::Direction2(geolib3::reverse(heading)).radians();
    const Eigen::AngleAxisd headingRotation(angle.value(), -Eigen::Vector3d::UnitY());

    return baseRotation * headingRotation;
}


Location findRoadMarkingLocationBySingleView(
    const db::eye::FrameLocation& frameLocation)
{
    static constexpr double SHIFT_METERS = 10.;

    const geolib3::Direction2 direction
        = geolib3::Direction2{decomposeRotation(frameLocation.rotation()).heading};
    const double mercatorShift
        = geolib3::toMercatorUnits(SHIFT_METERS, frameLocation.mercatorPos());

    return Location{
        frameLocation.mercatorPos() + direction.vector() * mercatorShift,
        reverseRotationHeading(frameLocation.rotation())
    };
}

Location findRoadMarkingLocation(
    const db::eye::FrameLocations& frameLocations)
{
    std::vector<Location> locations;
    for (const db::eye::FrameLocation& frameLocation : frameLocations) {
        locations.push_back(findRoadMarkingLocationBySingleView(frameLocation));
    }

    return averageLocations(locations);
}

} // namespace maps::mrc::eye
