#include <maps/wikimap/mapspro/services/mrc/libs/common/include/exif.h>
#include <maps/wikimap/mapspro/services/mrc/libs/config/include/config.h>
#include <maps/wikimap/mapspro/services/mrc/libs/db/include/feature.h>
#include <maps/wikimap/mapspro/services/mrc/libs/db/include/feature_gateway.h>
#include <maps/wikimap/mapspro/services/mrc/libs/classifiers/include/rotation_classifier.h>
#include <maps/wikimap/mapspro/services/mrc/libs/yt/include/common.h>

#include <maps/libs/cmdline/include/cmdline.h>
#include <maps/libs/common/include/exception.h>
#include <maps/libs/geolib/include/point.h>
#include <maps/libs/http/include/http.h>
#include <maps/libs/log8/include/log8.h>
#include <maps/libs/pgpool/include/pgpool3.h>
#include <maps/libs/sql_chemistry/include/batch_load.h>

#include <mapreduce/yt/interface/client.h>
#include <mapreduce/yt/util/temp_table.h>

#include <util/generic/size_literals.h>

#include <opencv2/opencv.hpp>

#include <fstream>
#include <iostream>


using namespace maps::mrc;

namespace {

static const TString COLUMN_NAME_FEATURE_ID  = "feature_id";
static const TString COLUMN_NAME_URL         = "url";
static const TString COLUMN_NAME_ORIENTATION = "orientation";

NYT::TNode createOperationSpec(size_t ytConcurrency, bool useGpu) {
    static const std::string TITLE = "Orientation classifier";

    NYT::TNode operationSpec;
    if (useGpu) {
        operationSpec = yt::baseGpuOperationSpec(TITLE, yt::PoolType::Processing)
            ("mapper", yt::baseGpuWorkerSpec()
                ("gpu_limit", 1)
                ("memory_limit", 16_GB)
            );
    } else {
        operationSpec = yt::baseCpuOperationSpec(TITLE, yt::PoolType::Processing)
            ("mapper", NYT::TNode::CreateMap()
                ("cpu_limit", 1)
                ("memory_limit", 16_GB)
            );
    }

    operationSpec("resource_limits", NYT::TNode()("user_slots", ytConcurrency));

    return operationSpec;
}

db::Features loadWalksFeatures(maps::pgpool3::Pool& pool)
{
    constexpr size_t batchSize = 50000;

    maps::mrc::db::Features features;
    auto txn = pool.slaveTransaction();
    maps::sql_chemistry::BatchLoad<db::table::Feature>
        batch{
            batchSize,
            db::table::Feature::dataset == db::Dataset::Walks &&
            db::table::Feature::orientation.isNull()
        };
    while (batch.next(*txn)) {
        for (auto feature = batch.begin(); feature != batch.end(); feature++) {
            features.push_back(*feature);
        }
    }
    return features;
}

TString getFeatureUrl(const maps::mrc::db::Feature& f) {
    return ("http://storage-int.mds.yandex.net/get-maps_mrc/" + f.mdsGroupId() + "/" + f.mdsPath()).c_str();
}

void uploadFeatures(NYT::IClientPtr& client, const TString& tableName, const db::Features& features)
{
    NYT::TTableWriterPtr<NYT::TNode> writer = client->CreateTableWriter<NYT::TNode>(tableName);
    for (size_t i = 0; i < features.size(); i++) {
        const db::Feature& f = features[i];
        writer->AddRow(
            NYT::TNode()
                (COLUMN_NAME_FEATURE_ID, NYT::TNode(f.id()))
                (COLUMN_NAME_URL,        NYT::TNode(getFeatureUrl(f)))
            );
    }
}

std::vector<uint8_t> downloadImage(maps::http::Client& client, const std::string& url)
{
    maps::common::RetryPolicy retryPolicy;
    retryPolicy.setTryNumber(10)
        .setInitialCooldown(std::chrono::seconds(1))
        .setCooldownBackoff(2);

    auto validateResponse = [](const auto& maybeResponse) {
        return maybeResponse.valid() && maybeResponse.get().responseClass() != maps::http::ResponseClass::ServerError;
    };
    auto resp = maps::common::retry(
        [&]() {
            return maps::http::Request(client, maps::http::GET, maps::http::URL(url)).perform();
        },
        retryPolicy,
        validateResponse
    );
    REQUIRE(resp.responseClass() == maps::http::ResponseClass::Success,
        "Unexpected response status " << resp.status() << " for url "
        << url);
    return resp.readBodyToVector();
}

cv::Mat loadImage(maps::http::Client& client, const std::string& url)
{
    std::vector<uint8_t> data = downloadImage(client, url);
    cv::Mat img = cv::imdecode(data, cv::IMREAD_COLOR | cv::IMREAD_IGNORE_ORIENTATION);
    if (img.rows > 2000 || img.cols > 2000) {
        cv::resize(img, img, cv::Size(), 0.5, 0.5, cv::INTER_AREA);
    }
    return img;
}

// YT Mappers
class TClassifierMapper
    : public NYT::IMapper<NYT::TTableReader<NYT::TNode>, NYT::TTableWriter<NYT::TNode>>  {
public:
    void Do(NYT::TTableReader<NYT::TNode>* reader, NYT::TTableWriter<NYT::TNode>* writer) override {
        INFO() << "Start classification ... ";
        maps::http::Client client;

        classifiers::RotationClassifier classifier;
        for (; reader->IsValid(); reader->Next()) {
            const NYT::TNode &inpRow = reader->GetRow();

            common::ImageOrientation orientation = classifier.detectImageOrientation(
                loadImage(client, inpRow[COLUMN_NAME_URL].AsString())
            );
            NYT::TNode outRow = inpRow;
            outRow[COLUMN_NAME_FEATURE_ID] = inpRow[COLUMN_NAME_FEATURE_ID];
            outRow[COLUMN_NAME_ORIENTATION] = (int16_t)orientation;
            writer->AddRow(outRow);
        }
    }
};
REGISTER_MAPPER(TClassifierMapper);

std::map<db::TId, common::ImageOrientation> loadClassificationResults(
    NYT::IClientPtr& client,
    const TString& tableName)
{
    NYT::TTableReaderPtr<NYT::TNode> reader = client->CreateTableReader<NYT::TNode>(tableName);

    std::map<db::TId, common::ImageOrientation> featureIdToOrientation;
    for (;reader->IsValid(); reader->Next()) {
        const NYT::TNode& inpRow = reader->GetRow();
        const db::TId fid = inpRow[COLUMN_NAME_FEATURE_ID].AsInt64();
        featureIdToOrientation[fid] = common::ImageOrientation::fromExif(inpRow[COLUMN_NAME_ORIENTATION].AsInt64());
    }
    return featureIdToOrientation;
}

void updateFeaturesOrientation(const std::map<db::TId, common::ImageOrientation>& featureIdToOrientation, db::Features& features)
{
    for (int i = (int)features.size() - 1; i >= 0; i--) {
        auto it = featureIdToOrientation.find(features[i].id());
        if (it == featureIdToOrientation.end()) {
            INFO() << "Orientation for feature: " << features[i].id() << " not defined";
            features.erase(features.begin() + i);
            continue;
        }
        features[i].setOrientation(it->second);
    }
}

void updateFeatures(maps::pgpool3::Pool& pool, db::Features& features)
{
    auto txn = pool.masterWriteableTransaction();
    db::FeatureGateway{*txn}.update(features, db::UpdateFeatureTxn::Yes);
    txn->commit();
}

}

