#include "pipeline.h"
#include "common.h"

#include <maps/libs/log8/include/log8.h>
#include <maps/libs/cmdline/include/cmdline.h>
#include <maps/libs/common/include/exception.h>

#include <maps/wikimap/mapspro/libs/common/include/yandex/maps/wiki/common/extended_xml_doc.h>
#include <maps/wikimap/mapspro/libs/common/include/yandex/maps/wiki/common/pgpool3_helpers.h>
#include <maps/wikimap/mapspro/libs/revision/include/yandex/maps/wiki/revision/revisionsgateway.h>
#include <maps/wikimap/mapspro/services/mrc/libs/object/include/revision_loader.h>

#include <mapreduce/yt/interface/client.h>
#include <mapreduce/yt/util/temp_table.h>

#include <fstream>

using namespace maps::mrc::house_number_pipeline;

namespace {

maps::wiki::revision::DBID getCommitID(NYT::IClientPtr& client, const std::string& inputYTPath) {
    static const TString TABLE_ATTR_NAME_COMMIT_ID = "mapCommitID";

    TString attr = inputYTPath.c_str();
    attr += "/@" + TABLE_ATTR_NAME_COMMIT_ID;
    return client->Get(attr).AsInt64();
}

void printStatistic(const std::string& title, const Statistic& statistic) {
    INFO() << title << " statistic: ";
    INFO() << " true positives:  " << statistic.truePositives;
    INFO() << " false positives: " << statistic.falsePositives;
    INFO() << " false negatives: " << statistic.falseNegatives;
    INFO() << "------------------";
    if (0 < statistic.truePositives + statistic.falsePositives) {
        INFO() << " precision: " << (double)statistic.truePositives / (statistic.truePositives + statistic.falsePositives);
    } else {
        INFO() << " precision: undefined";
    }
    if (0 < statistic.truePositives + statistic.falseNegatives) {
        INFO() << " recall:    " << (double)statistic.truePositives / (statistic.truePositives + statistic.falseNegatives);
    }
}

} //namespace


