#include <maps/wikimap/mapspro/services/mrc/libs/signdetect/include/signdetect_faster_rcnn.h>
#include <maps/wikimap/mapspro/services/mrc/libs/traffic_signs/include/yandex/maps/mrc/traffic_signs/signs.h>

#include <maps/libs/log8/include/log8.h>
#include <maps/libs/cmdline/include/cmdline.h>
#include <maps/libs/common/include/exception.h>

#include <mapreduce/yt/interface/client.h>
#include <mapreduce/yt/util/temp_table.h>

#include <library/cpp/string_utils/base64/base64.h>

#include <opencv2/opencv.hpp>

#include <set>
#include <fstream>

namespace {
static const TString YT_PROXY = "hahn";

static const TString COLUMN_NAME_IMAGE      = "image";
static const TString COLUMN_NAME_OBJECTS    = "objects";
static const TString ITEM_NAME_BBOX         = "bbox";
static const TString ITEM_NAME_NUM          = "num";
static const TString ITEM_NAME_TYPE         = "type";

static const TString COLUMN_NAME_GT_SIGN    = "gt_sign";
static const TString COLUMN_NAME_TST_SIGN   = "tst_sign";

static const std::string NEGATIVE_SIGN_TYPE_NAME = "negative";

static const std::set<maps::mrc::traffic_signs::TrafficSign> SPEED_LIMITS_TRAFFIC_SIGNS = {
    maps::mrc::traffic_signs::TrafficSign::ProhibitoryMaxSpeed,
    maps::mrc::traffic_signs::TrafficSign::ProhibitoryMaxSpeed5,
    maps::mrc::traffic_signs::TrafficSign::ProhibitoryMaxSpeed10,
    maps::mrc::traffic_signs::TrafficSign::ProhibitoryMaxSpeed20,
    maps::mrc::traffic_signs::TrafficSign::ProhibitoryMaxSpeed30,
    maps::mrc::traffic_signs::TrafficSign::ProhibitoryMaxSpeed40,
    maps::mrc::traffic_signs::TrafficSign::ProhibitoryMaxSpeed50,
    maps::mrc::traffic_signs::TrafficSign::ProhibitoryMaxSpeed60,
    maps::mrc::traffic_signs::TrafficSign::ProhibitoryMaxSpeed70,
    maps::mrc::traffic_signs::TrafficSign::ProhibitoryMaxSpeed80,
    maps::mrc::traffic_signs::TrafficSign::ProhibitoryMaxSpeed90,
    maps::mrc::traffic_signs::TrafficSign::ProhibitoryMaxSpeed100,
    maps::mrc::traffic_signs::TrafficSign::ProhibitoryMaxSpeed110,
    maps::mrc::traffic_signs::TrafficSign::ProhibitoryMaxSpeed120,
    maps::mrc::traffic_signs::TrafficSign::ProhibitoryMaxSpeed130,
    maps::mrc::traffic_signs::TrafficSign::ProhibitoryMaxSpeed15,
    maps::mrc::traffic_signs::TrafficSign::ProhibitoryMaxSpeed25,
    maps::mrc::traffic_signs::TrafficSign::ProhibitoryMaxSpeed35,
    maps::mrc::traffic_signs::TrafficSign::ProhibitoryMaxSpeed45,
    maps::mrc::traffic_signs::TrafficSign::ProhibitoryMaxSpeed55
};


constexpr size_t MiB = 1024 * 1024;
constexpr size_t GiB = 1024 * MiB;

NYT::TNode createOperationSpec(const TString& operationType, bool useGpu) {
    NYT::TNode operationSpec = NYT::TNode::CreateMap()
        ("title", "Traffic signs detector testing...")
        ("scheduling_tag_filter", "!testing");

    const TString GPU_POOL_TREES = "gpu_geforce_1080ti";

    if (useGpu) {
        operationSpec
        ("pool_trees", NYT::TNode::CreateList().Add(GPU_POOL_TREES))
        ("scheduling_options_per_pool_tree", NYT::TNode::CreateMap()
            (GPU_POOL_TREES, NYT::TNode::CreateMap()("pool", "research_gpu"))
        )
        (operationType, NYT::TNode::CreateMap()
            ("memory_limit", 16 * GiB)
            ("gpu_limit", 1)
            ("layer_paths",  NYT::TNode::CreateList()
                            .Add("//porto_layers/delta/gpu/cuda/10.1")
                            .Add("//porto_layers/delta/gpu/driver/418.67")
                            .Add("//porto_layers/base/bionic/porto_layer_search_ubuntu_bionic_app_lastest.tar.gz")
            )
        );
    } else {
        operationSpec
        (operationType, NYT::TNode::CreateMap()
            ("cpu_limit", 4)
            ("memory_limit", 16 * GiB)
            ("memory_reserve_factor", 0.6)
        );
    }
    return operationSpec;
}


cv::Mat decodeImage(const TString& encimageStr) {
    std::vector<std::uint8_t> encimage(Base64DecodeBufSize(encimageStr.length()));
    size_t encimageSize = Base64Decode(encimage.data(), encimageStr.begin(), encimageStr.end());
    encimage.resize(encimageSize);
    return cv::imdecode(encimage, cv::IMREAD_COLOR + cv::IMREAD_IGNORE_ORIENTATION);
}

cv::Rect NodeToRect(const NYT::TNode& node) {
    cv::Rect rc;
    const TVector<NYT::TNode>& rcNode = node.AsList();
    const int64_t x1 = rcNode[0].AsList()[0].AsInt64();
    const int64_t y1 = rcNode[0].AsList()[1].AsInt64();

    const int64_t x2 = rcNode[1].AsList()[0].AsInt64();
    const int64_t y2 = rcNode[1].AsList()[1].AsInt64();

    rc.x = std::min((int)x1, (int)x2);
    rc.y = std::min((int)y1, (int)y2);
    rc.width = (int)x2 - (int)x1;
    rc.height = (int)y2 - (int)y1;
    return rc;
}

maps::mrc::signdetect::DetectedSigns loadTrafficSigns(const NYT::TNode& node, const std::set<maps::mrc::traffic_signs::TrafficSign>& validSigns) {
    maps::mrc::signdetect::DetectedSigns signs;
    const TVector<NYT::TNode>& objectList = node.AsList();
    for (const NYT::TNode& objectNode : objectList) {
        maps::mrc::signdetect::DetectedSign sign;
        sign.sign = maps::mrc::traffic_signs::stringToTrafficSign(objectNode[ITEM_NAME_TYPE].AsString());
        if (0 == validSigns.count(sign.sign)) {
            continue;
        }
        sign.box = NodeToRect(objectNode[ITEM_NAME_BBOX]);
        sign.number = objectNode[ITEM_NAME_NUM].AsString();
        signs.emplace_back(sign);
    }
    return signs;
}

double iou(const cv::Rect& rc1, const cv::Rect& rc2) {
    const double intersect = (rc1 & rc2).area();
    return intersect / (rc1.area() + rc2.area() - intersect);
}

bool isSpeedLimitSign(maps::mrc::traffic_signs::TrafficSign sign) {
    return (0 < SPEED_LIMITS_TRAFFIC_SIGNS.count(sign));
}

void changeSpeedLimitByNumber(maps::mrc::signdetect::DetectedSigns& signs) {
    for (size_t i = 0; i < signs.size(); i++) {
        maps::mrc::signdetect::DetectedSign& sign = signs[i];
        if (!isSpeedLimitSign(sign.sign)) {
            continue;
        }
        if (sign.sign != maps::mrc::traffic_signs::TrafficSign::ProhibitoryMaxSpeed &&
            sign.confidence > sign.numberConfidence) {
            continue;
        }
        int number = 0;
        try {
            number = std::stoi(sign.number);
        } catch (...) {
            continue;
        }
        switch (number) {
        case 5:
            sign.sign = maps::mrc::traffic_signs::TrafficSign::ProhibitoryMaxSpeed5;
            break;
        case 10:
            sign.sign = maps::mrc::traffic_signs::TrafficSign::ProhibitoryMaxSpeed10;
            break;
        case 15:
            sign.sign = maps::mrc::traffic_signs::TrafficSign::ProhibitoryMaxSpeed15;
            break;
        case 20:
            sign.sign = maps::mrc::traffic_signs::TrafficSign::ProhibitoryMaxSpeed20;
            break;
        case 25:
            sign.sign = maps::mrc::traffic_signs::TrafficSign::ProhibitoryMaxSpeed25;
            break;
        case 30:
            sign.sign = maps::mrc::traffic_signs::TrafficSign::ProhibitoryMaxSpeed30;
            break;
        case 35:
            sign.sign = maps::mrc::traffic_signs::TrafficSign::ProhibitoryMaxSpeed35;
            break;
        case 40:
            sign.sign = maps::mrc::traffic_signs::TrafficSign::ProhibitoryMaxSpeed40;
            break;
        case 45:
            sign.sign = maps::mrc::traffic_signs::TrafficSign::ProhibitoryMaxSpeed45;
            break;
        case 50:
            sign.sign = maps::mrc::traffic_signs::TrafficSign::ProhibitoryMaxSpeed50;
            break;
        case 55:
            sign.sign = maps::mrc::traffic_signs::TrafficSign::ProhibitoryMaxSpeed55;
            break;
        case 60:
            sign.sign = maps::mrc::traffic_signs::TrafficSign::ProhibitoryMaxSpeed60;
            break;
        case 70:
            sign.sign = maps::mrc::traffic_signs::TrafficSign::ProhibitoryMaxSpeed70;
            break;
        case 80:
            sign.sign = maps::mrc::traffic_signs::TrafficSign::ProhibitoryMaxSpeed80;
            break;
        case 90:
            sign.sign = maps::mrc::traffic_signs::TrafficSign::ProhibitoryMaxSpeed90;
            break;
        case 100:
            sign.sign = maps::mrc::traffic_signs::TrafficSign::ProhibitoryMaxSpeed100;
            break;
        case 110:
            sign.sign = maps::mrc::traffic_signs::TrafficSign::ProhibitoryMaxSpeed110;
            break;
        case 120:
            sign.sign = maps::mrc::traffic_signs::TrafficSign::ProhibitoryMaxSpeed120;
            break;
        case 130:
            sign.sign = maps::mrc::traffic_signs::TrafficSign::ProhibitoryMaxSpeed130;
            break;
        }
    }
}

struct CompareItem {
    size_t gtIdx;
    size_t tstIdx;
    double compareValue;
};

cv::Rect normalizeRect(const cv::Rect& rc) {
    cv::Rect result = rc;
    if (result.width < 0) {
        result.x = result.x + result.width;
        result.width = -result.width;
    }
    if (result.height < 0) {
        result.y = result.y + result.height;
        result.height = -result.height;
    }
    return result;
}

std::vector<CompareItem> compareBBoxes(
    const maps::mrc::signdetect::DetectedSigns& gtSigns,
    const maps::mrc::signdetect::DetectedSigns& tstSigns,
    double compareThreshold)
{
    std::vector<CompareItem> result;
    CompareItem item;
    for (item.gtIdx = 0; item.gtIdx < gtSigns.size(); item.gtIdx++) {
        const cv::Rect gtRect = normalizeRect(gtSigns[item.gtIdx].box);
        for (item.tstIdx = 0; item.tstIdx < tstSigns.size(); item.tstIdx++) {
            item.compareValue = iou(gtRect, normalizeRect(tstSigns[item.tstIdx].box));
            if (compareThreshold > item.compareValue) {
                continue;
            }
            result.push_back(item);
        }
    }
    return result;
}

struct TrafficSignPair {
    std::string gtSign;
    std::string tstSign;
};

std::vector<TrafficSignPair> compareSigns(
    const maps::mrc::signdetect::DetectedSigns& gtSigns,
    const maps::mrc::signdetect::DetectedSigns& tstSigns,
    double compareThreshold,
    int minGTSignSize)
{
    if (gtSigns.empty() && tstSigns.empty()) {
        return {};
    }

    std::vector<CompareItem> compareItems = compareBBoxes(gtSigns, tstSigns, compareThreshold);
    std::sort(compareItems.begin(), compareItems.end(),
            [] (const CompareItem& a, const CompareItem& b) {
                return a.compareValue > b.compareValue;
            }
    );

    std::vector<TrafficSignPair> results;
    std::vector<bool> foundGT(gtSigns.size(), false);
    std::vector<bool> foundTst(tstSigns.size(), false);
    for (size_t i = 0; i < compareItems.size(); i++) {
        const CompareItem& compareItem = compareItems[i];
        if (foundGT[compareItem.gtIdx] || foundTst[compareItem.tstIdx]) {
            continue;
        }
        foundGT[compareItem.gtIdx] = true;
        foundTst[compareItem.tstIdx] = true;
        TrafficSignPair tsPair;
        tsPair.gtSign = maps::mrc::traffic_signs::toString(gtSigns[compareItem.gtIdx].sign);
        tsPair.tstSign = maps::mrc::traffic_signs::toString(tstSigns[compareItem.tstIdx].sign);
        results.emplace_back(tsPair);
    }

    TrafficSignPair tsPair;
    tsPair.gtSign = NEGATIVE_SIGN_TYPE_NAME;
    for (size_t tstIdx = 0; tstIdx < foundTst.size(); tstIdx++) {
        if (foundTst[tstIdx]) {
            continue;
        }
        tsPair.tstSign = maps::mrc::traffic_signs::toString(tstSigns[tstIdx].sign);
        results.push_back(tsPair);
    }
    tsPair.tstSign = NEGATIVE_SIGN_TYPE_NAME;
    for (size_t gtIdx = 0; gtIdx < foundGT.size(); gtIdx++) {
        if (foundGT[gtIdx]) {
            continue;
        }
        const maps::mrc::signdetect::DetectedSign& gtSign = gtSigns[gtIdx];
        if (0 < minGTSignSize &&
            abs(gtSign.box.width) < minGTSignSize &&
            abs(gtSign.box.height) < minGTSignSize) {
            continue;
        }
        tsPair.gtSign = maps::mrc::traffic_signs::toString(gtSign.sign);
        results.push_back(tsPair);
    }
    return results;
}

// YT Mappers
class TDetectorMapper
    : public NYT::IMapper<NYT::TTableReader<NYT::TNode>, NYT::TTableWriter<NYT::TNode>>  {
public:
    Y_SAVELOAD_JOB(useNumForSpeedLimits_, iouThreshold_, minGTSignSize_);

    TDetectorMapper() = default;

    TDetectorMapper(bool useNumForSpeedLimits, double iouThreshold, int minGTSignSize)
        : useNumForSpeedLimits_(useNumForSpeedLimits)
        , iouThreshold_(iouThreshold)
        , minGTSignSize_(minGTSignSize)
    { }

    void Do(NYT::TTableReader<NYT::TNode>* reader, NYT::TTableWriter<NYT::TNode>* writer) override {
        INFO() << "Start detection ... ";
        std::set<maps::mrc::traffic_signs::TrafficSign> validSigns = {tsDetector_.supportedSigns().begin(), tsDetector_.supportedSigns().end()};
        if (useNumForSpeedLimits_) {
            validSigns.insert(SPEED_LIMITS_TRAFFIC_SIGNS.begin(), SPEED_LIMITS_TRAFFIC_SIGNS.end());
        }
        for (; reader->IsValid(); reader->Next()) {
            const NYT::TNode &inpRow = reader->GetRow();
            cv::Mat image = decodeImage(inpRow[COLUMN_NAME_IMAGE].AsString());

            maps::mrc::signdetect::DetectedSigns gtSigns = loadTrafficSigns(inpRow[COLUMN_NAME_OBJECTS], validSigns);
            maps::mrc::signdetect::DetectedSigns tstSigns = tsDetector_.detect(image);
            if (useNumForSpeedLimits_) {
                changeSpeedLimitByNumber(tstSigns);
            }

            std::vector<TrafficSignPair> tsPairs = compareSigns(gtSigns, tstSigns, iouThreshold_, minGTSignSize_);
            for (size_t i = 0; i < tsPairs.size(); i++) {
                const TrafficSignPair& tsPair = tsPairs[i];
                writer->AddRow( NYT::TNode()
                                (COLUMN_NAME_GT_SIGN, tsPair.gtSign.c_str())
                                (COLUMN_NAME_TST_SIGN, tsPair.tstSign.c_str())
                              );
            }
        }
    }
private:
    bool useNumForSpeedLimits_;
    double iouThreshold_;
    int minGTSignSize_;
    maps::mrc::signdetect::FasterRCNNDetector tsDetector_;
};
REGISTER_MAPPER(TDetectorMapper);


void dumpStatistic(
    const cv::Mat& statistic,
    const std::vector<maps::mrc::traffic_signs::TrafficSign>& signTypes,
    const std::string& outputFilepath)
{
    const size_t signTypesCount = signTypes.size();

    std::vector<int> gtCounts(signTypesCount, 0);
    std::vector<int> tstCounts(signTypesCount, 0);
    std::ofstream ofs(outputFilepath);
    for (size_t tstIdx = 0; tstIdx < signTypes.size(); tstIdx++) {
        ofs << maps::mrc::traffic_signs::toString(signTypes[tstIdx]) << ";";
        for (size_t gtIdx = 0; gtIdx < signTypesCount; gtIdx++) {
            const int value = statistic.at<int>(tstIdx, gtIdx);
            ofs << value << ";";
            gtCounts[gtIdx] += value;
            tstCounts[tstIdx] += value;
        }
        const int value = statistic.at<int>(tstIdx, signTypesCount);
        ofs << value << ";";
        tstCounts[tstIdx] += value;
        ofs << std::endl;
    }
    ofs << "Negatives;";
    for (size_t gtIdx = 0; gtIdx < signTypesCount; gtIdx++) {
        const int value = statistic.at<int>(signTypesCount, gtIdx);
        ofs << value << ";";
        gtCounts[gtIdx] += value;
    }
    ofs << "0;" << std::endl;
    ofs << std::endl;

    ofs << "Type;Precision;Recall" << std::endl;
    int validCounts = 0;
    for (size_t idx = 0; idx < signTypesCount; idx++) {
        ofs << maps::mrc::traffic_signs::toString(signTypes[idx]) << ";";
        ofs << ((0 < tstCounts[idx]) ? 100. * statistic.at<int>(idx, idx) / tstCounts[idx] : 100.) << ";";
        ofs << ((0 < gtCounts[idx]) ? 100. * statistic.at<int>(idx, idx) / gtCounts[idx] : 100.) << ";";
        ofs << std::endl;
        validCounts += statistic.at<int>(idx, idx);
    }
    ofs << std::endl;

    ofs << "Full Precision;" << 100. * validCounts / std::accumulate(tstCounts.begin(), tstCounts.end(), 0) << std::endl;
    ofs << "Full Recall;" << 100. * validCounts / std::accumulate(gtCounts.begin(), gtCounts.end(), 0) << std::endl;
}

void calculateDetectorStatistic(
    NYT::IClientPtr& client,
    const TString& inputYTTable,
    const std::string& outputFilepath,
    bool useNumForSpeedLimits,
    double iouThreshold,
    int minGTSignSize,
    bool useGpu)
{
    constexpr int JOB_COUNT = 10;

    const NYT::TTempTable tempYTTable(client);
    INFO() << "Detect traffic signs...";
    client->Map(
        NYT::TMapOperationSpec()
            .AddInput<NYT::TNode>(inputYTTable)
            .AddOutput<NYT::TNode>(tempYTTable.Name())
            .JobCount(JOB_COUNT),
        new TDetectorMapper(useNumForSpeedLimits, iouThreshold, minGTSignSize),
        NYT::TOperationOptions().Spec(createOperationSpec("mapper", useGpu))
    );

    INFO() << "Traffic signs detected";
    maps::mrc::signdetect::FasterRCNNDetector tsDetector;
    std::vector<maps::mrc::traffic_signs::TrafficSign> signTypes = tsDetector.supportedSigns();
    size_t signTypesCount = signTypes.size();
    std::map<TString, int> typeNameToIndex;
    for (size_t i = 0; i < signTypesCount; i++) {
        typeNameToIndex[maps::mrc::traffic_signs::toString(signTypes[i]).c_str()] = i;
    }
    if (useNumForSpeedLimits) {
        for (maps::mrc::traffic_signs::TrafficSign ts : SPEED_LIMITS_TRAFFIC_SIGNS) {
            const TString typeName = maps::mrc::traffic_signs::toString(ts).c_str();
            if (typeNameToIndex.end() != typeNameToIndex.find(typeName)) {
                continue;
            }
            signTypes.push_back(ts);
            typeNameToIndex[typeName] = signTypes.size() - 1;
        }
        signTypesCount = signTypes.size();
    }
    typeNameToIndex[NEGATIVE_SIGN_TYPE_NAME.c_str()] = signTypesCount;

    INFO() << "Calculate statistic...";
    cv::Mat statistic = cv::Mat::zeros(signTypesCount + 1, signTypesCount + 1, CV_32SC1);
    NYT::TTableReaderPtr<NYT::TNode> reader = client->CreateTableReader<NYT::TNode>(tempYTTable.Name());
    int processedItems = 0;
    for (; reader->IsValid(); reader->Next(), processedItems++) {
        const NYT::TNode& inpRow = reader->GetRow();
        const TString gtSign = inpRow[COLUMN_NAME_GT_SIGN].AsString();
        const TString tstSign = inpRow[COLUMN_NAME_TST_SIGN].AsString();
        const auto gtIt = typeNameToIndex.find(gtSign);
        if (gtIt == typeNameToIndex.end()) {
            continue;
        }
        const int gtIdx = gtIt->second;
        const auto tstIt = typeNameToIndex.find(tstSign);
        if (tstIt == typeNameToIndex.end()) {
            continue;
        }
        const int tstIdx = tstIt->second;
        statistic.at<int>(tstIdx, gtIdx)++;
        if (processedItems % 1000 == 0) {
            INFO() << "Processed " << processedItems << " items";
        }
    }
    REQUIRE(0 == statistic.at<int>(signTypesCount, signTypesCount),
            "Something went wrong, because pairs negative - negative more than zero");
    dumpStatistic(statistic, signTypes, outputFilepath);
}

} //namespace

