#include <maps/wikimap/mapspro/services/mrc/libs/config/include/config.h>
#include <maps/wikimap/mapspro/services/mrc/libs/db/include/feature_gateway.h>
#include <maps/wikimap/mapspro/services/mrc/libs/db/include/object_in_photo_gateway.h>
#include <maps/wikimap/mapspro/services/mrc/libs/privacy_detector/include/privacy_detector_faster_rcnn.h>

#include <maps/wikimap/mapspro/services/mrc/libs/yt/include/common.h>
#include <maps/wikimap/mapspro/services/mrc/libs/yt/include/io.h>
#include <maps/wikimap/mapspro/services/mrc/libs/yt/include/operation.h>

#include <maps/libs/cmdline/include/cmdline.h>
#include <maps/libs/shellcmd/include/yandex/maps/shell_cmd.h>

#include <yandex/maps/wiki/common/string_utils.h>
#include <maps/libs/common/include/exception.h>
#include <maps/libs/common/include/profiletimer.h>
#include <maps/libs/log8/include/log8.h>

#include <mapreduce/yt/interface/client.h>
#include <util/generic/string.h>
#include <util/generic/size_literals.h>

#include <cstdint>
#include <cstdlib>
#include <string>

#include <unistd.h>
#include <boost/filesystem.hpp>

namespace fs = boost::filesystem;

namespace maps::mrc::test_yt_gpu {

db::Features loadFeatures(pgpool3::Pool& pool)
{
    constexpr int64_t MIN_FEAURE_ID = 53863586;
    constexpr int64_t MAX_FEAURE_ID = MIN_FEAURE_ID + 400;

    auto txn = pool.slaveTransaction();

    return db::FeatureGateway(*txn).load(
        db::table::Feature::id.between(MIN_FEAURE_ID, MAX_FEAURE_ID)
            && db::table::Feature::isPublished
    );
}

class PrivacyDetectWorker: public yt::Mapper {
    void Do(yt::Reader* reader, yt::Writer* writer) override;
};

void PrivacyDetectWorker::Do(yt::Reader* reader, yt::Writer* writer) {
    privacy_detector::FasterRCNNDetector detector;

    for (; reader->IsValid(); reader->Next()) {
        using FeatureWithImage = yt::FeatureWithImage<db::Feature>;
        const auto feature = yt::deserialize<FeatureWithImage>(reader->GetRow());

        if (feature.image.empty()) {
            continue;
        }

        const auto objects = detector.detect(feature.image);

        for (const auto& object: objects) {
            const db::ObjectInPhoto detected(
                feature.feature.id(),
                object.type,
                object.box,
                object.confidence
            );

            writer->AddRow(yt::serialize(detected));
        }
    }
}

REGISTER_MAPPER(PrivacyDetectWorker);

void detectPrivacyObjects(
        NYT::IClientBase& client,
        const TString& input,
        const TString& output)
{
    auto spec = yt::baseGpuOperationSpec("Test YT GPU", yt::PoolType::AdHoc)
        ("mapper", yt::baseGpuWorkerSpec()
            ("gpu_limit", 1)
            ("memory_limit", 8_GB)
        );

    // Run map
    client.Map(
        NYT::TMapOperationSpec()
            .AddInput<NYT::TNode>(input)
            .AddOutput<NYT::TNode>(output)
            .JobCount(5),
        new PrivacyDetectWorker(),
        NYT::TOperationOptions().Spec(spec)
    );
}

void upload(
        const common::Config& mrcConfig,
        NYT::IClientBase& client,
        const TString& path)
{
    const auto spec = NYT::TNode::CreateMap()
        ("title", "Upload")
        ("mapper", yt::baseGpuWorkerSpec());

    const size_t hintOnJobN = std::max(1ul, yt:: getRowCount(client, path) / 30);

    yt::uploadFeatureImage(mrcConfig, client, path, path, hintOnJobN, spec);
}

void run(const common::Config& mrcConfig)
{
    INFO() << "Start job...";
    auto mrc = mrcConfig.makePoolHolder();

    INFO() << "Load test features from db...";
    const db::Features features = loadFeatures(mrc.pool());
    INFO() << "Loaded " << features.size() << " images!";

    auto client = mrcConfig.externals().yt().makeClient();

    const TString INPUT_TABLE = "//home/maps/core/mrc/dasimagin/input_table";
    const TString OUTPUT_TABLE = "//home/maps/core/mrc/dasimagin/output_table";

    INFO() << "Save to tables " << INPUT_TABLE;
    yt::saveToTable(*client, INPUT_TABLE, features);

    INFO() << "Upload images to table " << INPUT_TABLE;
    upload(mrcConfig, *client, INPUT_TABLE);

    INFO() << "Run detection...";
    detectPrivacyObjects(*client, INPUT_TABLE, OUTPUT_TABLE);
    INFO() << "Results are stored at " << OUTPUT_TABLE;

    INFO() << "Job finished!";
}

} // namespace maps::mrc::test_yt_gpu

int main(int argc, const char** argv) try {
    NYT::Initialize(argc, argv);

    maps::cmdline::Parser parser;

    auto mrcConfigPath = parser.string("mrc-config")
        .help("path to mrc config");

    auto secret = parser.string("secret")
            .help("version for secrets from yav.yandex-team.ru");

    parser.parse(argc, const_cast<char**>(argv));

    const auto mrcConfig =
        maps::mrc::common::templateConfigFromCmdPath(secret, mrcConfigPath);

    maps::mrc::test_yt_gpu::run(mrcConfig);

    return EXIT_SUCCESS;
} catch (const maps::Exception& e) {
    FATAL() << "Worker failed: " << e;
    return EXIT_FAILURE;
} catch (const yexception& e) {
    FATAL() << "Worker failed: " << e.what();
    if (e.BackTrace()) {
        FATAL() << e.BackTrace()->PrintToString();
    }
    return EXIT_FAILURE;
} catch (const std::exception& e) {
    FATAL() << "Worker failed: " << e.what();
    return EXIT_FAILURE;
}
