#include <maps/wikimap/mapspro/services/mrc/libs/house_number_sign_detector/include/house_number_sign_detector.h>

#include <maps/libs/log8/include/log8.h>
#include <maps/libs/cmdline/include/cmdline.h>
#include <maps/libs/common/include/exception.h>

#include <mapreduce/yt/interface/client.h>
#include <library/cpp/string_utils/base64/base64.h>

#include <opencv2/opencv.hpp>

#include <fstream>
#include <iostream>
#include <sstream>
#include <vector>

using namespace NYT;
using namespace maps::mrc::house_number_sign_detector;

namespace {
    cv::Rect NodeToRect(const TNode& node) {
        cv::Rect rc;
        const auto& rcNode = node.AsList();
        const int64_t x1 = rcNode[0].AsList()[0].AsInt64();
        const int64_t y1 = rcNode[0].AsList()[1].AsInt64();

        const int64_t x2 = rcNode[1].AsList()[0].AsInt64();
        const int64_t y2 = rcNode[1].AsList()[1].AsInt64();

        rc.x = std::min((int)x1, (int)x2);
        rc.y = std::min((int)y1, (int)y2);
        rc.width = abs((int)x2 - (int)x1) + 1;
        rc.height = abs((int)y2 - (int)y1) + 1;
        return rc;
    }

    void NodeToHNSs(const TNode &node, HouseNumberSigns &hnss) {
        static const TString OBJECT_TYPE_NAME = "house_number_sign";

        const TVector<TNode>& objectsNodes = node.AsList();
        for (size_t i = 0; i < objectsNodes.size(); i++) {
            const TNode &objectNode = objectsNodes[i];

            if (objectNode["type"].AsString() != OBJECT_TYPE_NAME)
                continue;

            HouseNumberSign hns;
            hns.box = NodeToRect(objectNode["bbox"]);
            hns.confidence = 1.0f;
            hns.number = objectNode["num"].AsString();
            hnss.emplace_back(hns);
        }
    }

    double IoM(const cv::Rect &rc1, const cv::Rect &rc2) {
        if (rc1.area() == 0 || rc2.area() == 0)
            return 0.;
        const int intersectArea = (rc1 & rc2).area();
        const int minArea = std::min(rc1.area(), rc2.area());
        return (double)intersectArea / (double)minArea;
    }

    struct MatchPair {
        size_t gtIdx;
        size_t tstIdx;
        double iom;
    };

    struct MatchPairComparer{
        bool operator() (const MatchPair &lhs, const MatchPair &rhs) const {
            return rhs.iom < lhs.iom;
        }
    };

    typedef std::set<MatchPair, MatchPairComparer> MatchPairsSet;

    struct Statistic {
        int gtCount;
        int tstCount;
        int pairWithSameNum;
        int pairWithDiffNum;
        Statistic()
            : gtCount(0)
            , tstCount(0)
            , pairWithSameNum(0)
            , pairWithDiffNum(0) {}
    };

    int ToNumber(const std::string &str) {
        int result = 0;
        const char *p = str.c_str();
        for (; *p; p++) {
            if (!std::isdigit(*p))
                continue;
            result = result * 10 + (*p - '0');
        }
        return result;
    }

    void Compare(const HouseNumberSigns &gt, const HouseNumberSigns &tst, double iomThreshold, Statistic &stat, bool compareAsNumber) {
        MatchPairsSet pairs;
        MatchPair pair;
        for (pair.gtIdx = 0; pair.gtIdx < gt.size(); pair.gtIdx++) {
            REQUIRE(0 < gt[pair.gtIdx].box.area(), "Bounding box of GT signs has zero area");
            for (pair.tstIdx = 0; pair.tstIdx < tst.size(); pair.tstIdx++) {
                pair.iom = IoM(gt[pair.gtIdx].box, tst[pair.tstIdx].box);
                if (pair.iom < iomThreshold)
                    continue;
                pairs.insert(pair);
            }
        }

        std::vector<bool> gtFound(gt.size(), false);
        std::vector<bool> tstFound(tst.size(), false);
        for (MatchPairsSet::const_iterator it = pairs.cbegin(); it != pairs.cend(); it++) {
            if (gtFound[it->gtIdx] || tstFound[it->tstIdx])
                continue;
            gtFound[it->gtIdx] = true;
            tstFound[it->tstIdx] = true;
            bool sameNumber = compareAsNumber ? (ToNumber(gt[it->gtIdx].number) == ToNumber(tst[it->tstIdx].number))
                                              : (gt[it->gtIdx].number == tst[it->tstIdx].number);
            if (sameNumber)
                stat.pairWithSameNum++;
            else
                stat.pairWithDiffNum++;
        }
        stat.gtCount += gt.size();
        stat.tstCount += tst.size();
    }

