#include <maps/wikimap/mapspro/libs/tfrecord_writer/include/tfrecord_writer.h>

#include <maps/libs/common/include/exception.h>
#include <maps/libs/common/include/base64.h>
#include <maps/libs/log8/include/log8.h>
#include <maps/libs/cmdline/include/cmdline.h>

#include <mapreduce/yt/interface/client.h>

#include <library/cpp/string_utils/base64/base64.h>

#include <util/charset/utf8.h>
#include <util/charset/wide.h>

#include <random>
#include <string>
#include <list>
#include <fstream>

using namespace NYT;
namespace tfw = maps::wiki::tfrecord_writer;

std::map<wchar16, int64_t>
extractSymbolToIndex(const std::string& symbolsParam)
{
    std::map<wchar16, int64_t> result;
    std::stringstream ss(symbolsParam);
    std::string code;
    for (int64_t idx = 0; std::getline(ss, code, ','); idx++) {
        result[(wchar16)std::stoi(code)] = idx;
    }
    return result;
}

struct HouseNumber {
    cv::Rect bbox;
    TUtf16String number;
};

typedef std::list<HouseNumber> HouseNumberList;

cv::Rect NodeToRect(const TNode& node) {
    cv::Rect rc;
    const auto& rcNode = node.AsList();
    const int64_t x1 = rcNode[0].AsList()[0].AsInt64();
    const int64_t y1 = rcNode[0].AsList()[1].AsInt64();

    const int64_t x2 = rcNode[1].AsList()[0].AsInt64();
    const int64_t y2 = rcNode[1].AsList()[1].AsInt64();

    rc.x = std::min((int)x1, (int)x2);
    rc.y = std::min((int)y1, (int)y2);
    rc.width  = abs((int)x2 - (int)x1) + 1;
    rc.height = abs((int)y2 - (int)y1) + 1;
    return rc;
}

HouseNumberList ReadObjects(const TNode& node, int minSize, bool houseNumbersOnly) {
    static const TString HOUSE_NUMBER_OBJECT_TYPE = "house_number_sign";

    HouseNumberList objects;
    const TVector<TNode>& objectList = node.AsList();
    for (size_t i = 0; i < objectList.size(); i++) {
        const TNode &objectNode = objectList[i];
        const TString type = objectNode["type"].AsString();
        if (houseNumbersOnly && HOUSE_NUMBER_OBJECT_TYPE != type)
            continue;
        HouseNumber hn;
        hn.bbox = NodeToRect(objectNode["bbox"]);
        if (hn.bbox.width < minSize && hn.bbox.height < minSize)
            continue;
        hn.number = UTF8ToWide(objectNode["num"].AsString());
        if (0 == hn.number.size())
            continue;
        objects.emplace_back(hn);
    }
    return objects;
}

cv::Mat ReadImage(const TNode& node) {
    TString encimageStr = node.AsString();
    std::vector<std::uint8_t> encimage(Base64DecodeBufSize(encimageStr.length()));
    size_t encimageSize = Base64Decode(encimage.data(), encimageStr.begin(), encimageStr.end());
    encimage.resize(encimageSize);
    return cv::imdecode(encimage, cv::IMREAD_COLOR);
}

cv::Rect ExpandRect(const cv::Rect &rc, double expandBorder, const cv::Size &imageSize) {
    cv::Point tl;
    tl.x = std::max((int)(rc.x - expandBorder * rc.width), 0);
    tl.y = std::max((int)(rc.y - expandBorder * rc.height), 0);
    cv::Point br;
    br.x = std::min((int)(rc.x + (1. + expandBorder) * rc.width), imageSize.width);
    br.y = std::min((int)(rc.y + (1. + expandBorder) * rc.height), imageSize.height);
    return cv::Rect(tl, br);
}

std::list<int64_t> SymbolsToIndices(const TUtf16String& str, const std::map<wchar16, int64_t>& symbolToIndex) {
    std::list<int64_t> result;
    for (size_t chIdx = 0; chIdx < str.size(); chIdx++) {
        const auto cit = symbolToIndex.find(str[chIdx]);
        if (cit != symbolToIndex.cend())
            result.push_back(cit->second);
    }
    return result;
}

