#include <maps/libs/cmdline/include/cmdline.h>
#include <maps/libs/log8/include/log8.h>
#include <maps/libs/common/include/exception.h>
#include <maps/libs/geolib/include/polygon.h>
#include <maps/libs/geolib/include/static_geometry_searcher.h>
#include <maps/wikimap/mapspro/services/autocart/libs/geometry/include/polygon_processing.h>
#include <maps/wikimap/mapspro/services/autocart/libs/post_processing/include/post_processing.h>
#include <maps/wikimap/mapspro/services/autocart/libs/detection/include/detect_bld.h>
#include <maps/wikimap/mapspro/libs/tf_inferencer/tf_inferencer.h>

#include <opencv2/opencv.hpp>

#include <string>
#include <vector>
#include <fstream>
#include <algorithm>
#include <memory>
#include <utility>

namespace autocart = maps::wiki::autocart;
namespace tf_inferencer = maps::wiki::tf_inferencer;
namespace geolib3 = maps::geolib3;

namespace {

const std::string DEFAULT_ITEM_LIST_PATH = "data/data_list.txt";
const std::string DEFAULT_DATA_PATH = "data/";
const std::string DEFAULT_RECOGNITION_MODE = "rectangle";
const std::string DEFAULT_SEMSEGM_GDEF_PATH = "data/sem_segm.gdef";
const std::string DEFAULT_EDGE_VERT_GDEF_PATH = "data/edge_vert.gdef";
const std::string DEFAULT_EDGE_GDEF_PATH = "data/edge_detection.gdef";
const double DEFAULT_IOU_THRESHOLD = 0.5;

class BldDetector {
public:
    virtual ~BldDetector() = default;
    virtual std::vector<autocart::Polygon> detectBlds(const cv::Mat& image) = 0;
};

class BldDetectorByMinAreaRect : public BldDetector {
public:
    BldDetectorByMinAreaRect(const std::string& semSegmPath,
                             const std::string& edgeDetectorPath)
        : segmentator_(semSegmPath),
          edgeDetector_(edgeDetectorPath) {}
    std::vector<autocart::Polygon> detectBlds(const cv::Mat &image) override {
        return autocart::detectBldByMinAreaRect(segmentator_, edgeDetector_, image);
    }

private:
    tf_inferencer::TensorFlowInferencer segmentator_;
    tf_inferencer::TensorFlowInferencer edgeDetector_;
};

class BldDetectorByEdges : public BldDetector {
public:
    BldDetectorByEdges(const std::string& edgeDetectorPath)
        : edgeDetector_(edgeDetectorPath) {}
    std::vector<autocart::Polygon> detectBlds(const cv::Mat &image) override {
        return autocart::detectBldByEdges(edgeDetector_, image);
    }

private:
    tf_inferencer::TensorFlowInferencer edgeDetector_;
};

class BldDetectorByVertsEdges : public BldDetector {
public:
    BldDetectorByVertsEdges(const std::string& semSegmPath,
                            const std::string& edgeVertDetectorPath)
        : segmentator_(semSegmPath),
          edgeVertDetector_(edgeVertDetectorPath) {}
    std::vector<autocart::Polygon> detectBlds(const cv::Mat &image) override {
        return autocart::detectBldByVertsEdges(segmentator_, edgeVertDetector_, image);
    }

private:
    tf_inferencer::TensorFlowInferencer segmentator_;
    tf_inferencer::TensorFlowInferencer edgeVertDetector_;
};

std::unique_ptr<BldDetector>
createDetector(const std::string& mode,
               const std::string& semSegmPath,
               const std::string& edgeDetectorPath,
               const std::string& edgeVertDetectorPath) {
    if (mode == "rectangle") {
        REQUIRE(!semSegmPath.empty() && !edgeDetectorPath.empty(),
                "Path to gdef for semantic segmentation and edge detection"
                " networks must be specified for \"rectangle\" method");
        return std::make_unique<BldDetectorByMinAreaRect>(semSegmPath,
                                                          edgeDetectorPath);
    } else if (mode == "edge") {
        REQUIRE(!edgeDetectorPath.empty(),
                "Path to gdef for edge detection network"
                " must be specified for \"edge\" method");
        return std::make_unique<BldDetectorByEdges>(edgeDetectorPath);
    } else if (mode == "edge_vert") {
        REQUIRE(!semSegmPath.empty() && !edgeVertDetectorPath.empty(),
                "Path to gdef for semantic segmentation, edge and vert detection"
                " networks must be specified for \"edge_vert\" method");
        return std::make_unique<BldDetectorByVertsEdges>(semSegmPath,
                                                         edgeVertDetectorPath);
    } else {
        REQUIRE(false, "Unknown detection method:" << mode);
    }
}


struct BenchmarkItem {
    std::string imagePath;
    std::string labelPath;
};

using BenchmarkDataset = std::vector<BenchmarkItem>;

BenchmarkDataset loadDataset(const std::string& itemsListPath,
                             const std::string& dataPath = "") {
    BenchmarkDataset dataset;

    std::ifstream ifs(itemsListPath);
    REQUIRE(ifs.is_open(), "Failed to open file: " << itemsListPath);

    std::string normalizedDataPath = dataPath;
    if (!normalizedDataPath.empty() && normalizedDataPath.back() != '/') {
        normalizedDataPath += '/';
    }

    while (!ifs.eof()) {
        std::string line;
        std::getline(ifs, line);
        if (line.empty()) {
            continue;
        }

        std::stringstream ss(line);
        std::string imageFilename;
        std::string labelFilename;
        ss >> imageFilename >> labelFilename;

        BenchmarkItem item;
        item.imagePath = normalizedDataPath + imageFilename;
        item.labelPath = normalizedDataPath + labelFilename;
        dataset.push_back(item);
    }

    return dataset;
}

std::vector<geolib3::Polygon2> loadPolygons(const std::string& labelPath) {
    std::vector<geolib3::Polygon2> polygons;
    std::ifstream ifs(labelPath);
    REQUIRE(ifs.is_open(), "Failed to open file: " << labelPath);

    while (!ifs.eof()) {
        std::string line;
        std::getline(ifs, line);
        if (line.empty()) {
            continue;
        }
        std::stringstream ss(line);
        std::string className;
        int ptsCnt;
        ss >> className >> ptsCnt;
        if ("bld" != className) {
            continue;
        }
        std::vector<geolib3::Point2> pts;
        pts.reserve(ptsCnt);

        for (int i = 0; i < ptsCnt; i++) {
            double x, y;
            ss >> x >> y;
            pts.emplace_back(x, y);
        }
        polygons.emplace_back(std::move(pts));
    }

    return polygons;
}

std::vector<geolib3::Polygon2>
convertToGeolibPolygons(const std::vector<autocart::Polygon> &polygons)
{
    std::vector<geolib3::Polygon2> result;
    result.reserve(polygons.size());

    for (const auto &polygon : polygons) {
        std::vector<geolib3::Point2> pts;
        pts.reserve(polygon.size());
        for (const auto & point : polygon) {
            pts.emplace_back(point.x, point.y);
        }
        result.emplace_back(std::move(pts));
    }
    return result;
}

std::vector<geolib3::Polygon2>
rectifyAndAlignPolygons(const std::vector<geolib3::Polygon2>& polygons)
{
    std::vector<geolib3::Polygon2> boundingBoxes;
    boundingBoxes.reserve(polygons.size());
    std::transform(
        polygons.begin(), polygons.end(),
        std::back_inserter(boundingBoxes),
        [](const geolib3::Polygon2& polygon)
        {
            return autocart::getBoundingRectangle(polygon);
        }
    );

    std::vector<geolib3::Polygon2> result = autocart::alignRectangles(boundingBoxes, 3, 10, 200, 5);
    return result;
}

std::vector<geolib3::Polygon2>
postprocessBlds(const std::vector<autocart::Polygon>& polygons) {
    std::vector<geolib3::Polygon2> geolibPolygons =
            convertToGeolibPolygons(polygons);
    std::vector<geolib3::Polygon2> alignedPolygons =
            rectifyAndAlignPolygons(geolibPolygons);
    autocart::removeIntersections(alignedPolygons);

    return alignedPolygons;
}

struct BenchmarkResult {
    BenchmarkResult() {
        gtCnt = 0;
        testCnt = 0;
        foundCnt = 0;
    }