int main(int argc, const char** argv) try {

    NYT::Initialize(argc, argv);

    maps::cmdline::Parser parser("Calculate traffic signs detector-classificator statistic");

    maps::cmdline::Option<std::string> inputYTTable = parser.string("input")
        .required()
        .help("Path to YT table with ground truth data");

    maps::cmdline::Option<std::string> outputFilepath = parser.string("output")
        .required()
        .help("Output file path for statistic in cvs format");

    maps::cmdline::Option<bool> useNumForSpeedLimits = parser.flag("use-num-speedlimits")
        .help("Use recognized number for speed limit signs instead classifier answer");

    maps::cmdline::Option<double> iouThreshold = parser.real("iou-threshold")
        .defaultValue(0.5)
        .help("IoU threshold for calculate statistic for detector (default: 0.5)");

    maps::cmdline::Option<int> minGTSignSize = parser.num("min-gt-size")
        .defaultValue(30)
        .help("Minimal size of ground truth signs, if less we don't penalize detector recall (default: 30)");

    maps::cmdline::Option<bool> useGpu = parser.flag("use-gpu")
        .help("Use GPU for detector and recognizer tasks");

    parser.parse(argc, const_cast<char**>(argv));

    INFO() << "Connecting to yt::" << YT_PROXY;
    NYT::IClientPtr client = NYT::CreateClient(YT_PROXY);
    calculateDetectorStatistic(client, inputYTTable.c_str(), outputFilepath, useNumForSpeedLimits, iouThreshold, minGTSignSize, useGpu);

    return EXIT_SUCCESS;
}
catch (const maps::Exception& e) {
    FATAL() << "Worker failed: " << e;
    return EXIT_FAILURE;
}
catch (const std::exception& e) {
    FATAL() << "Worker failed: " << e.what();
    return EXIT_FAILURE;
}
