#include <maps/libs/json/include/value.h>
#include <maps/wikimap/mapspro/services/mrc/libs/db/include/common.h>
#include <maps/wikimap/mapspro/services/mrc/libs/traffic_signs/include/yandex/maps/mrc/traffic_signs/signs.h>
#include <maps/wikimap/mapspro/services/mrc/libs/signdetect/include/sign_relations.h>

#include <maps/libs/cmdline/include/cmdline.h>
#include <maps/libs/common/include/exception.h>
#include <maps/libs/log8/include/log8.h>



using namespace maps::mrc;

namespace {

struct FeatureData {
    maps::mrc::signdetect::DetectedSigns objects;
    std::vector<std::pair<size_t, size_t>> relationsGT;
};

struct Statistic {
    size_t truePositive = 0;
    size_t falsePositive = 0;
};

bool isEqual(const std::pair<size_t, size_t>& a, const std::pair<size_t, size_t>& b) {
    return ((a.first == b.first) && (a.second == b.second) ||
            (a.first == b.second) && (a.second == b.first));
}

std::pair<db::TId, maps::mrc::signdetect::DetectedSign> loadObject(const maps::json::Value& objectJson) {
    maps::mrc::signdetect::DetectedSign object;
    const maps::json::Value bboxJson = objectJson["bbox"];
    object.box.x = bboxJson[0][0].as<int64_t>();
    object.box.y = bboxJson[0][1].as<int64_t>();
    object.box.width = bboxJson[1][0].as<int64_t>() - object.box.x;
    object.box.height = bboxJson[1][1].as<int64_t>() - object.box.y;
    object.sign = maps::mrc::traffic_signs::stringToTrafficSign(objectJson["type"].as<std::string>());
    db::TId id = objectJson["id"].as<int64_t>();
    return {id, object};
}

// std::map<db::TId, size_t> Id объекта к индексу объекта в векторе DetectedSigns
std::pair<std::map<db::TId, size_t>, maps::mrc::signdetect::DetectedSigns> loadObjects(const maps::json::Value& objectsJson) {
    std::map<db::TId, size_t> idToIndex;
    maps::mrc::signdetect::DetectedSigns objects;
    for (auto item : objectsJson) {
        std::pair<db::TId, maps::mrc::signdetect::DetectedSign> data = loadObject(item);
        objects.emplace_back(data.second);
        idToIndex[data.first] = objects.size() - 1;
    }
    return {idToIndex, objects};
}

std::vector<std::pair<size_t, size_t>> loadRelations(const maps::json::Value& relsJson, const std::map<db::TId, size_t>& idToIndex) {
    std::vector<std::pair<size_t, size_t>> result;
    for (auto item : relsJson) {
        result.emplace_back(idToIndex.at(item[0].as<size_t>()), idToIndex.at(item[1].as<size_t>()));
    }
    return result;
}

FeatureData loadFeatureData(const maps::json::Value& value) {
    std::map<db::TId, size_t> idToIndex;
    FeatureData data;
    tie(idToIndex, data.objects) = loadObjects(value["objects"]);
    data.relationsGT  = loadRelations(value["relations"], idToIndex);
    return data;
}

Statistic calculateStatistic(
    const std::vector<std::pair<size_t, size_t>>& relationsTst,
    const std::vector<std::pair<size_t, size_t>>& relationsGT)
{
    Statistic stat;
    for (size_t i = 0; i < relationsTst.size(); i++) {
        bool found = false;
        const std::pair<size_t, size_t>& relation = relationsTst[i];
        for (size_t j = 0; j < relationsGT.size(); j++) {
            if (isEqual(relation, relationsGT[j])) {
                found = true;
                break;
            }
        }
        if (found) {
            stat.truePositive++;
        } else {
            stat.falsePositive++;
        }
    }
    return stat;
}

} // namespace


int main(int argc, const char** argv) try {
    maps::cmdline::Parser parser("Check traffic signs additional table to signs relations");

    maps::cmdline::Option<std::string> datasetPath = parser.string("dataset")
        .required()
        .help("Path to ground truth dataset, download from: https://proxy.sandbox.yandex-team.ru/2225243006");

    parser.parse(argc, const_cast<char**>(argv));

    maps::json::Value dataset = maps::json::Value::fromFile(datasetPath);

    size_t truePositive = 0;
    size_t falsePositive = 0;
    size_t amountGT = 0;
    for (auto item : dataset["features"]) {
        FeatureData data = loadFeatureData(item);
        std::vector<std::pair<size_t, size_t>> relationsTst = maps::mrc::signdetect::foundRelations(data.objects);
        Statistic stat = calculateStatistic(relationsTst, data.relationsGT);
        truePositive += stat.truePositive;
        falsePositive += stat.falsePositive;
        amountGT += data.relationsGT.size();
    }
    INFO() << "True positives:  " << truePositive;
    INFO() << "False positives: " << falsePositive;
    INFO() << "GT count:        " << amountGT;
    INFO() << "Precision:       " << 100.f * truePositive / (truePositive + falsePositive);
    INFO() << "Recall:          " << 100.f * truePositive / amountGT;

    return EXIT_SUCCESS;
}
catch (const maps::Exception& e) {
    FATAL() << "Worker failed: " << e;
    return EXIT_FAILURE;
}
catch (const std::exception& e) {
    FATAL() << "Worker failed: " << e.what();
    return EXIT_FAILURE;
}