    BenchmarkResult& operator+=(const BenchmarkResult& otherResult) {
        gtCnt += otherResult.gtCnt;
        testCnt += otherResult.testCnt;
        foundCnt += otherResult.foundCnt;
        return *this;
    }

    void print() {
        double precision = testCnt > 0 ? foundCnt / (double)testCnt : 0.;
        double recall = gtCnt > 0 ? foundCnt / (double)gtCnt : 0.;
        INFO() << "Result:";
        INFO() << "  Ground truth polygons number:" << gtCnt;
        INFO() << "  Found polygons number: " << foundCnt;
        INFO() << "  Extra polygons number: " << testCnt - foundCnt;
        INFO() << "  Precision: " << precision;
        INFO() << "  Recall: " << recall;
    }

    size_t gtCnt;
    size_t testCnt;
    size_t foundCnt;
};

struct PolygonIntersection {
    size_t gtIndx;
    size_t testIndx;
    double IoU;
    PolygonIntersection(size_t gtIndx, size_t testIndx, double IoU)
        : gtIndx(gtIndx),
          testIndx(testIndx),
          IoU(IoU) {}
};

BenchmarkResult
comparePolygons(const std::vector<geolib3::Polygon2>& gtPolygons,
                const std::vector<geolib3::Polygon2>& testPolygons,
                double iouThreshold) {
    geolib3::StaticGeometrySearcher<geolib3::Polygon2, size_t> searcher;
    for (size_t i = 0; i < testPolygons.size(); i++) {
        searcher.insert(&testPolygons[i], i);
    }
    searcher.build();
    std::vector<PolygonIntersection> intersections;
    for (size_t gtIndx = 0; gtIndx < gtPolygons.size(); gtIndx++) {
        auto searchResult = searcher.find(gtPolygons[gtIndx].boundingBox());
        for (auto it = searchResult.first; it != searchResult.second; it++) {
            size_t testIndx = it->value();
            double iouValue = autocart::IoU(gtPolygons[gtIndx], testPolygons[testIndx]);
            if (iouValue > iouThreshold) {
                intersections.emplace_back(gtIndx, testIndx, iouValue);
            }
        }
    }
    std::sort(
        intersections.begin(), intersections.end(),
        [](const auto& lhs, const auto& rhs)
        {
            return lhs.IoU > rhs.IoU;
        }
    );

    std::vector<bool> gtIsFounded(gtPolygons.size(), false);
    std::vector<bool> testIsFounded(testPolygons.size(), false);

    BenchmarkResult result;
    result.testCnt = testPolygons.size();
    result.gtCnt = gtPolygons.size();
    for (const auto intersection : intersections) {
        if (gtIsFounded[intersection.gtIndx] || testIsFounded[intersection.testIndx]) {
            continue;
        }
        result.foundCnt++;
        gtIsFounded[intersection.gtIndx] = true;
        testIsFounded[intersection.testIndx] = true;
    }

    return result;
}

} // namespace