    void LoadTestTable(const TTableReaderPtr<TNode> &reader, std::map<size_t, HouseNumberSigns> &mapHNSs) {
        const std::hash<std::string> strHash;
        for (int processedItems = 0; reader->IsValid(); reader->Next(), processedItems++) {
            HouseNumberSigns hnss;
            const TNode& inpRow = reader->GetRow();
            NodeToHNSs(inpRow["objects"], hnss);
            TString encimageStr = inpRow["image"].AsString();
            mapHNSs[strHash(encimageStr.c_str())] = hnss;
        }
    }

    cv::Mat DecodeImage(const TString &encimageBase64Str) {
        std::vector<std::uint8_t> encimage(Base64DecodeBufSize(encimageBase64Str.length()));
        size_t encimageSize = Base64Decode(encimage.data(), encimageBase64Str.begin(), encimageBase64Str.end());
        encimage.resize(encimageSize);
        return cv::imdecode(encimage, cv::IMREAD_COLOR);
    }

    void dump(const HouseNumberSigns &hnss) {
        INFO() << "Count: " << hnss.size();
        for (size_t i = 0; i < hnss.size(); i++) {
            INFO() << "  number: " << hnss[i].number
                   << " bbox: " << hnss[i].box;
        }
    }

    void drawHouseNumberSigns(cv::Mat &image, const HouseNumberSigns &hnss, const cv::Scalar &color, bool topText) {
        for (size_t i = 0; i < hnss.size(); i++) {
            const HouseNumberSign sign = hnss[i];
            cv::rectangle(image, sign.box, color, 2);
            if (sign.number.empty())
                continue;
            cv::Size textSize = cv::getTextSize(sign.number, 0, 0.5, 2, 0);
            cv::Point textPoint(sign.box.tl().x,
                                topText ? sign.box.tl().y : sign.box.br().y + textSize.height);
            if (textPoint.x + textSize.width + 2 >= image.cols)
                textPoint.x = image.cols - 2 - textSize.width;
            cv::putText(image, sign.number, textPoint, 0, .5, color, 2);
        }
    }
} //namespace

