#include <yandex/maps/mrc/traffic_signs/signs.h>

#include <maps/libs/common/include/exception.h>
#include <maps/libs/common/include/base64.h>
#include <maps/libs/log8/include/log8.h>
#include <maps/libs/cmdline/include/cmdline.h>
#include <maps/libs/common/include/exception.h>

#include <mapreduce/yt/interface/client.h>

#include <util/system/fs.h>
#include <util/generic/set.h>
#include <library/cpp/string_utils/base64/base64.h>

#include "opencv2/opencv.hpp"

#include <random>
#include <string>
#include <list>
#include <fstream>

using namespace NYT;
namespace ts = maps::mrc::traffic_signs;

TSet<TString> loadClasses(const std::string &path) {
    TSet<TString> setClassesName;
    std::ifstream ifs(path);
    if (!ifs.is_open())
        return setClassesName;
    for (; !ifs.eof();)
    {
        std::string line; std::getline(ifs, line);
        if (line.empty())
            continue;
        if ('#' == line[0])
            continue;
        setClassesName.insert(line.c_str());
    }
    return setClassesName;
}

cv::Rect NodeToRect(const TNode& node) {
    cv::Rect rc;
    const auto& rcNode = node.AsList();
    const int64_t x1 = rcNode[0].AsList()[0].AsInt64();
    const int64_t y1 = rcNode[0].AsList()[1].AsInt64();

    const int64_t x2 = rcNode[1].AsList()[0].AsInt64();
    const int64_t y2 = rcNode[1].AsList()[1].AsInt64();

    rc.x = std::min((int)x1, (int)x2);
    rc.y = std::min((int)y1, (int)y2);
    rc.width  = abs((int)x2 - (int)x1) + 1;
    rc.height = abs((int)y2 - (int)y1) + 1;
    return rc;
}

cv::Mat loadImage(const TNode& node) {
    TString encimageStr = node.AsString();
    std::vector<std::uint8_t> encimage(Base64DecodeBufSize(encimageStr.length()));
    size_t encimageSize = Base64Decode(encimage.data(), encimageStr.begin(), encimageStr.end());
    encimage.resize(encimageSize);
    return cv::imdecode(encimage, 1);
}

int main(int argc, const char** argv) try {
    Initialize(argc, argv);

    maps::cmdline::Parser parser("Download traffic signs dataset from YT table and cut off image of objects from some classes");

    auto inputTableName = parser.string("input")
        .required()
        .help("Input YT table name");

    auto intputClassesPath = parser.string("classes")
        .required()
        .help("Input list of classes name, which get from dataset");

    auto outputSize = parser.num("outsize")
        .defaultValue(-1)
        .help("Size of output image. If value is negative, then save object without resize (default: -1)");

    auto outputFolder = parser.string("outfolder")
        .required()
        .help("Path to output file with images");

    auto margin = parser.num("margin")
        .defaultValue(20)
        .help("Additional margin around bboxes (default: 20)");

    auto minSize = parser.num("minsize")
        .defaultValue(40)
        .help("Minimal value of size (width or height) of signs (default: 40)");

    parser.parse(argc, const_cast<char**>(argv));

    TSet<TString> setClasses = loadClasses(intputClassesPath);

    INFO() << "Connecting to yt::hahn";
    IClientPtr client = CreateClient("hahn");

    std::string outLabelPath = outputFolder + "/labels.txt";

    std::ofstream ofs(outLabelPath);
    auto reader = client->CreateTableReader<TNode>(inputTableName.c_str());
    int processedObjects = 0;
    int processedItems = 0;
    for (; reader->IsValid(); reader->Next(), processedItems++) {
        const auto& inpRow = reader->GetRow();

        cv::Mat image;
        int64_t feature_id = inpRow["feature_id"].AsInt64();
        const auto& objectList = inpRow["objects"].AsList();
        for (const auto& objectNode : objectList) {
            const TString type = objectNode["type"].AsString();
            if (!setClasses.contains(type))
                continue;
            cv::Rect rc = NodeToRect(objectNode["bbox"]);
            if (std::max(rc.width, rc.height) < minSize)
                continue;
            if (image.empty())
                image = loadImage(inpRow["image"]);

            cv::Rect rcBig;
            rcBig.x = std::max(0, rc.x - margin);
            rcBig.y = std::max(0, rc.y - margin);
            rcBig.width  = std::min(image.cols - 1, rc.x + rc.width  - 1 + margin) - rcBig.x;
            rcBig.height = std::min(image.rows - 1, rc.y + rc.height - 1 + margin) - rcBig.y;

            cv::Mat object;
            if (outputSize < 0) {
                object = image(rcBig);
            } else {
                if (rcBig.width < rcBig.height) {
                    const int delta = rcBig.height - rcBig.width;
                    if (2 * rcBig.x < delta)
                        rcBig.x = 0;
                    else if (2 * (rcBig.x + rcBig.width) + delta > 2 * image.cols)
                        rcBig.x = image.cols - rcBig.height;
                    else
                        rcBig.x -= delta / 2;
                    rcBig.width = rcBig.height;
                } else {
                    const int delta = rcBig.width - rcBig.height;
                    if (2 * rcBig.y < delta)
                        rcBig.y = 0;
                    else if (2 * (rcBig.y + rcBig.height) + delta > 2 * image.rows)
                        rcBig.y = image.rows - rcBig.width;
                    else
                        rcBig.y -= delta / 2;
                    rcBig.height = rcBig.width;
                }
                if (rcBig.width < outputSize || rcBig.height < outputSize)
                    continue;

                cv::resize(image(rcBig), object, cv::Size(outputSize, outputSize));
            }
            std::string outImageName = cv::format("%d_%06d.jpg", feature_id, processedObjects++);
            std::string outImagePath = outputFolder + "/" + outImageName;
            cv::imwrite(outImagePath, object);
            ofs << "./" << outImageName << " " << type << std::endl;
        }
        if ((processedItems + 1) % 1000 == 0) {
            INFO() << "Processed " << (processedItems + 1) << " images. Objects saved " << processedObjects;
        }
    }
    return 0;
}
catch (const maps::Exception& e) {
    FATAL() << "Worker failed: " << e;
    return EXIT_FAILURE;
}
catch (const std::exception& e) {
    FATAL() << "Worker failed: " << e.what();
    return EXIT_FAILURE;
}
