#include <maps/wikimap/mapspro/services/autocart/libs/detection/include/detect_bld.h>
#include <maps/wikimap/mapspro/services/autocart/libs/post_processing/include/polygon_regularization.h>

#include <maps/wikimap/mapspro/services/autocart/libs/detection/include/graph_pgn_extract.h>
#include <maps/wikimap/mapspro/services/autocart/libs/detection/include/pp_pgns.h>

#include <maps/wikimap/mapspro/services/autocart/libs/detection/include/instance_segmentation.h>

#include <maps/libs/common/include/exception.h>

#include <algorithm>
#include <iterator>

namespace maps {
namespace wiki {
namespace autocart {

namespace {
constexpr uint8_t SEMSEGM_BLD_LABEL = 1;

void thinningIteration(cv::Mat *data, int iter)
{
    ASSERT(data);
    cv::Mat marker = cv::Mat::zeros(data->size(), CV_8UC1);
    for (int i = 1; i < data->rows - 1; i++)
    {
        uint8_t *ptrMarker = marker.ptr(i);
        uint8_t *ptrDataP = data->ptr(i - 1);
        uint8_t *ptrDataC = data->ptr(i);
        uint8_t *ptrDataN = data->ptr(i + 1);
        for (int j = 1; j < data->cols - 1; j++)
        {
            uint8_t p2 = ptrDataP[j];
            uint8_t p3 = ptrDataP[j + 1];
            uint8_t p4 = ptrDataC[j + 1];
            uint8_t p5 = ptrDataN[j + 1];
            uint8_t p6 = ptrDataN[j];
            uint8_t p7 = ptrDataN[j - 1];
            uint8_t p8 = ptrDataC[j - 1];
            uint8_t p9 = ptrDataP[j - 1];

            int A = (p2 == 0 && p3 == 1) + (p3 == 0 && p4 == 1) +
                    (p4 == 0 && p5 == 1) + (p5 == 0 && p6 == 1) +
                    (p6 == 0 && p7 == 1) + (p7 == 0 && p8 == 1) +
                    (p8 == 0 && p9 == 1) + (p9 == 0 && p2 == 1);
            int B = p2 + p3 + p4 + p5 + p6 + p7 + p8 + p9;
            int m1 = iter == 0 ? (p2 * p4 * p6) : (p2 * p4 * p8);
            int m2 = iter == 0 ? (p4 * p6 * p8) : (p2 * p6 * p8);

            if (A == 1 && (B >= 2 && B <= 6) && m1 == 0 && m2 == 0)
                ptrMarker[j] = 1;
        }
    }
    (*data) &= ~marker;
}

void thinningZhangSuen(cv::Mat *data, uint8_t threshold)
{
    ASSERT(data);
    cv::Mat processed;
    cv::threshold(*data, processed, threshold, 1, cv::THRESH_BINARY_INV);

    cv::Mat prev = cv::Mat::zeros(processed.size(), CV_8UC1);
    cv::Mat diff;
    do
    {
        thinningIteration(&processed, 0);
        thinningIteration(&processed, 1);
        cv::absdiff(processed, prev, diff);
        processed.copyTo(prev);
    } while (cv::countNonZero(diff) > 0);

    *data = (processed == 0);
}

cv::Mat inferenceEdgeDetector(const tf_inferencer::TensorFlowInferencer &edgeDetector,
                              const cv::Mat &image)
{
    static const TString INPUT_LAYER_NAME = "inference_input";
    static const TString OUTPUT_LAYER_NAME = "inference_sigmoid";

    cv::Mat edges = tf_inferencer::tensorToImage(
            edgeDetector.inference(INPUT_LAYER_NAME, image,
                                            OUTPUT_LAYER_NAME)
        );

    REQUIRE(edges.type() == CV_32FC1 && edges.size() == image.size(),
            "Inavalid type or sizes of edges detection results");

    edges.convertTo(edges, CV_8UC1, 255.);
    return edges;
}


std::vector<Polygon> convertToPolygons(const std::vector<cv::RotatedRect> &objects)
{
    std::vector<Polygon> result;
    result.reserve(objects.size());

    std::transform(
        objects.begin(), objects.end(), std::back_inserter(result),
        [](const cv::RotatedRect& rect) -> Polygon {
            Polygon polygon(4);
            rect.points(polygon.data());
            return polygon;
        }
    );

    return result;
}


/**
 * @brief detects edges on sattelite image
 * @param image  satellite image
 * @return       mask of edges. pixel on edges has zero value
 */
cv::Mat extractStrongEdges(const tf_inferencer::TensorFlowInferencer &edgeDetector,
                           const cv::Mat &image)
{
    constexpr uint8_t EDGES_THRESHOLD = 135;
    cv::Mat fuse = inferenceEdgeDetector(edgeDetector, image);

    fuse = 255 - fuse;
    cv::Mat binary_fuse;
    cv::threshold(fuse, binary_fuse, EDGES_THRESHOLD, 255, cv::THRESH_BINARY);
    cv::Mat edges(image.size(), CV_8UC1, cv::Scalar::all(0));
    cv::Mat kernel = cv::Mat::ones(3, 3, CV_8UC1);
    cv::Mat eroded;
    cv::Mat dilated;
    cv::Mat diff;
    const int MAX_ITER_CNT = 20;
    for (int thresh = EDGES_THRESHOLD; thresh < 256; thresh++) {
        for (int i = 0; i < MAX_ITER_CNT; ++i) {
            cv::erode(binary_fuse, eroded, kernel);
            cv::dilate(eroded, dilated, kernel);
            eroded  |= (fuse > thresh+1);
            dilated &= (fuse > thresh);

            edges  |= (binary_fuse & ~dilated);
            eroded |= edges;

            absdiff(binary_fuse, eroded, diff);
            if (cv::countNonZero(diff) == 0) {
                break;
            } else {
                eroded.copyTo(binary_fuse);
            }
        }

        binary_fuse &= (fuse>thresh+1);
        binary_fuse |= edges;
    }
    cv::dilate(edges, edges, kernel);
    return 255 - edges;
}

std::vector<Polygon> extractPolygons(const cv::Mat &edges)
{
    const float MIN_AREA = 200.f;
    const float MAX_AREA = 40000.f;

    cv::Mat rasterMarkup;
    int componentsCnt = cv::connectedComponents(edges, rasterMarkup, 4);
    std::vector<Polygon> polygons;
    const int rows = rasterMarkup.rows;
    const int cols = rasterMarkup.cols;
    std::vector<std::vector<cv::Point2f>> blobs(componentsCnt);
    for (int row = 0; row < rows; row++) {
        const int *ptr = rasterMarkup.ptr<int>(row);
        for (int col = 0; col < cols; col++) {
            blobs[ptr[col]].push_back(cv::Point2f(col, row));
        }
    }

    for (size_t i = 1; i < blobs.size(); ++i){
        Polygon hull;
        cv::convexHull(blobs[i], hull, false);
        float area = cv::contourArea(hull);
        if(MIN_AREA <= area && area <= MAX_AREA) {
            polygons.push_back(hull);
        }
    }

    return polygons;
}

} // namespace

cv::Mat semSegmImage(const tf_inferencer::TensorFlowInferencer& segmentator,
                     const cv::Mat &image)
{
    static const TString INPUT_LAYER_NAME = "inference_input";
    static const TString OUTPUT_ARGMAX_LAYER_NAME = "inference_argmax";

    cv::Mat temp = tf_inferencer::tensorToImage(
            segmentator.inference(INPUT_LAYER_NAME, image,
                                  OUTPUT_ARGMAX_LAYER_NAME)
        );


    REQUIRE(temp.type() == CV_64FC1 && temp.size() == image.size(),
            "Inavalid type or sizes of semantic segmentation results");
    cv::Mat labels(temp.size(), CV_8UC1);
    for (int row = 0; row < temp.rows; row++) {
        const int64_t *src = temp.ptr<int64_t>(row);
        uint8_t *dst = labels.ptr<uint8_t>(row);
        for (int col = 0; col < temp.cols; col++) {
            dst[col] = (uint8_t)src[col];
        }
    }
    return labels;
}


cv::Mat detectEdges(const tf_inferencer::TensorFlowInferencer& edgeDetector,
                    const cv::Mat &image)
{
    constexpr uint8_t EDGES_THRESHOLD = 135;
    cv::Mat edges = inferenceEdgeDetector(edgeDetector, image);
    thinningZhangSuen(&edges, EDGES_THRESHOLD);
    return edges;
}


int extractRaster(const cv::Mat &image,
                  const cv::Mat &mask,
                  const cv::Mat &edges,
                  cv::Mat &rasterMarkup)
{
    cv::Mat kernel = cv::Mat::ones(3, 3, CV_8UC1);
    cv::Mat opening;
    cv::morphologyEx(mask, opening, cv::MORPH_OPEN, kernel,
                     cv::Point(-1, -1), //anchor in center of kernel
                     2 // iterations
    );

    cv::Mat sureBG;
    cv::dilate(opening, sureBG, kernel, cv::Point(-1, -1), 3);
    cv::Mat distTransform;
    cv::distanceTransform(opening, distTransform, cv::DIST_L2, 5, CV_32F);
    cv::Mat sureFG;

    double minVal, maxVal;
    cv::minMaxLoc(distTransform, &minVal, &maxVal);
    if (maxVal - minVal < 1e-3)
        return 0;
    distTransform = (distTransform - (float)minVal) / (float)(maxVal - minVal) * 255.f;

    distTransform.convertTo(sureFG, CV_8U);
    cv::threshold(sureFG, sureFG, 1., 255, cv::THRESH_OTSU | cv::THRESH_BINARY);

    if (!edges.empty()) {
        cv::Mat temp = sureFG; sureFG.release();
        temp.copyTo(sureFG, edges);
    }
    cv::Mat unknown = sureBG - sureFG;
    int componentsCnt = cv::connectedComponents(sureFG, rasterMarkup, 4);
    rasterMarkup = rasterMarkup + 1;
    rasterMarkup.setTo(0, unknown == 255);
    cv::watershed(image, rasterMarkup);
    return componentsCnt;
}

std::vector<cv::RotatedRect>
extractRectangles(const cv::Mat &rasterMarkup,
                  int componentsCnt)
{
    constexpr int MIN_BLOB_PTS = 200;
    constexpr int MIN_RECT_AREA = 400;

    std::vector<cv::RotatedRect> result;
    if (0 == componentsCnt)
        return result;

    REQUIRE(rasterMarkup.type() == CV_32SC1, "Markup data should be 32SC1 type");

    std::vector<std::vector<cv::Point>> blobs(componentsCnt + 1);
    const int rows = rasterMarkup.rows;
    const int cols = rasterMarkup.cols;
    for (int row = 0; row < rows; row++) {
        const int *ptr = rasterMarkup.ptr<int>(row);
        for (int col = 0; col < cols; col++) {
            if (ptr[col] < 0)
                continue;
            blobs[ptr[col]].push_back(cv::Point(col, row));
        }
    }

    for (size_t i = 2; i < blobs.size(); i++)
    {
        const std::vector<cv::Point> &blob = blobs[i];
        if (blob.size() < MIN_BLOB_PTS)
            continue;
        cv::RotatedRect rrc = cv::minAreaRect(blob);
        if (rrc.size.area() < MIN_RECT_AREA)
            continue;
        result.push_back(rrc);
    }
    return result;
}

std::vector<Polygon>
detectBldByMinAreaRect(const tf_inferencer::TensorFlowInferencer& segmentator,
          const tf_inferencer::TensorFlowInferencer& edgeDetector,
          const cv::Mat &image)
{
    cv::Mat labels = semSegmImage(segmentator, image);
    cv::Mat edges = detectEdges(edgeDetector, image);

    cv::Mat rasterMarkup;
    int componentsCnt = extractRaster(image,
                                      (labels == SEMSEGM_BLD_LABEL),
                                      edges,
                                      rasterMarkup);

    auto rectangles = extractRectangles(rasterMarkup, componentsCnt);
    return convertToPolygons(rectangles);
}

std::vector<Polygon>
detectBldByEdges(const tf_inferencer::TensorFlowInferencer &edgeDetector,
          const cv::Mat &image)
{
    cv::Mat edges = extractStrongEdges(edgeDetector, image);
    std::vector<Polygon> polygons = extractPolygons(edges);
    return polygons;
}

std::vector<Polygon>
detectBldByVertsEdges(const tf_inferencer::TensorFlowInferencer& segmentator,
                      const tf_inferencer::TensorFlowInferencer& edgeDetector,
                      const cv::Mat &image)
{
    static const TString INPUT_LAYER_NAME = "inference_input";
    static const TString OUTPUT_EDGES_LAYER_NAME = "inference_edges";
    static const TString OUTPUT_VERTS_LAYER_NAME = "inference_vertices";

    cv::Mat segm = (semSegmImage(segmentator, image) == SEMSEGM_BLD_LABEL);

    std::vector<tensorflow::Tensor> results = edgeDetector.inference(
        INPUT_LAYER_NAME,
        image,
        std::vector<std::string>({OUTPUT_EDGES_LAYER_NAME, OUTPUT_VERTS_LAYER_NAME}));
    cv::Mat edges = tf_inferencer::tensorToImage(results[0]);
    cv::Mat verts = tf_inferencer::tensorToImage(results[1]);

    static const maps::wiki::autocart::ExtractPolygonsParams ep_params;
    static const maps::wiki::autocart::PPPolygonsParams pp_params;

    std::vector<std::vector<cv::Point>> polygons =
        maps::wiki::autocart::extractPolygons(verts, edges, segm, ep_params);
    maps::wiki::autocart::postprocessPolygons(pp_params, polygons);

    std::vector<Polygon> result;
    result.resize(polygons.size());
    for (size_t i = 0; i < polygons.size(); i++) {
        result[i].resize(polygons[i].size());
        for (size_t j = 0; j < polygons[i].size(); j++) {
            result[i][j] = polygons[i][j];
        }
    }
    return result;
}

std::vector<std::vector<cv::Point2f>>
extractBldContours(const tf_inferencer::TensorFlowInferencer& segmentator,
                   const cv::Mat &image)
{
    constexpr double MIN_BLD_AREA = 400.;
    std::vector<std::vector<cv::Point2f>> bldContours;

    cv::Mat labels = semSegmImage(segmentator, image);
    cv::Mat rasterMarkup;
    int componentsCnt = extractRaster(image,
                                      (labels == SEMSEGM_BLD_LABEL),
                                      cv::Mat(),
                                      rasterMarkup);
    for (int i = 2; i < componentsCnt + 1; i++) {
        cv::Mat blobImage = (rasterMarkup == i);
        std::vector<std::vector<cv::Point>> contours;
        std::vector<cv::Vec4i> hierarchy;
        cv::findContours(blobImage, contours, hierarchy, CV_RETR_TREE, CV_CHAIN_APPROX_SIMPLE);
        for (size_t j = 0; j < contours.size(); j++) {
            // if not nested and big enough
            if (hierarchy[j][3] < 0 && cv::contourArea(contours[j]) >= MIN_BLD_AREA) {
                std::vector<cv::Point2f> bldContour;
                bldContour.reserve(contours[j].size());
                for (const auto& point : contours[j]) {
                    bldContour.emplace_back(point.x, point.y);
                }
                bldContours.push_back(std::move(bldContour));
            }
        }
    }
    return bldContours;
}

std::vector<Polygon>
detectComplexBldByGrid(const tf_inferencer::TensorFlowInferencer& segmentator,
                       const cv::Mat &image)
{
    constexpr double TOLERANCE = 12.;
    constexpr double GRID_SIDE = 3.;

    std::vector<std::vector<cv::Point2f>> contours = extractBldContours(segmentator,
                                                                        image);
    std::vector<Polygon> blds;
    for (const auto& contour : contours) {
        Polygon bld = gridRegularizePolygon(contour, TOLERANCE, GRID_SIDE);
        blds.push_back(bld);
    }
    return blds;
}

std::vector<Polygon>
detectComplexBldByProjection(const tf_inferencer::TensorFlowInferencer& segmentator,
                             const cv::Mat &image)
{
    constexpr double MIN_WALL_LENGTH = 15.;

    std::vector<std::vector<cv::Point2f>> contours = extractBldContours(segmentator,
                                                                        image);
    std::vector<Polygon> blds;
    for (const auto& contour : contours) {
        Polygon bld = projectionRegularizePolygon(contour, MIN_WALL_LENGTH);
        blds.push_back(bld);
    }
    return blds;
}

std::vector<Polygon>
detectComplexBldByMaskRCNN(const tf_inferencer::MaskRCNNInferencer& maskrcnn,
                           const cv::Mat &image,
                           size_t batchSize)
{
    constexpr double MIN_BLD_AREA = 100.;

    std::vector<BldInstance> instances = segmentInstances(maskrcnn, image, batchSize);
    std::vector<Polygon> blds = vectorizeInstances(instances);
    blds.erase(
        std::remove_if(
            blds.begin(), blds.end(),
            [&](const Polygon& polygon) {
                return contourArea(polygon) < MIN_BLD_AREA;
            }
        ),
        blds.end()
    );

    return blds;
}

} //namespace autocart
} //namespace wiki
} //namespace maps