void MakeTFRecord(TTableReaderPtr<TNode> &reader,
                  const std::map<wchar16, int64_t>& symbolToIndex,
                  tfw::TFRecordWriter<tfw::MultiLabelsObject> &trainWriter,
                  tfw::TFRecordWriter<tfw::MultiLabelsObject> &testWriter,
                  int minSize, double expandBorder, bool houseNumbersOnly,
                  double trainPart, int seed) {
    REQUIRE(!symbolToIndex.empty(), "Symbols set is empty");

    INFO() << "Make result tfrecord";
    INFO() << "Object min size: " << minSize;

    std::default_random_engine rndGen(seed);
    std::uniform_real_distribution<double> rndUniformDistr(0.0, 1.0);

    int trainItems = 0;
    int testItems = 0;
    for (; reader->IsValid(); reader->Next()) {
        const TNode& inpRow = reader->GetRow();

        HouseNumberList objects = ReadObjects(inpRow["objects"], minSize, houseNumbersOnly);
        if (objects.empty())
            continue;
        cv::Mat image = ReadImage(inpRow["image"]);
        for (HouseNumberList::iterator it = objects.begin(); it != objects.end(); it++) {
            tfw::MultiLabelsObject object;
            object.labels = SymbolsToIndices(it->number, symbolToIndex);
            if (object.labels.empty())
                continue;
            object.text  = WideToUTF8(it->number).c_str();
            cv::Rect rect = ExpandRect(it->bbox, expandBorder, image.size());
            object.bbox  = (it->bbox & rect) - rect.tl();
            if (rndUniformDistr(rndGen) < trainPart) {
                trainWriter.AddRecord(image(rect), {object});
                trainItems++;
            }
            else {
                testWriter.AddRecord(image(rect), {object});
                testItems++;
            }
        }

        if ((trainItems + testItems + 1) % 1000 == 0) {
            INFO() << "Processed " << (trainItems + testItems + 1) << " items";
        }
    }
}

int main(int argc, const char** argv) try {
    Initialize(argc, argv);

    maps::cmdline::Parser parser("Download house number signs dataset from YT table to tfrecord");

    maps::cmdline::Option<std::string> intputTableName = parser.string("input")
        .required()
        .help("Input YT table name");

    maps::cmdline::Option<std::string> symbolsCodeParam = parser.string("symbols")
        .required()
        .help("Symbols unicodes divided by commas");

    maps::cmdline::Option<std::string> outputTrainPath = parser.string("outtrain")
        .required()
        .help("Path to output file for train data");

    maps::cmdline::Option<std::string> outputTestPath = parser.string("outtest")
        .required()
        .help("Path to output file for test data");

    maps::cmdline::Option<int> minSize = parser.num("minsize")
        .defaultValue(0)
        .help("Minimal value of size (width or height) of signs (default: 0)");

    maps::cmdline::Option<double> expandBorder = parser.real("expborder")
        .defaultValue(.5)
        .help("Add border around object in part of width and height (default: 0.5)");

    maps::cmdline::Option<bool> houseNumbersOnly = parser.flag("hns-only")
        .help("convert house numbers objects only");

    maps::cmdline::Option<double> trainPart = parser.real("trainpart")
        .defaultValue(.9)
        .help("Part of dataset for building train.tfrecord (from 0. to 1.0) other part will use for test.tfrecord  (default: 0.9)");

    maps::cmdline::Option<int> randomSeed = parser.num("seed")
        .defaultValue(42)
        .help("Seed of random generator to split train and test records (default: 42)");

    parser.parse(argc, const_cast<char**>(argv));

    INFO() << "Connecting to yt::hahn";
    IClientPtr client = CreateClient("hahn");

    TTableReaderPtr<TNode> reader = client->CreateTableReader<TNode>(intputTableName.c_str());

    TFileOutput trainFile(outputTrainPath.c_str());
    tfw::TFRecordWriter<tfw::MultiLabelsObject> tfrecordWriterTrain(&trainFile);
    TFileOutput testFile(outputTestPath.c_str());
    tfw::TFRecordWriter<tfw::MultiLabelsObject> tfrecordWriterTest(&testFile);

    MakeTFRecord(reader,
                 extractSymbolToIndex(symbolsCodeParam),
                 tfrecordWriterTrain,
                 tfrecordWriterTest,
                 minSize,
                 expandBorder,
                 houseNumbersOnly,
                 trainPart,
                 randomSeed);

    return 0;
}
catch (const maps::Exception& e) {
    FATAL() << "Worker failed: " << e;
    return EXIT_FAILURE;
}
catch (const std::exception& e) {
    FATAL() << "Worker failed: " << e.what();
    return EXIT_FAILURE;
}