int main(int argc, const char** argv) try {
    static const TString YT_PROXY = "hahn";
    static const size_t ytConcurrency = 100;

    NYT::Initialize(argc, argv);

    maps::cmdline::Parser parser("Update features orientation");

    maps::cmdline::Option<std::string> mrcConfigPath = parser.string("mrc-config")
        .required()
        .help("Path to mrc config");

    maps::cmdline::Option<std::string> secretVersion = parser.string("secret-version")
        .help("version for secrets from yav.yandex-team.ru");

    maps::cmdline::Option<bool> dryRun = parser.flag("dry-run")
        .help("Don't apply changes to DB");

    maps::cmdline::Option<bool> useGpu = parser.flag("use-gpu")
        .help("Use GPU for detector and recognizer tasks");


    parser.parse(argc, const_cast<char**>(argv));

    const maps::mrc::common::Config mrcConfig =
    maps::mrc::common::templateConfigFromCmdPath(secretVersion, mrcConfigPath);
    maps::wiki::common::PoolHolder mrc(mrcConfig.makePoolHolder());

    INFO() << "Start load features";
    db::Features features = loadWalksFeatures(mrc.pool());
    INFO() << "Features loaded: " << features.size();

    NYT::IClientPtr client = NYT::CreateClient(YT_PROXY);
    const NYT::TTempTable inputTable(client);
    uploadFeatures(client, inputTable.Name(), features);
    const NYT::TTempTable outputTable(client);
    INFO() << "Start classifier to " << outputTable.Name();
    client->Map(
        NYT::TMapOperationSpec()
            .AddInput<NYT::TNode>(inputTable.Name())
            .AddOutput<NYT::TNode>(outputTable.Name())
            .JobCount(std::max(1, (int)features.size() / 100)),
        new TClassifierMapper(),
        NYT::TOperationOptions().Spec(createOperationSpec(ytConcurrency, useGpu))
    );
    INFO() << "Classifier finished";

    std::map<db::TId, common::ImageOrientation> featureIdToOrientation =
        loadClassificationResults(client, outputTable.Name());

    if (!dryRun) {
        INFO() << "Start load features";
        features = loadWalksFeatures(mrc.pool());
        INFO() << "Features loaded: " << features.size();
        updateFeaturesOrientation(featureIdToOrientation, features);
        INFO() << "Start update features";
        updateFeatures(mrc.pool(), features);
        INFO() << "Features updates";
    }

    return EXIT_SUCCESS;
}
catch (const maps::Exception& e) {
    FATAL() << "Worker failed: " << e;
    return EXIT_FAILURE;
}
catch (const std::exception& e) {
    FATAL() << "Worker failed: " << e.what();
    return EXIT_FAILURE;
}
