#include <maps/wikimap/mapspro/services/mrc/libs/common/include/exif.h>
#include <maps/wikimap/mapspro/services/mrc/libs/sideview_classifier/include/sideview.h>

#include <maps/libs/cmdline/include/cmdline.h>
#include <maps/libs/common/include/exception.h>
#include <maps/libs/http/include/http.h>
#include <maps/libs/log8/include/log8.h>
#include <maps/libs/pgpool/include/pgpool3.h>

#include <mapreduce/yt/interface/client.h>
#include <mapreduce/yt/util/temp_table.h>

#include <opencv2/opencv.hpp>

#include <fstream>
#include <iostream>
#include <sstream>
#include <vector>


using namespace maps::mrc;

namespace {

static const TString COLUMN_NAME_URL1            = "url1";
static const TString COLUMN_NAME_URL2            = "url2";
static const TString COLUMN_NAME_ORIENTATION1    = "orientation1";
static const TString COLUMN_NAME_ORIENTATION2    = "orientation2";
static const TString COLUMN_NAME_SIDEVIEW_GT     = "sideview_gt";
static const TString COLUMN_NAME_SIDEVIEW        = "sideview";
static const TString COLUMN_NAME_CONFIDENCE      = "confidence";

constexpr size_t MiB = 1024 * 1024;
constexpr size_t GiB = 1024 * MiB;

NYT::TNode createOperationSpec(const TString& operationType, bool useGpu) {
    NYT::TNode operationSpec = NYT::TNode::CreateMap()
        ("title", "House number (test pipline)")
        ("scheduling_tag_filter", "!testing");

    const TString GPU_POOL_TREES = "gpu_geforce_1080ti";

    if (useGpu) {
        operationSpec
        ("pool_trees", NYT::TNode::CreateList().Add(GPU_POOL_TREES))
        ("scheduling_options_per_pool_tree", NYT::TNode::CreateMap()
            (GPU_POOL_TREES, NYT::TNode::CreateMap()("pool", "research_gpu"))
        )
        (operationType, NYT::TNode::CreateMap()
            ("memory_limit", 16 * GiB)
            ("gpu_limit", 1)
            ("layer_paths",  NYT::TNode::CreateList()
                            .Add("//porto_layers/delta/gpu/cuda/10.1")
                            .Add("//porto_layers/delta/gpu/driver/418.67")
                            .Add("//porto_layers/base/bionic/porto_layer_search_ubuntu_bionic_app_lastest.tar.gz")
            )
        );
    } else {
        operationSpec
        (operationType, NYT::TNode::CreateMap()
            ("cpu_limit", 4)
            ("memory_limit", 16 * GiB)
            ("memory_reserve_factor", 0.6)
        );
    }
    return operationSpec;
}

int uploadPairs(const std::string& inputPath, NYT::IClientPtr& client, const TString& tableName)
{
    NYT::TTableWriterPtr<NYT::TNode> writer = client->CreateTableWriter<NYT::TNode>(tableName);

    int uploaded = 0;
    std::ifstream ifs(inputPath);
    REQUIRE(ifs.is_open(), "Unable to open file: " << inputPath);
    for (; !ifs.eof();) {
        std::string line; std::getline(ifs, line);
        if (line.empty())
            continue;
        std::stringstream ss(line);

        std::string url1, url2;
        int orientation1, orientation2;
        int sideviewGT;
        ss >> url1 >> orientation1 >> url2 >> orientation2 >> sideviewGT;
        uploaded ++;
        writer->AddRow(
            NYT::TNode()
                (COLUMN_NAME_URL1,         NYT::TNode(url1))
                (COLUMN_NAME_ORIENTATION1, NYT::TNode(orientation1))
                (COLUMN_NAME_URL2,         NYT::TNode(url2))
                (COLUMN_NAME_ORIENTATION2, NYT::TNode(orientation2))
                (COLUMN_NAME_SIDEVIEW_GT,  NYT::TNode(sideviewGT))
        );
    };
    return uploaded;
}

std::vector<uint8_t> downloadImage(maps::http::Client& client, const std::string& url)
{
    maps::common::RetryPolicy retryPolicy;
    retryPolicy.setTryNumber(10)
        .setInitialCooldown(std::chrono::seconds(1))
        .setCooldownBackoff(2);

    auto validateResponse = [](const auto& maybeResponse) {
        return maybeResponse.valid() && maybeResponse.get().responseClass() != maps::http::ResponseClass::ServerError;
    };
    auto resp = maps::common::retry(
                [&]() {
                    return maps::http::Request(client, maps::http::GET, maps::http::URL(url)).perform();
                },
                retryPolicy,
                validateResponse
            );
    REQUIRE(resp.responseClass() == maps::http::ResponseClass::Success,
        "Unexpected response status " << resp.status() << " for url "
        << url);
    return resp.readBodyToVector();
}

cv::Mat loadImage(maps::http::Client& client, const std::string& url, int orientation)
{
    std::vector<uint8_t> data = downloadImage(client, url);
    cv::Mat image = cv::imdecode(data, cv::IMREAD_COLOR | cv::IMREAD_IGNORE_ORIENTATION);
    return maps::mrc::common::transformByImageOrientation(image, maps::mrc::common::ImageOrientation::fromExif(orientation));
}

// YT Mappers
class TClassifierMapper
    : public NYT::IMapper<NYT::TTableReader<NYT::TNode>, NYT::TTableWriter<NYT::TNode>>  {
public:
    void Do(NYT::TTableReader<NYT::TNode>* reader, NYT::TTableWriter<NYT::TNode>* writer) override {
        INFO() << "Start classification ... ";
        maps::http::Client client;

        maps::mrc::sideview::SideViewClassifier classifier;
        for (; reader->IsValid(); reader->Next()) {
            const NYT::TNode &inpRow = reader->GetRow();

            std::pair<maps::mrc::sideview::SideViewType, float> classifierResult = classifier.inference(
                loadImage(client, inpRow[COLUMN_NAME_URL1].AsString(), inpRow[COLUMN_NAME_ORIENTATION1].AsInt64()),
                loadImage(client, inpRow[COLUMN_NAME_URL2].AsString(), inpRow[COLUMN_NAME_ORIENTATION2].AsInt64())
            );
            NYT::TNode outRow = inpRow;
            outRow[COLUMN_NAME_SIDEVIEW]    = (classifierResult.first == maps::mrc::sideview::SideViewType::ForwardView) ? 0 : 1;
            outRow[COLUMN_NAME_SIDEVIEW_GT] = inpRow[COLUMN_NAME_SIDEVIEW_GT].AsInt64();
            outRow[COLUMN_NAME_CONFIDENCE]  = classifierResult.second;
            writer->AddRow(outRow);
        }
    }
};
REGISTER_MAPPER(TClassifierMapper);

void loadStatistic(NYT::IClientPtr& client, const TString& tableName) {
    NYT::TTableReaderPtr<NYT::TNode> reader = client->CreateTableReader<NYT::TNode>(tableName);

    int frontViewValids = 0;
    int frontViewErrors = 0;
    int sideViewValids = 0;
    int sideViewErrors = 0;
    for (;reader->IsValid(); reader->Next()) {
        const NYT::TNode& inpRow = reader->GetRow();
        if (0 == inpRow[COLUMN_NAME_SIDEVIEW_GT].AsInt64()) {
            if (0 == inpRow[COLUMN_NAME_SIDEVIEW].AsInt64()) {
                frontViewValids++;
            } else {
                frontViewErrors++;
            }
        } else {
            if (1 == inpRow[COLUMN_NAME_SIDEVIEW].AsInt64()) {
                sideViewValids++;
            } else {
                sideViewErrors++;
            }
        }
    }

    INFO()  << "Front view valids: " << frontViewValids << " from " << (frontViewValids + frontViewErrors)
            << ". Accuracy: " << (double)frontViewValids / (double) (frontViewValids + frontViewErrors) * 100. << "%";
    INFO()  << "Side view valids:  " << sideViewValids << " from " << (sideViewValids + sideViewErrors)
            << ". Accuracy: " << (double)sideViewValids / (double) (sideViewValids + sideViewErrors) * 100. << "%";
}

} //namespace