int main(int argc, const char** argv) try {
    Initialize(argc, argv);

    maps::cmdline::Parser parser("Calculate statistic of house number sign detector");

    maps::cmdline::Option<std::string> inputTable = parser.string("input")
        .required()
        .help("Path to YT table with ground truth data");

    maps::cmdline::Option<double> iomThreshold = parser.real("iom_thr")
        .defaultValue(0.5)
        .help("Threshold for bbox iom (default: 0.5)");

    maps::cmdline::Option<std::string> inputTestTable = parser.string("input_test")
        .defaultValue("")
        .help("Path to YT table with test data. If empty use detector from current version of library (default: '')");

    maps::cmdline::Option<bool> skipBlankGT = parser.flag("skip_blank_gt")
        .help("Skip ground truth signs with undefined number");

    maps::cmdline::Option<int> minSize = parser.num("min_sz")
        .defaultValue(0)
        .help("If width and height of GT items bbox smaller than this value, then remove item from GT set (default: 0)");

    maps::cmdline::Option<bool> compareAsNumber = parser.flag("cmp_as_num")
        .help("Compare numbers from GT and Test as numbers not as string");

    maps::cmdline::Option<bool> woRecognizeNum = parser.flag("wo_recognize_num")
        .help("Launch detector only, number recognizer disabled");

    maps::cmdline::Option<bool> dumpTrace = parser.flag("dump_trace")
        .help("Dump ground truth and detected objects");

    maps::cmdline::Option<std::string> dumpImageResults = parser.string("dump_img_result_path")
        .defaultValue("")
        .help("Dump images with detected objects drawed on");

    parser.parse(argc, const_cast<char**>(argv));
    IClientPtr client = CreateClient("hahn");
    TTableReaderPtr<TNode> reader = client->CreateTableReader<TNode>(inputTable.c_str());

    std::map<size_t, HouseNumberSigns> mapTestHNSs;
    if (!inputTestTable.empty()) {
        INFO() << "Load test data from YT table: " << inputTestTable;
        TTableReaderPtr<TNode> readerTest = client->CreateTableReader<TNode>(inputTestTable.c_str());
        LoadTestTable(readerTest, mapTestHNSs);
    }

    FasterRCNNDetector detector;
    Statistic statistic;
    const std::hash<std::string> strHash;
    for (int processedItems = 0; reader->IsValid(); reader->Next(), processedItems++) {
        HouseNumberSigns hnssGT;
        const TNode& inpRow = reader->GetRow();
        NodeToHNSs(inpRow["objects"], hnssGT);

        if (skipBlankGT) {
            for (int i = (int)hnssGT.size() - 1; 0 <= i; i--) {
                if (hnssGT[i].number.empty())
                    hnssGT.erase(hnssGT.begin() + i);
            }
        }

        if (0 < minSize) {
            for (int i = (int)hnssGT.size() - 1; 0 <= i; i--) {
                if (hnssGT[i].box.width < minSize &&
                    hnssGT[i].box.height < minSize)
                    hnssGT.erase(hnssGT.begin() + i);
            }
        }

        TString encimageStr = inpRow["image"].AsString();
        HouseNumberSigns hnssTest;
        cv::Mat image;
        if (inputTestTable.empty()) {
            image = DecodeImage(encimageStr);
            hnssTest = detector.detect(image, woRecognizeNum ? RecognizeNumber::No : RecognizeNumber::Yes);
        } else {
            size_t hash = strHash(encimageStr.c_str());
            std::map<size_t, HouseNumberSigns>::const_iterator cit = mapTestHNSs.find(hash);
            if (cit != mapTestHNSs.end())
                hnssTest = cit->second;
            else
                INFO() << "Unable to found test data with hash: " << hash;
        }

        if (!dumpImageResults.empty()) {
            if (image.empty())
                image = DecodeImage(encimageStr);
            drawHouseNumberSigns(image, hnssGT, cv::Scalar(0, 0, 255), true);
            drawHouseNumberSigns(image, hnssTest, cv::Scalar(255, 0, 0), false);
            cv::imwrite(cv::format("%s/%d.jpg", dumpImageResults.c_str(), processedItems), image);
        }

        if (dumpTrace) {
            INFO() << "Ground truth:";
            dump(hnssGT);
            INFO() << "Detected truth:";
            dump(hnssTest);

            Statistic old = statistic;
            Compare(hnssGT, hnssTest, iomThreshold, statistic, compareAsNumber);
            INFO() << "  with same numbers:      " << statistic.pairWithSameNum - old.pairWithSameNum;
            INFO() << "  with different numbers: " << statistic.pairWithDiffNum - old.pairWithDiffNum;
            INFO() << "_______________";
        } else {
            Compare(hnssGT, hnssTest, iomThreshold, statistic, compareAsNumber);
        }
    }
    INFO() << "Ground Truth objects:         " << statistic.gtCount;
    INFO() << "Detected objects:             " << statistic.tstCount;
    if (!woRecognizeNum) {
        INFO() << "Match with same numbers:      " << statistic.pairWithSameNum;
        INFO() << "Match with different numbers: " << statistic.pairWithDiffNum;
    } else {
        INFO() << "Match objects:                " << statistic.pairWithSameNum + statistic.pairWithDiffNum;
    }
    INFO() << "Detector: ";
    if (0 < statistic.tstCount) {
        INFO() << "    precision: "
               << (float)(statistic.pairWithDiffNum + statistic.pairWithSameNum) / (float) statistic.tstCount * 100 << "%";
    }
    if (0 < statistic.gtCount) {
        INFO() << "    recall:    "
               << (float)(statistic.pairWithDiffNum + statistic.pairWithSameNum) / (float) statistic.gtCount * 100 << "%";
    }
    if (!woRecognizeNum) {
        INFO() << "Detector + recognizer: ";
        if (0 < statistic.tstCount) {
            INFO() << "    precision: "
                   << (float)statistic.pairWithSameNum / (float) statistic.tstCount * 100 << "%";
        }
        if (0 < statistic.gtCount) {
            INFO() << "    recall:    "
                   << (float)statistic.pairWithSameNum / (float) statistic.gtCount * 100 << "%";
        }
    }
    return EXIT_SUCCESS;
}
catch (const maps::Exception& e) {
    FATAL() << "Worker failed: " << e;
    return EXIT_FAILURE;
}
catch (const std::exception& e) {
    FATAL() << "Worker failed: " << e.what();
    return EXIT_FAILURE;
}
