#include <maps/wikimap/mapspro/services/autocart/libs/detection/include/instance_segmentation.h>

#include <maps/wikimap/mapspro/libs/tf_inferencer/tf_inferencer.h>
#include <maps/wikimap/mapspro/services/autocart/libs/post_processing/include/polygon_regularization.h>

#include <tensorflow/core/framework/graph.pb.h>
#include <tensorflow/core/framework/tensor.h>
#include <tensorflow/core/common_runtime/dma_helper.h>

#include <maps/libs/geolib/include/bounding_box.h>
#include <maps/libs/geolib/include/static_geometry_searcher.h>
#include <maps/libs/geolib/include/spatial_relation.h>

#include <opencv2/opencv.hpp>

#include <vector>
#include <string>
#include <stack>
#include <optional>

namespace tf = tensorflow;

namespace maps {
namespace wiki {
namespace autocart {

namespace {

geolib3::BoundingBox cvRectToBbox(const cv::Rect& rect) {
    return geolib3::BoundingBox(
        geolib3::Point2(rect.x, rect.y),
        geolib3::Point2(rect.x + rect.width, rect.y + rect.height)
    );
}

/**
 * @brief Segment different building instances on images.
 *     Using Mask R-CNN for instance segmentation:
 *     https://arxiv.org/abs/1703.06870
 * @param maskrcnn - path to Mask R-CNN tensorflow model in protobuf format
 * @param images    - data for input layer
 * @return bounding boxes and masks for different building instances on images
 */
std::vector<std::vector<BldInstance>>
inferenceMaskRCNN(const tf_inferencer::MaskRCNNInferencer& maskrcnn,
                  const std::vector<cv::Mat>& images) {
    const float SCORE_THRESHOLD = 0.5f;
    const float MASK_THRESHOLD  = 0.5f;

    std::vector<tf_inferencer::MaskRCNNResults> resultsBatches
        = maskrcnn.inference(images, SCORE_THRESHOLD);

    std::vector<std::vector<BldInstance>> instancesBatches;
    for (const tf_inferencer::MaskRCNNResults& results : resultsBatches) {
        std::vector<BldInstance> instances;
        instances.reserve(results.size());
        for (const tf_inferencer::MaskRCNNResult& result : results) {
            instances.emplace_back(
                cvRectToBbox(result.bbox), result.mask > MASK_THRESHOLD);
        }
        instancesBatches.push_back(instances);
    }
    return instancesBatches;
}

/**
 * @brief Adds black pixels to right and bottom of image to
 *     match size of image patches to input of neural networks
 * @param image   - source image
 * @param minSize - minimum size of patch (size of neural network input)
 * @param overlap - overlap of two patches
 * @return padded image with the following dimensions:
 *     width = minSize.width + (minSize.width - overlap) * K
 *     height = minSize.height + (minSize.height - overlap) * K
 *     where K is a non-negative integer
 */
cv::Mat padImage(const cv::Mat& image, const cv::Size& cellSize, int overlap) {
    REQUIRE(cellSize.width > overlap && cellSize.height > overlap,
            "Overlap should be smaller than cell size");
    int notOverlapWidth = cellSize.width - overlap;
    int notOverlapHeight = cellSize.height - overlap;
    int paddedWidth = cellSize.width;
    if (image.cols > cellSize.width) {
        int addHorizCellsCnt = (image.cols - cellSize.width + notOverlapWidth - 1) / notOverlapWidth;
        paddedWidth += notOverlapWidth * addHorizCellsCnt;
    }
    int paddedHeight = cellSize.height;
    if (image.rows > cellSize.height) {
        int addVertCellsCnt = (image.rows - cellSize.height + notOverlapHeight - 1) / notOverlapHeight;
        paddedHeight += notOverlapHeight * addVertCellsCnt;
    }
    cv::Mat paddedImage(paddedHeight, paddedWidth, CV_8UC3, cv::Scalar::all(0));
    cv::Rect imageRect(cv::Point(0, 0), cv::Size(image.cols, image.rows));
    image.copyTo(paddedImage(imageRect));
    return paddedImage;
}

/**
 * @brief Merges two bounding boxes.
 *     Merged bounding box is minimum bounding box that contains
 *    source bounding boxes.
 * @param lhs - bounding box
 * @param rhs - bounding box
 * @return merged bounding box
 */
geolib3::BoundingBox
mergeBBoxes(const geolib3::BoundingBox& lhs, const geolib3::BoundingBox& rhs) {
    double minX = std::min(lhs.lowerCorner().x(), rhs.lowerCorner().x());
    double minY = std::min(lhs.lowerCorner().y(), rhs.lowerCorner().y());
    double maxX = std::max(lhs.upperCorner().x(), rhs.upperCorner().x());
    double maxY = std::max(lhs.upperCorner().y(), rhs.upperCorner().y());
    return geolib3::BoundingBox(geolib3::Point2(minX, minY), geolib3::Point2(maxX, maxY));
}

/**
 * @brief Merges two building instances.
 *     1) Bounding box of merged instance is minimum bounding box
 *     that contains bounding boxes of source instances.
 *     2) Mask of merged instance is sum of masks of source instances.
 * @param lhs - building instance
 * @param rhs - building instance
 * @return merged building instance
 */
BldInstance mergeInstances(const BldInstance& lhs, const BldInstance& rhs) {
    geolib3::BoundingBox bbox = mergeBBoxes(lhs.bbox, rhs.bbox);
    double minX = bbox.lowerCorner().x();
    double minY = bbox.lowerCorner().y();
    cv::Mat mask(bbox.height(), bbox.width(), CV_8UC1, cv::Scalar(0));
    cv::Rect lhsRect(cv::Point(lhs.bbox.lowerCorner().x() - minX,
                               lhs.bbox.lowerCorner().y() - minY),
                         cv::Size(lhs.bbox.width(), lhs.bbox.height()));
    cv::Rect rhsRect(cv::Point(rhs.bbox.lowerCorner().x() - minX,
                               rhs.bbox.lowerCorner().y() - minY),
                         cv::Size(rhs.bbox.width(), rhs.bbox.height()));
    lhs.mask.copyTo(mask(lhsRect));
    cv::bitwise_or(rhs.mask, mask(rhsRect), mask(rhsRect));
    return BldInstance(bbox, mask);
}

/**
 * @brief Checks whether two instances describe one building.
 *     Two instances describe one building if:
 *     1) Number of common pixels is greater than some value
 *     2) IoU is greater than some value
 *     3) Number of common pixels is greater than some percentage
 *        of pixels of one of two buildings
 * @param lhs - building instance
 * @param rhs - building instance
 * @return true if instances describe one building, otherwise false
 */
bool isSameBuilding(const BldInstance& lhs, const BldInstance& rhs) {
    const int PIXELS_CNT_THRESHOLD = 800;
    const double IOU_THRESHOLD = 0.4;
    const double INTERSECT_AREA_RATIO_THRESHOLD = 0.4;

    if (!geolib3::spatialRelation(lhs.bbox, rhs.bbox,
                                  geolib3::SpatialRelation::Intersects)) {
        return false;
    }
    geolib3::BoundingBox bbox = mergeBBoxes(lhs.bbox, rhs.bbox);
    double minX = bbox.lowerCorner().x();
    double minY = bbox.lowerCorner().y();
    cv::Mat lhsMask(bbox.height(), bbox.width(), CV_8UC1, cv::Scalar::all(0));
    cv::Rect lhsMaskRect(cv::Point(lhs.bbox.lowerCorner().x() - minX,
                                   lhs.bbox.lowerCorner().y() - minY),
                         cv::Size(lhs.bbox.width(), lhs.bbox.height()));
    lhs.mask.copyTo(lhsMask(lhsMaskRect));
    int lhsArea = cv::countNonZero(lhsMask);
    cv::Mat rhsMask(bbox.height(), bbox.width(), CV_8UC1, cv::Scalar::all(0));
    cv::Rect rhsMaskRect(cv::Point(rhs.bbox.lowerCorner().x() - minX,
                                   rhs.bbox.lowerCorner().y() - minY),
                         cv::Size(rhs.bbox.width(), rhs.bbox.height()));
    rhs.mask.copyTo(rhsMask(rhsMaskRect));
    int rhsArea = cv::countNonZero(rhsMask);
    cv::Mat intersectMask;
    cv::bitwise_and(lhsMask, rhsMask, intersectMask);
    int intersectArea = cv::countNonZero(intersectMask);
    double iou = intersectArea / float(lhsArea + rhsArea - intersectArea);

    if (intersectArea > PIXELS_CNT_THRESHOLD) {
        return true;
    } else if (iou > IOU_THRESHOLD) {
        return true;
    } else if (intersectArea / float(lhsArea) > INTERSECT_AREA_RATIO_THRESHOLD ||
               intersectArea / float(rhsArea) > INTERSECT_AREA_RATIO_THRESHOLD) {
        return true;
    } else {
        return false;
    }
}

/**
 * @brief Finds contour of instance mask
 * @param instance - building instance
 * @return set of boundary points if it exists, otherwise std::nullopt
 */
std::optional<std::vector<cv::Point2f>> findContour(const BldInstance& instance) {
    std::vector<std::vector<cv::Point>> contours;
    std::vector<cv::Vec4i> hierarchy;
    cv::findContours(instance.mask, contours, hierarchy, CV_RETR_TREE, CV_CHAIN_APPROX_SIMPLE);
    std::optional<size_t> bestContourIndx;
    double maxArea = 0;
    for (size_t i = 0; i < contours.size(); i++) {
        // if nested
        if (hierarchy[i][3] >= 0) {
            continue;
        }
        double contourArea = cv::contourArea(contours[i]);
        if (contourArea > maxArea) {
            maxArea = contourArea;
            bestContourIndx = i;
        }
    }
    if (!bestContourIndx.has_value()) {
        // contour not found
        return std::nullopt;
    }
    std::vector<cv::Point2f> contour;
    int instanceX = instance.bbox.lowerCorner().x();
    int instanceY = instance.bbox.lowerCorner().y();
    for (const auto& cvPt : contours[*bestContourIndx]) {
        contour.emplace_back(cvPt.x + instanceX, cvPt.y + instanceY);
    }
    return contour;
}

/**
 * @brief Merges each instance group into one instance.
 * @param instances - building instances
 * @param groups - sets of instance indices corresponding to one group
 * @return merged instances
 */
std::vector<BldInstance>
mergeInstancesGroups(std::vector<BldInstance> instances,
                     const std::vector<std::vector<size_t>>& groups) {
    std::vector<BldInstance> mergedInstances;
    mergedInstances.reserve(groups.size());
    for (const auto& group : groups) {
        if (group.size() == 1) {
            mergedInstances.push_back(std::move(instances[group[0]]));
        } else if (group.size() > 1) {
            BldInstance mergedInstance = instances[group[0]];
            for (size_t i = 1; i < group.size(); i++) {
                mergedInstance = mergeInstances(mergedInstance, instances[group[i]]);
            }
            mergedInstances.push_back(mergedInstance);
        }
    }
    return mergedInstances;
}

} // anonymous namespace

std::vector<std::vector<std::pair<int, int>>> makeCropBatches(
    int horizCellsCnt, int vertCellsCnt,
    size_t batchSize)
{
    std::vector<std::vector<std::pair<int, int>>> batches;

    std::vector<std::pair<int, int>> tmpBatch;
    for (int i = 0; i < horizCellsCnt; i++) {
        for (int j = 0; j < vertCellsCnt; j++) {
            tmpBatch.emplace_back(i, j);

            if (tmpBatch.size() >= batchSize) {
                batches.push_back(tmpBatch);
                tmpBatch.clear();
            }
        }
    }

    if (!tmpBatch.empty()) {
        batches.push_back(tmpBatch);
    }

    return batches;
}

std::vector<BldInstance> segmentInstances(
    const tf_inferencer::MaskRCNNInferencer& maskrcnn,
    const cv::Mat& imageBGR,
    size_t batchSize)
{
    // Image resize to this size before MASK R-CNN inference
    const cv::Size INPUT_SIZE(800, 800);
    const int OVERLAP = 200;

    cv::Mat imageRGB;
    cv::cvtColor(imageBGR, imageRGB, cv::COLOR_BGR2RGB);
    cv::Mat paddedImage = padImage(imageRGB, INPUT_SIZE, OVERLAP);
    int horizCellsCnt = 1 + (paddedImage.cols - INPUT_SIZE.width) / (INPUT_SIZE.width - OVERLAP);
    int vertCellsCnt = 1 + (paddedImage.rows - INPUT_SIZE.height) / (INPUT_SIZE.height - OVERLAP);
    std::vector<BldInstance> bldInstances;

    for (const auto& batch : makeCropBatches(horizCellsCnt, vertCellsCnt, batchSize)) {
        std::vector<std::pair<int, int>> cellsCoords(batch.size());
        std::vector<cv::Mat> curImages(batch.size());

        for (size_t i = 0; i < batch.size(); i++) {
            const auto& [x, y] = batch[i];

            int cellX = (INPUT_SIZE.width - OVERLAP) * x;
            int cellY = (INPUT_SIZE.height - OVERLAP) * y;

            cv::Rect imageRect(cv::Point(cellX, cellY), INPUT_SIZE);

            cellsCoords[i] = std::make_pair(cellX, cellY);
            curImages[i] = paddedImage(imageRect);
        }

        auto instancesBatches = inferenceMaskRCNN(maskrcnn, curImages);

        REQUIRE(instancesBatches.size() == curImages.size(),
                "Incorrect Mask RCNN output");

        for (size_t i = 0; i < instancesBatches.size(); i++) {
            const auto& instances = instancesBatches[i];
            const auto& [cellX, cellY] = cellsCoords[i];

            for (const auto& instance : instances) {
                int minX = cellX + instance.bbox.lowerCorner().x();
                int minY = cellY + instance.bbox.lowerCorner().y();
                int maxX = minX + instance.bbox.width();
                int maxY = minY + instance.bbox.height();
                geolib3::BoundingBox originalBBox(geolib3::Point2(minX, minY),
                                                  geolib3::Point2(maxX, maxY));
                bldInstances.emplace_back(originalBBox, instance.mask);
            }
        }
    }

    return mergeInstances(std::move(bldInstances));
}

std::vector<BldInstance> mergeInstances(std::vector<BldInstance> instances) {
    std::vector<std::vector<size_t>> groups;
    std::vector<bool> isUsed(instances.size(), false);

    geolib3::StaticGeometrySearcher<geolib3::BoundingBox, size_t> searcher;
    for (size_t i = 0; i < instances.size(); i++) {
        searcher.insert(&(instances[i].bbox), i);
    }
    searcher.build();

    for (size_t i = 0; i < instances.size(); i++) {
        if (isUsed[i]) {
            continue;
        }
        std::vector<size_t> newGroup;
        std::stack<size_t> sameBldInstancesIndx({i});
        while (!sameBldInstancesIndx.empty()) {
            size_t indx = sameBldInstancesIndx.top();
            sameBldInstancesIndx.pop();
            isUsed[indx] = true;
            newGroup.push_back(indx);
            auto searchResult = searcher.find(instances[indx].bbox);
            for (auto it = searchResult.first; it != searchResult.second; it++) {
                size_t candidatIndx = it->value();
                if (isUsed[candidatIndx]) {
                    continue;
                }
                if (isSameBuilding(instances[indx], instances[candidatIndx])) {
                    sameBldInstancesIndx.push(candidatIndx);
                }
            }
        }
        groups.push_back(std::move(newGroup));
    }

    return mergeInstancesGroups(std::move(instances), groups);
}

std::vector<std::vector<cv::Point2f>>
vectorizeInstances(const std::vector<BldInstance>& instances) {
    // Polygon regularization parameters in pixels
    const double TOLERANCE = 12.;
    const double GRID_SIDE = 3.;

    std::vector<std::vector<cv::Point2f>> blds;
    for (const auto& instance : instances) {
        auto contour = findContour(instance);
        if (contour.has_value()) {
            auto bld = gridRegularizePolygon(*contour, TOLERANCE, GRID_SIDE);
            blds.push_back(bld);
        }
    }
    return blds;
}

} // namespace autocart
} // namespace wiki
} // namespace maps
