#include <maps/wikimap/mapspro/services/autocart/pipeline/libs/objects/include/bbox.h>
#include <maps/wikimap/mapspro/services/autocart/pipeline/libs/objects/include/area.h>
#include <maps/wikimap/mapspro/services/autocart/pipeline/libs/objects/include/road.h>
#include <maps/wikimap/mapspro/services/autocart/pipeline/libs/objects/include/building.h>
#include <maps/wikimap/mapspro/services/autocart/pipeline/libs/objects/include/dwellplace.h>
#include <maps/wikimap/mapspro/services/autocart/pipeline/libs/detection/include/state.h>
#include <maps/wikimap/mapspro/services/autocart/pipeline/libs/detection/include/detect_blds_in_cells.h>
#include <maps/wikimap/mapspro/services/autocart/pipeline/libs/yt_utils/include/op_wrapper.h>
#include <maps/wikimap/mapspro/services/autocart/pipeline/libs/yt_utils/include/batch.h>
#include <maps/wikimap/mapspro/services/autocart/pipeline/libs/yt_utils/include/rows_count.h>

#include <maps/libs/log8/include/log8.h>

#include <maps/libs/geolib/include/const.h>
#include <maps/libs/geolib/include/point.h>
#include <maps/libs/geolib/include/polygon.h>
#include <maps/libs/geolib/include/polyline.h>
#include <maps/libs/geolib/include/linear_ring.h>
#include <maps/libs/geolib/include/bounding_box.h>
#include <maps/libs/geolib/include/spatial_relation.h>

#include <maps/libs/tile/include/const.h>
#include <maps/libs/tile/include/utils.h>

#include <maps/libs/common/include/base64.h>

#include <maps/wikimap/mapspro/libs/tf_inferencer/tf_inferencer.h>
#include <maps/wikimap/mapspro/libs/tf_inferencer/maskrcnn_inferencer.h>

#include <maps/wikimap/mapspro/services/autocart/libs/geometry/include/hex_wkb.h>
#include <maps/wikimap/mapspro/services/autocart/libs/detection/include/detect_bld.h>
#include <maps/wikimap/mapspro/services/autocart/libs/satellite/include/load_sat_image.h>
#include <maps/wikimap/mapspro/services/autocart/libs/post_processing/include/post_processing.h>
#include <maps/wikimap/mapspro/services/autocart/libs/post_processing/include/align_along_road.h>
#include <maps/wikimap/mapspro/services/autocart/libs/utils/include/multithreading.h>

#include <library/cpp/string_utils/base64/base64.h>

#include <mapreduce/yt/util/temp_table.h>
#include <mapreduce/yt/interface/client.h>

#include <util/generic/size_literals.h>
#include <util/thread/pool.h>
#include <util/thread/lfqueue.h>

#include <opencv2/opencv.hpp>

#include <vector>
#include <cmath>
#include <algorithm>

namespace maps::wiki::autocart::pipeline {

namespace {

static const TString SAT_IMAGE = "image";

static const TString RECTANGLES_MODE = "rectangles";
static const TString EDGES_MODE = "edges";
static const TString COMPLEX_MASKRCNN_MODE = "complex_maskrcnn";

static const std::set<TString> AVAILABLE_MODES{
    RECTANGLES_MODE,
    EDGES_MODE,
    COMPLEX_MASKRCNN_MODE
};

const TString TF_MODELS_RESOURCE_FOLDER = "/maps/autocart/dwellplaces/models";
const TString TF_MODEL_SEM_SEGM_RESOURCE
    = TF_MODELS_RESOURCE_FOLDER + "/sem_segm.gdef";
const TString TF_MODEL_EDGE_DETECTION_RESOURCE
    = TF_MODELS_RESOURCE_FOLDER + "/edge_detection.gdef";
const TString TF_MODEL_MASKRCNN_RESOURCE
    = TF_MODELS_RESOURCE_FOLDER + "/maskrcnn.gdef";

class BldDetector {
public:
    virtual std::vector<geolib3::Polygon2> detect(const cv::Mat& image) const = 0;
    virtual ~BldDetector() = default;

protected:
    std::vector<geolib3::Polygon2>
    convertToGeolibPolygons(
        const std::vector<std::vector<cv::Point2f>>& cvPolygons) const
    {
        std::vector<geolib3::Polygon2> geolibPolygons;
        geolibPolygons.reserve(cvPolygons.size());
        for (const auto& cvPolygon : cvPolygons) {
            std::vector<geolib3::Point2> geolibPoints;
            for (const auto& cvPoint : cvPolygon) {
                geolibPoints.emplace_back(cvPoint.x, cvPoint.y);
            }
            geolibPolygons.emplace_back(geolibPoints);
        }
        return geolibPolygons;
    }
};

class BldDetectorByRectangles : public BldDetector {
public:
    BldDetectorByRectangles()
        : segmentator_(
              tf_inferencer::TensorFlowInferencer::fromResource(
                  TF_MODEL_SEM_SEGM_RESOURCE)
              ),
          edgeDetector_(
              tf_inferencer::TensorFlowInferencer::fromResource(
                  TF_MODEL_EDGE_DETECTION_RESOURCE)
              )
    {}

