#include "detect_road.h"

#include <maps/libs/geolib/include/polygon.h>
#include <maps/libs/geolib/include/spatial_relation.h>
#include <maps/libs/geolib/include/intersection.h>
#include <maps/libs/geolib/include/vector.h>
#include <maps/libs/geolib/include/distance.h>
#include <maps/libs/geolib/include/convex_hull.h>
#include <maps/libs/geolib/include/segment.h>
#include <maps/libs/geolib/include/line.h>
#include <maps/libs/geolib/include/direction.h>
#include <maps/libs/geolib/include/bounding_box.h>
#include <maps/libs/geolib/include/static_geometry_searcher.h>

#include <maps/libs/common/include/exception.h>

#include <vector>
#include <cmath>
#include <utility>
#include <algorithm>
#include <limits>
#include <memory>
#include <boost/optional.hpp>

namespace maps {
namespace wiki {
namespace autocart {

namespace {

/*geolib3::Polygon2
buffer(const geolib3::Polyline2& polyline, double width)
{
    auto geosPolyline = geolib3::internal::geolib2geosGeometry(polyline);
    std::unique_ptr<geos::geom::Geometry> bufferedGeom(
        geosPolyline->buffer(width)
    );

    return geolib3::internal::geos2geolibGeometry(
        dynamic_cast<geos::geom::Polygon*>(bufferedGeom.get())
    );
}*/

cv::Mat inferenceRoadDetector(const tf_inferencer::TensorFlowInferencer &roadDetector,
                              const cv::Mat &image)
{
    static const TString INPUT_LAYER_NAME = "inference_input";
    static const TString OUTPUT_LAYER_NAME = "inference_sigmoid";

    cv::Mat roads = tf_inferencer::tensorToImage(
                roadDetector.inference(INPUT_LAYER_NAME, image,
                                       OUTPUT_LAYER_NAME)
                );

    REQUIRE(roads.type() == CV_32FC1 && roads.size() == image.size(),
            "Inavalid type or sizes of edges detection results");

    roads.convertTo(roads, CV_8UC1, 255.);
    return roads;
}

void thinningIteration(cv::Mat *data, int iter)
{
    ASSERT(data);
    cv::Mat marker = cv::Mat::zeros(data->size(), CV_8UC1);
    for (int i = 1; i < data->rows - 1; i++)
    {
        uint8_t *ptrMarker = marker.ptr(i);
        uint8_t *ptrDataP = data->ptr(i - 1);
        uint8_t *ptrDataC = data->ptr(i);
        uint8_t *ptrDataN = data->ptr(i + 1);
        for (int j = 1; j < data->cols - 1; j++)
        {
            uint8_t p2 = ptrDataP[j];
            uint8_t p3 = ptrDataP[j + 1];
            uint8_t p4 = ptrDataC[j + 1];
            uint8_t p5 = ptrDataN[j + 1];
            uint8_t p6 = ptrDataN[j];
            uint8_t p7 = ptrDataN[j - 1];
            uint8_t p8 = ptrDataC[j - 1];
            uint8_t p9 = ptrDataP[j - 1];

            int A = (p2 == 0 && p3 == 1) + (p3 == 0 && p4 == 1) +
                    (p4 == 0 && p5 == 1) + (p5 == 0 && p6 == 1) +
                    (p6 == 0 && p7 == 1) + (p7 == 0 && p8 == 1) +
                    (p8 == 0 && p9 == 1) + (p9 == 0 && p2 == 1);
            int B = p2 + p3 + p4 + p5 + p6 + p7 + p8 + p9;
            int m1 = iter == 0 ? (p2 * p4 * p6) : (p2 * p4 * p8);
            int m2 = iter == 0 ? (p4 * p6 * p8) : (p2 * p6 * p8);

            if (A == 1 && (B >= 2 && B <= 6) && m1 == 0 && m2 == 0)
                ptrMarker[j] = 1;
        }
    }
    (*data) &= ~marker;
}

void thinningZhangSuen(cv::Mat *data)
{
    ASSERT(data);
    cv::Mat processed = ~*data / 255;
    //cv::threshold(*data, processed, threshold, 1, cv::THRESH_BINARY_INV);

    cv::Mat prev = cv::Mat::zeros(processed.size(), CV_8UC1);
    cv::Mat diff;
    do
    {
        thinningIteration(&processed, 0);
        thinningIteration(&processed, 1);
        cv::absdiff(processed, prev, diff);
        processed.copyTo(prev);
    } while (cv::countNonZero(diff) > 0);

    *data = (processed == 0);
}

/**
 * @brief detects edges on sattelite image
 * @param image  satellite image
 * @return       mask of edges. pixel on edges has zero value
 */
cv::Mat extractStrongEdges(const cv::Mat &edges, uint8_t edgesThreshold)
{
    cv::Mat fuse = 255 - edges;
    cv::Mat binary_fuse;
    cv::threshold(fuse, binary_fuse, edgesThreshold, 255, cv::THRESH_BINARY);
    cv::Mat strongEdges(edges.size(), CV_8UC1, cv::Scalar::all(0));
    cv::Mat kernel = cv::Mat::ones(3, 3, CV_8UC1);
    cv::Mat eroded;
    cv::Mat dilated;
    cv::Mat diff;
    const int MAX_ITER_CNT = 20;
    for (int thresh = edgesThreshold; thresh < 256; thresh++) {
        for (int i = 0; i < MAX_ITER_CNT; ++i) {
            cv::erode(binary_fuse, eroded, kernel);
            cv::dilate(eroded, dilated, kernel);
            eroded  |= (fuse > thresh+1);
            dilated &= (fuse > thresh);

            strongEdges  |= (binary_fuse & ~dilated);
            eroded |= strongEdges;

            absdiff(binary_fuse, eroded, diff);
            if (cv::countNonZero(diff) == 0) {
                break;
            } else {
                eroded.copyTo(binary_fuse);
            }
        }

        binary_fuse &= (fuse>thresh+1);
        binary_fuse |= strongEdges;
    }
    cv::dilate(strongEdges, strongEdges, kernel);
    return 255 - strongEdges;
}

} // namespace

geolib3::Point2 start(const geolib3::Polyline2& polyline) {
    return polyline.pointAt(0);
}

geolib3::Point2 end(const geolib3::Polyline2& polyline) {
    return polyline.pointAt(polyline.pointsNumber() - 1);
}

enum ConnectType {
    startToStart,
    startToEnd,
    endToStart,
    endToEnd
};

double
connectPolylines(const geolib3::Polyline2& mainLine,
                 const geolib3::Polyline2& addLine,
                 geolib3::Polyline2& newLine,
                 double epsilon) {
    std::vector<geolib3::Point2> pts(mainLine.pointsNumber());
    for (size_t i = 0; i < mainLine.pointsNumber(); i++) {
        pts[i] = mainLine.pointAt(i);
    }

    std::vector<geolib3::Point2> addPts(addLine.pointsNumber());
    for (size_t i = 0; i < addLine.pointsNumber(); i++) {
        addPts[i] = addLine.pointAt(i);
    }

    ConnectType connectType = ConnectType::startToStart;
    double minDist = geolib3::distance(pts.front(), addPts.front());

    double distEndToStart = geolib3::distance(pts.front(), addPts.back());
    if (distEndToStart < minDist) {
        minDist = distEndToStart;
        connectType = ConnectType::endToStart;
    }

    double distStartToEnd = geolib3::distance(pts.back(), addPts.front());
    if (distStartToEnd < minDist) {
        minDist = distStartToEnd;
        connectType = ConnectType::startToEnd;
    }

    double distEndToEnd = geolib3::distance(pts.back(), addPts.back());
    if (distEndToEnd < minDist) {
        minDist = distEndToEnd;
        connectType = ConnectType::endToEnd;
    }

    if (minDist < epsilon) {
        switch (connectType) {
        case ConnectType::startToStart:
            pts.insert(pts.begin(), addPts.rbegin(), addPts.rend() - 1);
            break;
        case ConnectType::startToEnd:
            pts.insert(pts.end(), addPts.begin() + 1, addPts.end());
            break;

        case ConnectType::endToStart:
            pts.insert(pts.begin(), addPts.begin(), addPts.end() - 1);
            break;

        case ConnectType::endToEnd:
            pts.insert(pts.end(), addPts.rbegin() + 1, addPts.rend());
            break;
        }

        newLine = geolib3::Polyline2(pts);
        return minDist;
    }

    return -1;
}

geolib3::Polyline2 findPath(cv::Mat* image,
                            const geolib3::Point2& start,
                            double eps) {
    std::vector<geolib3::Point2> pts;

    bool doIt = true;
    geolib3::Point2 cur = start;
    while (doIt) {
        doIt = false;
        pts.push_back(cur);
        int col = cur.x();
        int row = cur.y();
        image->at<uchar>(row, col) = 255;
        int startX = std::max(0, int(col - eps));
        int endX = std::min(image->cols, int(col + eps));
        int startY = std::max(0, int(row - eps));
        int endY = std::min(image->rows, int(row + eps));

        int closeX = -1;
        int closeY = -1;
        double minDist = std::numeric_limits<double>::max();
        for (int x = startX; x < endX; x++) {
            for (int y = startY; y < endY; y++) {
                if (image->at<uchar>(y, x) == 0) {
                    geolib3::Point2 pt(x, y);
                    double dist = geolib3::distance(pt, cur);
                    if (dist < minDist) {
                        minDist = dist;
                        closeX = x;
                        closeY = y;
                    }
                }
            }
        }
        if (closeX !=-1 && closeY != -1) {
            image->at<uchar>(closeY, closeX) = 255;
            cur = geolib3::Point2(closeX, closeY);
            doIt = true;
        }
    }

    if (pts.size() == 1) {
        pts.push_back(start);
    }

    return geolib3::Polyline2(pts);
}

std::vector<geolib3::Polyline2>
extractPolylines(const cv::Mat& image) {
    cv::Mat tmpImage;
    image.copyTo(tmpImage);
    std::cout << "extract" << std::endl;
    std::vector<geolib3::Polyline2> lines;
    int cols = image.cols;
    int rows = image.rows;
    for (int col = 0; col < cols; col++) {
        for (int row = 0; row < rows; row++) {
            if (image.at<uchar>(row, col) == 0) {
                lines.push_back(findPath(&tmpImage, geolib3::Point2(col, row), 2.));
            }
        }
    }

    for (int i = (int)lines.size() - 1; i >= 0; i--) {
        if (geolib3::length(lines[i]) < 5.) {
            lines.erase(lines.begin() + i);
        }
    }

    for (int i = (int)lines.size() - 1; i >= 0; i--) {
        int connectedIndex = -1;
        geolib3::Polyline2 connectedLine;
        double minDist = std::numeric_limits<double>::max();
        for (int j = 0; j < i; j++) {
            geolib3::Polyline2 newLine;
            double dist = connectPolylines(lines[i], lines[j], newLine, 10.);
            if (dist > 0) {
                if (dist < minDist) {
                    minDist = dist;
                    connectedLine = newLine;
                    connectedIndex = j;
                }
            }
        }

        if (connectedIndex != -1) {
            lines[connectedIndex] = connectedLine;
            lines.erase(lines.begin() + i);
        }
    }

    for (int i = (int)lines.size() - 1; i >= 0; i--) {
        if (geolib3::length(lines[i]) < 4.) {
            lines.erase(lines.begin() + i);
        } else {
            lines[i] = geolib3::unique(lines[i]);
            lines[i] = geolib3::simplify(lines[i], 5.);
        }
    }
    return lines;
}

std::vector<geolib3::Polyline2>
detectRoad(const tf_inferencer::TensorFlowInferencer &roadDetector,
           const cv::Mat &image)
{
    cv::Mat edges = inferenceRoadDetector(roadDetector, image);
    constexpr uint8_t EDGES_THRESHOLD = 135;
    cv::Mat strongEdges = extractStrongEdges(edges, EDGES_THRESHOLD);
    cv::Mat thinningEdges;
    strongEdges.copyTo(thinningEdges);
    thinningZhangSuen(&thinningEdges);

    cv::imwrite("edges.jpg", edges);
    cv::imwrite("strong_edges.jpg", strongEdges);
    cv::imwrite("thinning_edges.jpg", thinningEdges);

    std::vector<geolib3::Polyline2> lines = extractPolylines(thinningEdges);
    return lines;
}

} //namespace autocart
} //namespace wiki
} //namespace maps