int main(int argc, const char** argv) try {
    static const TString YT_PROXY = "hahn";

    NYT::Initialize(argc, argv);

    maps::cmdline::Parser parser("Test house number hypotheses generator pipeline");

    maps::cmdline::Option<std::string> inputYTPath = parser.string("input")
        .required()
        .help("Path to YT table with ground truth data");

    maps::cmdline::Option<std::string> wikiConfigPath = parser.string("wiki-config")
        .help("Path to services config for wikimap");

    maps::cmdline::Option<bool> groundTruthStatistic = parser.flag("gt-stat")
        .help("Calculate ground truth statistic");

    maps::cmdline::Option<bool> testDetector = parser.flag("detector")
        .help("Test house number detector");

    maps::cmdline::Option<bool> testDetectorRecognizer = parser.flag("detector-recognizer")
        .help("Test house number detector and recognizer simultaneously");

    maps::cmdline::Option<bool> testRecognizer = parser.flag("recognizer")
        .help("Test house number recognizer");

    maps::cmdline::Option<bool> addrPointsSearcher = parser.flag("addrpt-searcher")
        .help("Test address points searcher");

    maps::cmdline::Option<bool> fullStack = parser.flag("full-stack")
        .help("Test full pipeline from detection to address points searching");

    maps::cmdline::Option<bool> filterCars = parser.flag("filter-cars")
        .help("Filter results of detector-recognizer on the cars");

    maps::cmdline::Option<double> iouThreshold = parser.real("iou-threshold")
        .defaultValue(0.5)
        .help("IoU threshold for calculate statistic for detector and detector+recognizer (default: 0.5)");

    maps::cmdline::Option<double> minDetectorConfidence = parser.real("min-detector-confidence")
        .defaultValue(0.5)
        .help("Minimal confidence for detector (default: 0.5)");

    maps::cmdline::Option<double> minRecognizerConfidence = parser.real("min-recognizer-confidence")
        .defaultValue(0.85)
        .help("Minimal confidence for recognizer (default: 0.85)");

    maps::cmdline::Option<bool> skipSymbolsUnknownByRecognizer = parser.flag("skip-unknown-symbols")
        .help("Do not consider symbols not supported by recognizer for comparing gt with results of recognizer");

    maps::cmdline::Option<double> minAddrPointConfidence = parser.real("addr-point-confidence")
        .defaultValue(0.7)
        .help("Minimal confidence for address point searcher (default: 0.7)");

    maps::cmdline::Option<std::string> outputDetectorResultsPath = parser.string("output-detector-results")
        .help("Path to json file for save results of detector");

    maps::cmdline::Option<bool> removeObjectsAddressPointFromMap = parser.flag("remove-objects-address-point-from-map")
        .help("Remove (virtually) address points linked to signs in dataset from map for statistic calculation");

    maps::cmdline::Option<bool> calculateStatisticOnGrid = parser.flag("on-grid")
        .help("Calculate statistic on grid");

    maps::cmdline::Option<std::string> outputGridStatisticPath = parser.string("output-grid-stat")
        .help("Output path for save grid statistic");

    maps::cmdline::Option<bool> useGpu = parser.flag("use-gpu")
        .help("Use GPU for detector and recognizer tasks");

    parser.parse(argc, const_cast<char**>(argv));

    INFO() << "Connecting to yt::" << YT_PROXY;
    NYT::IClientPtr client = NYT::CreateClient(YT_PROXY);
    INFO() << "Opening table " << inputYTPath;
    const maps::wiki::revision::DBID commitId = getCommitID(client, inputYTPath);
    INFO() << "Commit ID: " << commitId;
    client->Create(YT_TEMPORARY_FOLDER, NYT::NT_MAP, NYT::TCreateOptions().Recursive(true).IgnoreExisting(true));

    if (groundTruthStatistic) {
        calculateGroundTruthStatistic(client, inputYTPath.c_str());
    }

    if (testDetector) {
        const NYT::TTempTable tmpTable(client, "", YT_TEMPORARY_FOLDER);
        const bool recognizeNumber = false;
        detectHouseNumbers(client, inputYTPath.c_str(), tmpTable.Name(), filterCars, recognizeNumber, useGpu);
        if (!outputDetectorResultsPath.empty()) {
            exportDetectorResults(client, tmpTable.Name(), minDetectorConfidence, /*minRecognizerConfidence*/ -DBL_MAX, outputDetectorResultsPath);
        }
        if (calculateStatisticOnGrid) {
            calculateDetectorRecognizerStatisticOnGrid(
                client,
                tmpTable.Name(),
                iouThreshold,
                recognizeNumber,
                skipSymbolsUnknownByRecognizer,
                outputGridStatisticPath);
        } else {
            Statistic statistic = calculateDetectorRecognizerStatistic(
                client,
                tmpTable.Name(),
                minDetectorConfidence,
                /*minRecognizerConfidence*/ -DBL_MAX,
                iouThreshold,
                recognizeNumber,
                false);
            printStatistic("Detector", statistic);
        }
    }

    if (testDetectorRecognizer) {
        const NYT::TTempTable tmpTable(client, "", YT_TEMPORARY_FOLDER);
        const bool recognizeNumber = true;
        detectHouseNumbers(client, inputYTPath.c_str(), tmpTable.Name(), filterCars, recognizeNumber, useGpu);
        if (!outputDetectorResultsPath.empty()) {
            exportDetectorResults(client, tmpTable.Name(), minDetectorConfidence, minRecognizerConfidence, outputDetectorResultsPath);
        }

        if (calculateStatisticOnGrid) {
            calculateDetectorRecognizerStatisticOnGrid(
                client,
                tmpTable.Name(),
                iouThreshold,
                recognizeNumber,
                skipSymbolsUnknownByRecognizer,
                outputGridStatisticPath);
        } else {
            Statistic statistic = calculateDetectorRecognizerStatistic(
                client,
                tmpTable.Name(),
                minDetectorConfidence,
                minRecognizerConfidence,
                iouThreshold,
                recognizeNumber,
                skipSymbolsUnknownByRecognizer);
            printStatistic("Detector+recognizer", statistic);
        }
    }

    if (testRecognizer) {
        const NYT::TTempTable tmpTable(client, "", YT_TEMPORARY_FOLDER);
        recognizeHouseNumbers(client, inputYTPath.c_str(), tmpTable.Name(), useGpu);
        Statistic statistic = calculateRecognizerStatistic(
            client,
            tmpTable.Name(),
            minRecognizerConfidence,
            skipSymbolsUnknownByRecognizer);
        printStatistic("Recognizer", statistic);
    }

    if (addrPointsSearcher) {
        maps::wiki::common::ExtendedXmlDoc wikiConfig(wikiConfigPath);

        const NYT::TTempTable addressPointsTable(client, "", YT_TEMPORARY_FOLDER);
        const int threadsCount = 10;
        searchAddressPoints(
            wikiConfig,
            commitId,
            client,
            inputYTPath.c_str(),
            addressPointsTable.Name(),
            skipSymbolsUnknownByRecognizer,
            COLUMN_NAME_OBJECTS,
            ITEM_NAME_VISIBLE_NUMBER,
            removeObjectsAddressPointFromMap,
            threadsCount);

        if (!removeObjectsAddressPointFromMap) {
            Statistic statistic = calculateAddressPointsStatistic(
                client,
                addressPointsTable.Name(),
                /*minDetectorConfidence*/ 0.0,
                /*minRecognizerConfidence*/ 0.0,
                minAddrPointConfidence,
                skipSymbolsUnknownByRecognizer);
            printStatistic("Address points searcher ", statistic);
        }

        Statistic statistic = calculateHypothesesStatistic(
            client,
            addressPointsTable.Name(),
            /*minDetectorConfidence*/ 0.0,
            /*minRecognizerConfidence*/ 0.0,
            minAddrPointConfidence,
            skipSymbolsUnknownByRecognizer,
            removeObjectsAddressPointFromMap);
        printStatistic("Hypotheses ", statistic);
    }

    if (fullStack) {
        maps::wiki::common::ExtendedXmlDoc wikiConfig(wikiConfigPath);

        const NYT::TTempTable detectedTable(client, "", YT_TEMPORARY_FOLDER);
        const NYT::TTempTable addressPointsTable(client, "", YT_TEMPORARY_FOLDER);
        const bool recognizeNumber = true;
        const int threadsCount = 10;
        detectHouseNumbers(client, inputYTPath.c_str(), detectedTable.Name(), filterCars, recognizeNumber, useGpu);
        searchAddressPoints(
            wikiConfig,
            commitId,
            client,
            detectedTable.Name(),
            addressPointsTable.Name(),
            skipSymbolsUnknownByRecognizer,
            COLUMN_NAME_DETECTED,
            ITEM_NAME_NUM,
            removeObjectsAddressPointFromMap,
            threadsCount);

        if (!removeObjectsAddressPointFromMap && !calculateStatisticOnGrid) {
            Statistic statistic = calculateAddressPointsStatistic(
                client,
                addressPointsTable.Name(),
                minDetectorConfidence,
                minRecognizerConfidence,
                minAddrPointConfidence,
                skipSymbolsUnknownByRecognizer);
            printStatistic("Address points searcher (full stack)", statistic);
        }

        if (calculateStatisticOnGrid) {
            calculateHypothesesStatisticOnGrid(
                client,
                addressPointsTable.Name(),
                skipSymbolsUnknownByRecognizer,
                removeObjectsAddressPointFromMap,
                outputGridStatisticPath);
        } else {
            Statistic statistic = calculateHypothesesStatistic(
                client,
                addressPointsTable.Name(),
                minDetectorConfidence,
                minRecognizerConfidence,
                minAddrPointConfidence,
                skipSymbolsUnknownByRecognizer,
                removeObjectsAddressPointFromMap);
            printStatistic("Hypotheses (full stack)", statistic);
        }
    }
    return EXIT_SUCCESS;
}
catch (const maps::Exception& e) {
    FATAL() << "Worker failed: " << e;
    return EXIT_FAILURE;
}
catch (const std::exception& e) {
    FATAL() << "Worker failed: " << e.what();
    return EXIT_FAILURE;
}
