#include "undistort.h"

namespace maps {
namespace mrc {
namespace browser {
namespace {

class FishEyesUndistort {
public:
    FishEyesUndistort() : imgSize_(-1, -1) {}

    cv::Mat undistortImage(const cv::Mat& distorted)
    {
        calcUndistortMaps(distorted.size());
        cv::Mat undistorted;
        cv::remap(distorted, undistorted, map1_, map2_, cv::INTER_LINEAR,
                  cv::BORDER_CONSTANT);
        return undistorted;
    }

private:
    cv::Size imgSize_;
    cv::Mat map1_;
    cv::Mat map2_;

    void undistortPoints(const std::vector<cv::Point2f>& distPts,
                         std::vector<cv::Point2f>& undistPts,
                         const cv::Mat& K)
    {
        undistPts.resize(distPts.size());

        float fx = K.at<float>(0, 0);
        float fy = K.at<float>(1, 1);

        float cx = K.at<float>(0, 2);
        float cy = K.at<float>(1, 2);

        for (size_t i = 0; i < distPts.size(); i++) {
            const cv::Point2f pw((distPts[i].x - cx) / fx,
                                 (distPts[i].y - cy) / fy); // world point

            const float theta = sqrtf(pw.dot(pw));
            const float scale = (theta < 1e-8) ? 1.f : tanf(theta) / theta;

            undistPts[i] = pw * scale;
        }
    }

    void calcUndistortMaps(const cv::Size& imgSize)
    {
        if (imgSize_ == imgSize)
            return;
        imgSize_ = imgSize;

        cv::Mat K = cv::Mat::eye(3, 3, CV_32FC1);
        K.at<float>(0, 0) = (float)imgSize_.width / 2.f;
        K.at<float>(0, 2) = (float)imgSize_.width / 2.f;
        K.at<float>(1, 1) = (float)imgSize_.width / 2.f;
        K.at<float>(1, 2) = (float)imgSize_.height / 2.f;
        float scale_x = 1.f;
        float scale_y = 1.f;
        std::vector<cv::Point2f> distorted(4);
        distorted[0].x = (float)imgSize_.width / 2.f;
        distorted[0].y = 0.f;
        distorted[1].x = (float)imgSize_.width;
        distorted[1].y = (float)imgSize_.height / 2.f;
        distorted[2].x = (float)imgSize_.width / 2.f;
        distorted[2].y = (float)imgSize_.height;
        distorted[3].x = 0.f;
        distorted[3].y = (float)imgSize_.height / 2.f;

        std::vector<cv::Point2f> undistorted(4);
        undistortPoints(distorted, undistorted, K);

        scale_x = (float)imgSize_.width
                  / (undistorted[1].x - undistorted[3].x);
        scale_y = (float)imgSize_.height
                  / (undistorted[2].y - undistorted[0].y);

        scale_x = std::max(scale_x, scale_y);
        scale_y = scale_x;

        cv::Mat Knew = cv::Mat::eye(3, 3, CV_32FC1);
        Knew.at<float>(0, 0) = scale_x;
        Knew.at<float>(0, 2) = (float)(imgSize_.width / 2);
        Knew.at<float>(1, 1) = scale_y;
        Knew.at<float>(1, 2) = (float)(imgSize_.height / 2);

        calcUndistortMaps(K, Knew);
    }

    void calcUndistortMaps(const cv::Mat& K, const cv::Mat& Knew)
    {
        map1_.create(imgSize_, CV_16SC2);
        map2_.create(imgSize_, CV_16UC1);

        float fx = K.at<float>(0, 0);
        float fy = K.at<float>(1, 1);

        float cx = K.at<float>(0, 2);
        float cy = K.at<float>(1, 2);

        cv::Matx33f iR = Knew;
        iR = iR.inv(cv::DECOMP_SVD);

        for (int i = 0; i < imgSize_.height; ++i) {
            short* m1 = (short*)map1_.ptr<float>(i);
            ushort* m2 = (ushort*)map2_.ptr<float>(i);

            float _x = i * iR(0, 1) + iR(0, 2), _y = i * iR(1, 1) + iR(1, 2),
                  _w = i * iR(2, 1) + iR(2, 2);

            for (int j = 0; j < imgSize_.width; ++j) {
                float x = _x / _w, y = _y / _w;

                float r = sqrtf(x * x + y * y);
                float theta = atanf(r);

                float scale = (r == 0) ? 1.f : theta / r;
                float u = fx * x * scale + cx;
                float v = fy * y * scale + cy;

                int iu = cv::saturate_cast<int>(u * cv::INTER_TAB_SIZE);
                int iv = cv::saturate_cast<int>(v * cv::INTER_TAB_SIZE);
                m1[j * 2 + 0] = (short)(iu >> cv::INTER_BITS);
                m1[j * 2 + 1] = (short)(iv >> cv::INTER_BITS);
                m2[j] = (ushort)((iv & (cv::INTER_TAB_SIZE - 1))
                                     * cv::INTER_TAB_SIZE
                                 + (iu & (cv::INTER_TAB_SIZE - 1)));

                _x += iR(0, 0);
                _y += iR(1, 0);
                _w += iR(2, 0);
            }
        }
    }
};

} // anonymous namespace

cv::Mat undistort(const cv::Mat& image)
{
    return FishEyesUndistort{}.undistortImage(image);
}

} // browser
} // mrc
} // maps