    std::vector<geolib3::Polygon2> detect(const cv::Mat& image) const override {
        std::vector<std::vector<cv::Point2f>> cvPolygons
            = detectBldByMinAreaRect(segmentator_, edgeDetector_, image);
        return convertToGeolibPolygons(cvPolygons);
    }

private:
    tf_inferencer::TensorFlowInferencer segmentator_;
    tf_inferencer::TensorFlowInferencer edgeDetector_;
};

class BldDetectorByEdges : public BldDetector {
public:
    BldDetectorByEdges()
        : edgeDetector_(
              tf_inferencer::TensorFlowInferencer::fromResource(
                  TF_MODEL_EDGE_DETECTION_RESOURCE)
              )
    {}

    std::vector<geolib3::Polygon2> detect(const cv::Mat& image) const override {
        std::vector<std::vector<cv::Point2f>> cvPolygons
           = detectBldByEdges(edgeDetector_, image);
        return convertToGeolibPolygons(cvPolygons);
    }

private:
    tf_inferencer::TensorFlowInferencer edgeDetector_;
};

class BldDetectorByMaskRCNN : public BldDetector {
public:
    BldDetectorByMaskRCNN()
        : maskrcnn_(
              tf_inferencer::MaskRCNNInferencer::fromResource(
                  TF_MODEL_MASKRCNN_RESOURCE)
              )
    {}

