#include "yt.h"
#include "mapper.h"

#include <maps/wikimap/mapspro/services/mrc/libs/db/include/feature_gateway.h>
#include <maps/wikimap/mapspro/services/mrc/libs/db/include/object_in_photo_gateway.h>
#include <maps/wikimap/mapspro/services/mrc/libs/yt/include/common.h>
#include <maps/wikimap/mapspro/services/mrc/libs/yt/include/io.h>
#include <maps/wikimap/mapspro/services/mrc/libs/yt/include/operation.h>
#include <maps/wikimap/mapspro/services/mrc/libs/yt/include/serialization.h>

#include <maps/libs/log8/include/log8.h>
#include <mapreduce/yt/interface/client.h>
#include <mapreduce/yt/library/operation_tracker/operation_tracker.h>
#include <util/generic/size_literals.h>

#include <algorithm>
#include <unordered_set>

namespace maps::mrc::image_analyzer {

namespace {

constexpr std::size_t MIN_IMAGES_PER_JOB = 500;

class YtProcessor
{
public:
    YtProcessor(const common::Config& mrcConfig, UseGpu useGpu)
        : mrcConfig_(mrcConfig)
        , useGpu_(useGpu)
        , estimateForbidderWithYavision_(useGpu == UseGpu::No)
        , client_(mrcConfig.externals().yt().makeClient())
    {
        const TString basePath(mrcConfig_.externals().yt().path() + "/image-analyzer");

        if (!client_->Exists(basePath)) {
            INFO() << "Create YT node " << basePath;
            client_->Create(basePath, NYT::NT_MAP, NYT::TCreateOptions()
                .Recursive(true));
        }
        inputPath_ = basePath + "/features-input";
        featuresImageMapperOutputPath_ = basePath + "/features-image-mapper-output";
        featuresYavisionOutputPath_ = basePath + "/features-yavision-output";
        featuresOutputPath_ = basePath + "/features-output";
        privacyObjectsOutputPath_ = basePath + "/image-objects-output";
    }

    ProcessedFeatureById process(const db::Features& features)
    {
        INFO() << "YT: Store " << features.size() << " features at " << inputPath_;
        yt::saveToTable(*client_, inputPath_, features);

        size_t numRows = yt::getRowCount(*client_, inputPath_);
        size_t jobsCount = std::max(numRows / MIN_IMAGES_PER_JOB, 1UL);

        INFO() << "YT: Run operations";

        NYT::TOperationTracker tracker;

        runFeatureImageMapper(tracker, jobsCount);
        if (estimateForbidderWithYavision_) {
            // Если используется GPU, классификация запрещенных снимков
            // происходит непосредственно в FeatureImageMapper
            // Если GPU нет, то дёргаем за ручку yavision для классификации
            // и чтобы это было не очень медлено делаем в параллельном YT-обработчике
            // при этом для некоторых feature мы можем не дождаться ответа от ручки
            // и тогда их не обработаем совсем.
            // Но мы решительно настроены запускаться на GPU и в эту ветку не попадать
            runYavisionMapper(tracker, jobsCount);
        }
        tracker.WaitAllCompleted();

        INFO() << "YT: Load results";
        auto results = loadResults();

        client_->Remove(inputPath_);
        client_->Remove(featuresImageMapperOutputPath_);
        client_->Remove(featuresYavisionOutputPath_, NYT::TRemoveOptions().Force(true));
        client_->Remove(privacyObjectsOutputPath_);

        INFO() << "YT: Done";
        return results;
    }

private:
    NYT::TNode makeFeatureImageMapperOperationSpec(const TString& title,
                                                   int gpuLimit = 1) const
    {
        NYT::TNode result;

        if (useGpu_ == UseGpu::Yes) {
            result = yt::baseGpuOperationSpec(title, yt::PoolType::Processing)
                ("mapper", yt::baseGpuWorkerSpec()
                    ("memory_limit", 6_GB)
                    ("gpu_limit", gpuLimit)
                );
        } else {
            result = yt::baseCpuOperationSpec(title, yt::PoolType::Processing)
                ("mapper", NYT::TNode::CreateMap()
                    ("memory_limit", 6_GB)
                    ("memory_reserve_factor", 0.6)
                    ("cpu_limit", 4)
                );
        }

        return result;
    }