int main(int argc, const char** argv) try {
    static const TString YT_PROXY = "hahn";

    NYT::Initialize(argc, argv);

    maps::cmdline::Parser parser("Test sideview classifier");

    maps::cmdline::Option<std::string> inputPathParam = parser.string("input")
        .required()
        .help("Path to input file with urls pairs");

    maps::cmdline::Option<bool> useGpu = parser.flag("use-gpu")
        .help("Use GPU for detector and recognizer tasks");

    parser.parse(argc, const_cast<char**>(argv));

    INFO() << "Connecting to yt::" << YT_PROXY;
    NYT::IClientPtr client = NYT::CreateClient(YT_PROXY);
    const NYT::TTempTable inputTable(client);
    int uploaded = uploadPairs(inputPathParam, client, inputTable.Name());
    INFO() << uploaded << " pairs uploaded for classifier to " << inputTable.Name();

    const NYT::TTempTable outputTable(client);
    INFO() << "Start classifier to " << outputTable.Name();
    client->Map(
        NYT::TMapOperationSpec()
            .AddInput<NYT::TNode>(inputTable.Name())
            .AddOutput<NYT::TNode>(outputTable.Name())
            .JobCount(std::max(1, std::min(uploaded / 10, 10))),
        new TClassifierMapper(),
        NYT::TOperationOptions().Spec(createOperationSpec("mapper", useGpu))
    );
    INFO() << "Classifier finished";

    loadStatistic(client, outputTable.Name());
    return EXIT_SUCCESS;
}
catch (const maps::Exception& e) {
    FATAL() << "Worker failed: " << e;
    return EXIT_FAILURE;
}
catch (const std::exception& e) {
    FATAL() << "Worker failed: " << e.what();
    return EXIT_FAILURE;
}