    std::vector<geolib3::Polygon2> detect(const cv::Mat& image) const override {
        std::vector<std::vector<cv::Point2f>> cvPolygons
            = detectComplexBldByMaskRCNN(maskrcnn_, image);
        return convertToGeolibPolygons(cvPolygons);
    }

private:
    tf_inferencer::MaskRCNNInferencer maskrcnn_;
};

BldDetector* createDetector(const TString& mode) {
    if (RECTANGLES_MODE == mode) {
        return new BldDetectorByRectangles();
    } else if (EDGES_MODE == mode) {
        return new BldDetectorByEdges();
    } else if (COMPLEX_MASKRCNN_MODE == mode) {
        return new BldDetectorByMaskRCNN();
    } else {
        throw maps::RuntimeError("Unknown detection mode: " + mode);
    }
}

geolib3::Polygon2 imageToMercator(
    const geolib3::Polygon2& polygon,
    const geolib3::Point2& origin,
    size_t zoom)
{
    std::vector<geolib3::Point2> points;
    for (size_t i = 0; i < polygon.pointsNumber(); i ++) {
        points.push_back(
            autocart::imageToMercator(polygon.pointAt(i), origin, zoom));
    }
    return geolib3::Polygon2(std::move(points));
}

void replacePolygonsByRectangles(
    std::vector<geolib3::Polygon2>& polygons)
{
    for (geolib3::Polygon2& polygon : polygons) {
        polygon = getBoundingRectangle(polygon);
    }
}

void alignPolygons(std::vector<geolib3::Polygon2>& polygons) {
    constexpr size_t NEIGHBORS_NUMBER = 3;
    constexpr double ANGLE_DELTA_DEGREE = 10.;
    constexpr double ANGLE_DIST_DELTA = 200.;
    constexpr size_t ITER_COUNT = 5;

    polygons = alignRectangles(
        polygons,
        NEIGHBORS_NUMBER,
        ANGLE_DELTA_DEGREE,
        ANGLE_DIST_DELTA,
        ITER_COUNT
    );
}

void alignPolygonsAlongRoads(
    std::vector<geolib3::Polygon2>& polygons,
    const std::vector<geolib3::Polyline2>& roads)
{
    AlignAlongRoadsParams params;
    polygons = alignAlongRoads(polygons, roads, params);
}

template <typename Geom>
void removeOverlapsWithObjects(
    std::vector<geolib3::Polygon2>& newObjs,
    const std::vector<Geom>& oldObjs)
{
    if (newObjs.empty()) {
        return;
    }
    for (int i = (int)newObjs.size() - 1; 0 <= i; i--) {
        for (size_t j = 0; j < oldObjs.size(); j++) {
            if (geolib3::spatialRelation(newObjs[i],
                                         oldObjs[j],
                                         geolib3::SpatialRelation::Intersects)) {
               newObjs.erase(newObjs.begin() + i);
               break;
            }
        }
    }
}

void removeOutPolygons(
    std::vector<geolib3::Polygon2>& objs,
    const  geolib3::BoundingBox& bbox)
{
    if (objs.empty()) {
        return;
    }
    for (int i = (int)objs.size() - 1; 0 <= i; i--) {
        if (!geolib3::spatialRelation(bbox,
                                      objs[i],
                                      geolib3::SpatialRelation::Contains)) {
            objs.erase(objs.begin() + i);
        }
    }
}

TString encodeImage(const cv::Mat& image) {
    std::vector<uint8_t> encodedImageData;
    cv::imencode(".png", image, encodedImageData);
    return maps::base64Encode(encodedImageData).c_str();
}

class LoadSatImageMapper
    : public NYT::IMapper<NYT::TTableReader<NYT::TNode>,
                          NYT::TTableWriter<NYT::TNode>> {
public:
    LoadSatImageMapper() = default;
    LoadSatImageMapper(size_t zoom, const TString& tileSourceUrl)
        : zoom_(zoom)
        , tileSourceUrl_(tileSourceUrl)
    {}

    Y_SAVELOAD_JOB(zoom_, tileSourceUrl_);

    void Do(NYT::TTableReader<NYT::TNode>* reader,
            NYT::TTableWriter<NYT::TNode>* writer) override
    {
        for (; reader->IsValid(); reader->Next()) {
            NYT::TNode node = reader->GetRow();
            BBox bbox = BBox::fromYTNode(node);

            cv::Mat image;
            try {
                image = loadSatImage(bbox.toMercatorGeom(), zoom_, tileSourceUrl_);
            } catch (const std::exception& e) {
                WARN() << e.what();
                continue;
            }

            node[SAT_IMAGE] = encodeImage(image);

            writer->AddRow(node);
        }
    }

private:
    size_t zoom_;
    TString tileSourceUrl_;
};

REGISTER_MAPPER(LoadSatImageMapper);

struct InputData {
    BBox bbox;
    cv::Mat image;
    std::vector<geolib3::Polyline2> roads;
    std::vector<geolib3::Polyline2> mainRoads;
    std::vector<geolib3::Polygon2> blds;
    std::vector<geolib3::MultiPolygon2> areas;
};

cv::Mat decodeImage(const TString& encimageBase64Str) {
    std::vector<std::uint8_t> encimage(Base64DecodeBufSize(encimageBase64Str.length()));
    size_t encimageSize = Base64Decode(encimage.data(), encimageBase64Str.begin(), encimageBase64Str.end());
    encimage.resize(encimageSize);
    return cv::imdecode(encimage, cv::IMREAD_COLOR);
}

// Detect building for each image in input table
// Input table should have columns:
// 'cell_bbox' - bbox coordinates in Mercator
// 'blds' - list of buildings in bbox in Mercator
// 'roads' - list if roads in bbox in Mercator
//
// Detected buildings are saved in geodetic coodinates (lat, lon)
//
class DetectBuildingMapper
    : public NYT::IMapper<NYT::TTableReader<NYT::TNode>,
                          NYT::TTableWriter<NYT::TNode>> {
public:
    DetectBuildingMapper() = default;
    DetectBuildingMapper(const TString& mode, size_t zoom);

    Y_SAVELOAD_JOB(mode_, zoom_);

    void Do(NYT::TTableReader<NYT::TNode>* reader,
            NYT::TTableWriter<NYT::TNode>* writer) override;

private:
    InputData loadInputData(const NYT::TNode& node) const {
        InputData data;
        data.bbox = BBox::fromYTNode(node);
        data.image = decodeImage(node[SAT_IMAGE].AsString());
        for (const Road& road : roadsFromYTNode(node)) {
            geolib3::Polyline2 mercGeom = road.toMercatorGeom();
            data.roads.push_back(mercGeom);
            if (road.getFc() < 9 && road.getFow() < 12) {
                 data.mainRoads.push_back(mercGeom);
            }
        }
        for (const Building& bld : bldsFromYTNode(node)) {
            data.blds.emplace_back(bld.toMercatorGeom());
        }
        for (const Area& area : areasFromYTNode(node)) {
            data.areas.emplace_back(area.toMercatorGeom());
        }
        return data;
    }

    TString mode_;
    size_t zoom_;
};

DetectBuildingMapper::DetectBuildingMapper(const TString& mode, size_t zoom)
    : mode_(mode)
    , zoom_(zoom)
{}

void DetectBuildingMapper::Do(
    NYT::TTableReader<NYT::TNode>* reader,
    NYT::TTableWriter<NYT::TNode>* writer)
{
    std::unique_ptr<BldDetector> detector(createDetector(mode_));

    for (; reader->IsValid(); reader->Next()) {
        InputData data = loadInputData(reader->GetRow());

        geolib3::BoundingBox geoBbox = data.bbox.toGeodeticGeom();
        INFO() << "Detecting buildings in region: "
               << "[[" << geoBbox.minX() << ", " << geoBbox.minY() << "],"
               << " [" << geoBbox.maxX() << ", " << geoBbox.maxY() << "]]";

        INFO() << "Image size: " << data.image.rows << ", " << data.image.cols;

        INFO() << "Detecting buildings on image";
        std::vector<geolib3::Polygon2> polygons = detector->detect(data.image);
        INFO() << "Detected " << polygons.size() << " buildings";

        geolib3::BoundingBox mercatorBbox = data.bbox.toMercatorGeom();

        INFO() << "Converting buildings to mercator";
        geolib3::Point2 origin = getDisplayOrigin(mercatorBbox, zoom_);
        std::vector<geolib3::Polygon2> mercatorPolygons;
        mercatorPolygons.reserve(polygons.size());
        for (const geolib3::Polygon2& polygon : polygons) {
             mercatorPolygons.emplace_back(imageToMercator(polygon, origin, zoom_));
        }

        INFO() << "Existing roads count: " << data.roads.size();
        INFO() << "Existing main roads count: " << data.mainRoads.size();
        INFO() << "Existing buildings count: " << data.blds.size();
        INFO() << "Existing areas count: " << data.areas.size();

        INFO() << "Replacing polygons with rectangles";
        replacePolygonsByRectangles(mercatorPolygons);

        INFO() << "Aligning polygons";
        alignPolygons(mercatorPolygons);

        INFO() << "Aligning polygons along roads";
        alignPolygonsAlongRoads(mercatorPolygons, data.mainRoads);

        INFO() << "Removing intersections";
        removeIntersections(mercatorPolygons);

        INFO() << "Removing out-of-bounds polygons";
        removeOutPolygons(mercatorPolygons, mercatorBbox);
        INFO() << mercatorPolygons.size() << " buildings left";

        INFO() << "Removing intersections with existing buildings";
        removeOverlapsWithObjects(mercatorPolygons, data.blds);
        INFO() << mercatorPolygons.size() << " buildings left";

        INFO() << "Removing intersections with existing roads";
        removeOverlapsWithObjects(mercatorPolygons, data.roads);
        INFO() << mercatorPolygons.size() << " buildings left";

        INFO() << "Removing intersections with existing areas";
        removeOverlapsWithObjects(mercatorPolygons, data.areas);
        INFO() << mercatorPolygons.size() << " buildings left";

        for (const geolib3::Polygon2& mercatorPolygon : mercatorPolygons) {
            Building bld = Building::fromMercatorGeom(mercatorPolygon);
            writer->AddRow(
                bld.toYTNode()
            );
        }
    }
}

REGISTER_MAPPER(DetectBuildingMapper);

class BuildingDetectorProcessor : public ObjectInQueueWithData {
public:
    BuildingDetectorProcessor(
        NYT::IClientBasePtr client,
        TLockFreeQueue<TString>* imageYTPathQueue,
        NYT::TRichYPath outputRYTPath,
        DetectorConfig config)
        : client_(client)
        , imageYTPathQueue_(imageYTPathQueue)
        , outputRYTPath_(outputRYTPath)
        , config_(config)
    {}