    void runFeatureImageMapper(NYT::TOperationTracker& tracker, size_t jobsCount)
    {
        NYT::TUserJobSpec jobSpec;
        jobSpec.MemoryLimit(6_GB);

        auto operationOptions = makeFeatureImageMapperOperationSpec(
            "image classification and privacy objects detection");
        tracker.AddOperation(
            client_->Map(
                NYT::TMapOperationSpec()
                    .AddInput<NYT::TNode>(inputPath_)
                    .AddOutput<NYT::TNode>(featuresImageMapperOutputPath_)
                    .AddOutput<NYT::TNode>(privacyObjectsOutputPath_)
                    .JobCount(jobsCount)
                    .MaxFailedJobCount(100)
                    .MapperSpec(jobSpec),
                new FeatureImageMapper(mrcConfig_.toString().c_str(), !estimateForbidderWithYavision_),
                NYT::TOperationOptions()
                        .Spec(operationOptions)
                        .Wait(false)
            )
        );
    }

    void runYavisionMapper(NYT::TOperationTracker& tracker, size_t jobsCount)
    {
        NYT::TUserJobSpec jobSpec;
        jobSpec.MemoryLimit(1_GB);

        auto operationOptions = makeFeatureImageMapperOperationSpec(
            "yavision classification", 0 /*gpuLimit*/);

        tracker.AddOperation(
            client_->Map(
                NYT::TMapOperationSpec()
                    .AddInput<NYT::TNode>(inputPath_)
                    .AddOutput<NYT::TNode>(featuresYavisionOutputPath_)
                    .JobCount(jobsCount)
                    .MaxFailedJobCount(100)
                    .MapperSpec(jobSpec),
                new YavisionMapper(mrcConfig_.toString().c_str()),
                NYT::TOperationOptions()
                    .Spec(operationOptions)
                    .Wait(false)
            )
        );
    }

    ProcessedFeatureById loadResults()
    {
        auto featuresImageMapper = yt::loadFromTable<db::Features>(
            *client_, featuresImageMapperOutputPath_);

        INFO() << "Loaded " << featuresImageMapper.size()
            << " features from " << featuresImageMapperOutputPath_;

        auto privacyObjects = yt::loadFromTable<db::ObjectsInPhoto>(
            *client_, privacyObjectsOutputPath_);

        ProcessedFeatureById imageMapperFeatureById;
        for (auto& feature : featuresImageMapper) {
            auto id = feature.id();
            imageMapperFeatureById.insert({id, ProcessedFeature{std::move(feature), {}}});
        }

        for (auto& object : privacyObjects) {
            auto id = object.featureId();
            imageMapperFeatureById.at(id).privacyObjects.push_back(std::move(object));
        }

        if (!estimateForbidderWithYavision_) {
            return imageMapperFeatureById;
        }

        auto featuresYavision = yt::loadFromTable<db::Features>(
            *client_, featuresYavisionOutputPath_);

        INFO() << "Loaded " << featuresYavision.size()
            << " features from " << featuresYavisionOutputPath_;

        ProcessedFeatureById resultFeatureById;
        for (const auto& feature : featuresYavision) {
            auto id = feature.id();
            if (!imageMapperFeatureById.count(id)) {
                continue;
            }
            auto imageMapperFeature = imageMapperFeatureById.at(id);
            imageMapperFeature.feature.setForbiddenProbability(
                feature.forbiddenProbability()
            );
            resultFeatureById.emplace(id, std::move(imageMapperFeature));
        }

        return resultFeatureById;
    }

private:
    const common::Config& mrcConfig_;
    UseGpu useGpu_;
    bool estimateForbidderWithYavision_;
    NYT::IClientPtr client_;
    TString inputPath_;
    TString featuresImageMapperOutputPath_;
    TString featuresYavisionOutputPath_;
    TString featuresOutputPath_;
    TString privacyObjectsOutputPath_;
};

} // namespace


ProcessedFeatureById processOnYT(
    const db::Features& features,
    const common::Config& mrcConfig,
    UseGpu useGpu)
{
    YtProcessor ytProcessor(mrcConfig, useGpu);
    return ytProcessor.process(features);
}

} // maps::mrc::image_analyzer