int main(int argc, char** argv)
try {
    maps::cmdline::Parser parser;
    auto dataPathParam = parser.string("data_path")\
            .defaultValue(DEFAULT_DATA_PATH)\
            .help("path to folder with benchmark dataset");

    auto itemsListPathParam = parser.string("items_list_path")\
            .defaultValue(DEFAULT_ITEM_LIST_PATH)\
            .help("path to txt file with list of files in benchmark dataset");

    auto modeParam = parser.string("mode")\
            .defaultValue(DEFAULT_RECOGNITION_MODE)\
            .help("recognition mode");

    auto semSegmPathParam = parser.string("sem_segm_path")\
            .defaultValue(DEFAULT_SEMSEGM_GDEF_PATH)\
            .help("path to the gdef for semantic segmentation network");

    auto edgeDetectPathParam = parser.string("edge_detector_path")\
            .defaultValue(DEFAULT_EDGE_GDEF_PATH)\
            .help("path to the gdef for edge detection network");

    auto edgeVertDetectPathParam = parser.string("edge_vert_detector_path")\
            .defaultValue(DEFAULT_EDGE_VERT_GDEF_PATH)\
            .help("path to the gdef for edge and vertex detection network");

    auto iouThresholdParam = parser.real("iou_threshold")\
            .defaultValue(DEFAULT_IOU_THRESHOLD)\
            .help("threshold for intersection over union");
    parser.parse(argc, argv);

    BenchmarkDataset dataset = loadDataset(itemsListPathParam, dataPathParam);

    std::unique_ptr<BldDetector> detector =
            createDetector(modeParam, semSegmPathParam,
                           edgeDetectPathParam, edgeVertDetectPathParam);
    BenchmarkResult totalCompareResult;
    for (const auto& item : dataset) {
        INFO() << "Image: " << item.imagePath;
        cv::Mat image = cv::imread(item.imagePath);
        REQUIRE(!image.empty(), "Failed to open image: " << item.imagePath);
        std::vector<geolib3::Polygon2> gtBlds = loadPolygons(item.labelPath);

        std::vector<autocart::Polygon> detectedBlds = detector->detectBlds(image);
        std::vector<geolib3::Polygon2> processedBlds = postprocessBlds(detectedBlds);

        BenchmarkResult compareResult = comparePolygons(gtBlds, processedBlds,
                                                        iouThresholdParam);
        compareResult.print();
        totalCompareResult += compareResult;
    }

    INFO() << "Total:";
    totalCompareResult.print();

    return EXIT_SUCCESS;
}
catch (const maps::Exception& e) {
    INFO() << e;
    return EXIT_FAILURE;
}
catch (const std::exception& e) {
    INFO() << e.what();
    return EXIT_FAILURE;
}
catch (...) {
    INFO() << "Caught unknown exception";
    return EXIT_FAILURE;
}