    void Process(void* /*threadSpecificResource*/) override {
        for (;;) {
            TString imageYTPath;
            if (imageYTPathQueue_->Dequeue(&imageYTPath)) {

                size_t rowsCount = getRowsCount(client_, imageYTPath);
                size_t jobCount = std::max(rowsCount / config_.jobSize(), 1ul);

                YTOpExecutor::Map(
                    client_,
                    YTOpExecutor::MapSpec()
                        .AddInput(imageYTPath)
                        .AddOutput(outputRYTPath_),
                    new DetectBuildingMapper(
                        config_.mode(),
                        config_.zoom()
                    ),
                    YTOpExecutor::Options()
                        .Title("[Buildings detector] Detecting buildings")
                        .UseGPU(1)
                        .JobCount(jobCount)
                        .MemoryLimit(16_GB)
                );
                client_->Remove(imageYTPath, NYT::TRemoveOptions().Recursive(true).Force(true));
            } else if (!isWaitData()) {
                break;
            } else {
                std::this_thread::sleep_for(std::chrono::seconds(1));
            }
        }
        Stop();
    }
private:
    NYT::IClientBasePtr client_;
    TLockFreeQueue<TString>* imageYTPathQueue_;
    NYT::TRichYPath outputRYTPath_;
    DetectorConfig config_;
};

} // namespace

void detectBuildingsInCells(
    NYT::IClientBasePtr client,
    const TString& inputYTTablePath,
    const DetectorConfig& config,
    const TString& outputYTTablePath)
{
    static const size_t MAX_JOB_COUNT = 1000;

    REQUIRE(AVAILABLE_MODES.find(config.mode()) != AVAILABLE_MODES.end(),
            "Unknown detection algorithm: " + config.mode());

    size_t rowsCount = getRowsCount(client, inputYTTablePath);
    size_t batchSize = std::min(config.jobSize() * MAX_JOB_COUNT, rowsCount);

    NYT::TRichYPath outputRYTPath(outputYTTablePath);
    outputRYTPath.Append(true);
    client->Create(
        outputRYTPath.Path_,
        NYT::NT_TABLE,
        NYT::TCreateOptions().Recursive(true).Force(true)
    );

    TLockFreeQueue<TString> imageYTPathQueue;

    THolder<IThreadPool> threadPool(new TThreadPool());
    threadPool->Start(1);
    TAutoPtr<BuildingDetectorProcessor> processor(
        new BuildingDetectorProcessor(
            client,
            &imageYTPathQueue,
            outputRYTPath,
            config
        )
    );
    threadPool->SafeAdd(processor.Get());

    for (size_t begin = 0; begin < rowsCount; begin += batchSize) {
        size_t end = std::min(begin + batchSize, rowsCount);
        NYT::TRichYPath batchRYTPath(inputYTTablePath);
        batchRYTPath.AddRange(NYT::TReadRange().FromRowIndices(begin, end));

        TString imageYTPath = State::getTempTable(client).Release();

        size_t jobCount = std::max((end - begin) / config.jobSize(), 1ul);
        YTOpExecutor::Map(
            client,
            YTOpExecutor::MapSpec()
                .AddInput(batchRYTPath)
                .AddOutput(imageYTPath),
            new LoadSatImageMapper(
                config.zoom(),
                config.tileSourceURL()
            ),
            YTOpExecutor::Options()
                .Title("[Buildings detector] Loading satellite images")
                .JobCount(jobCount)
                .RunningJobCount(config.runningJobCount())
                .MemoryLimit(2_GB)
                .MaxRowWeight(128_MB)
        );

        imageYTPathQueue.Enqueue(imageYTPath);

        while (!imageYTPathQueue.IsEmpty()) {
            std::this_thread::sleep_for(std::chrono::seconds(1));
        }
    }

    processor->DataEnded();
    while (!imageYTPathQueue.IsEmpty()) {
        std::this_thread::sleep_for(std::chrono::seconds(1));
    }

    while (processor->isRunning()) {
        std::this_thread::sleep_for(std::chrono::seconds(1));
    }

    threadPool->Stop();
}

} // namespace maps::wiki::autocart::pipeline
