#include <maps/wikimap/mapspro/services/mrc/libs/birdview/include/vanishing_point.h>

#include <cmath>

namespace maps::mrc::birdview {
namespace {

int deltaX(const Line& line) { return line[2] - line[0]; }

int deltaY(const Line& line) { return line[3] - line[1]; }

int minY(const Line& line) { return std::min(line[1], line[3]); }

float length(const Line& line)
{
    int dx = deltaX(line);
    int dy = deltaY(line);
    return std::sqrt(dx * dx + dy * dy);
}

bool valid(const Line& line)
{
    static const float MIN_RADIAN = 10. * CV_PI / 180.;
    int dx = deltaX(line);
    int dy = deltaY(line);
    return dx && dy && std::fabs(std::atan(dy / (float)dx)) > MIN_RADIAN;
}

std::optional<cv::Point2f> intersection(const Line& line1, const Line& line2)
{
    int dx1 = deltaX(line1);
    int dy1 = deltaY(line1);
    int dx2 = deltaX(line2);
    int dy2 = deltaY(line2);
    int denom = dx1 * dy2 - dx2 * dy1;
    if (denom == 0)
        return {}; // collinear
    int numer = dx2 * (line1[1] - line2[1]) - dy2 * (line1[0] - line2[0]);
    float ratio = numer / (float)denom;
    return cv::Point2f{line1[0] + ratio * dx1, line1[1] + ratio * dy1};
}

cv::Mat calculateConfidenceMatrix(const std::vector<Line>& lines,
                                  const cv::Size& imageSize,
                                  int scale)
{
    cv::Mat result = cv::Mat::zeros(imageSize / scale, CV_32FC1);
    for (size_t i = 1; i < lines.size(); i++) {
        const Line& line1 = lines[i];
        if (!valid(line1) || minY(line1) < imageSize.height / 2) {
            continue;
        }

        float len1 = length(line1);

        for (size_t j = 0; j < i; j++) {
            const Line& line2 = lines[j];
            if (!valid(line2) || minY(line2) < imageSize.height / 2) {
                continue;
            }

            auto point = intersection(line1, line2);
            if (!point || point->x > 2 * imageSize.width / 3
                || point->x < imageSize.width / 3 || point->y > minY(line1)
                || point->y > minY(line2)) {
                continue;
            }

            int col = point->x / scale;
            int row = point->y / scale;
            if (col < 0 || col >= result.cols || row < 0
                || row >= result.rows) {
                continue;
            }

            result.at<float>(row, col) += len1 * length(line2);
        }
    }
    return result;
}

} // anonymous namespace

std::pair<cv::Point2f, float>
findVanishingPoint(const std::vector<Line>& lines, const cv::Size& imageSize)
{
    static const int SCALE = 5;
    static const int MARGIN = 5;
    static const double GAUSSIAN_SIGMA = 1.5;

    cv::Mat conf = calculateConfidenceMatrix(lines, imageSize, SCALE);
    cv::Mat blured;
    cv::GaussianBlur(conf, blured, cv::Size(), GAUSSIAN_SIGMA);
    double maxVal;
    cv::Point maxPoint;
    cv::minMaxLoc(blured, 0, &maxVal, 0, &maxPoint);

    cv::Rect rect;
    rect.x = std::max(maxPoint.x - MARGIN, 0);
    rect.y = std::max(maxPoint.y - MARGIN, 0);
    rect.width = std::min(maxPoint.x + MARGIN, blured.cols - 1) - rect.x + 1;
    rect.height = std::min(maxPoint.y + MARGIN, blured.rows - 1) - rect.y + 1;
    blured(rect).setTo(0.);

    double backgroundMaxVal;
    cv::minMaxLoc(blured, 0, &backgroundMaxVal, 0, 0);
    float weight = (maxVal - backgroundMaxVal) / maxVal;
    return {maxPoint * SCALE, weight};
}

std::optional<cv::Point> findVanishingPoint(const cv::Mat& image)
{
    static constexpr double SCALE_TO_FIND_LINES = 0.5;
    static constexpr double WEIGHT_THRESHOLD = 0.3;

    cv::Ptr<cv::LineSegmentDetector> lineDetectorPtr
        = cv::createLineSegmentDetector(
            cv::LSD_REFINE_STD, SCALE_TO_FIND_LINES);
    cv::Mat gray;
    cv::cvtColor(image, gray, cv::COLOR_BGR2GRAY);
    std::vector<cv::Vec4i> lines;
    lineDetectorPtr->detect(gray, lines);
    auto[point, weight] = birdview::findVanishingPoint(lines, gray.size());
    if (weight < WEIGHT_THRESHOLD) {
        return {};
    }
    return cv::Point(point);
}

} // namespace maps::mrc::birdview
