#include "yt.h"
#include "classifier.h"

#include <maps/wikimap/mapspro/services/mrc/libs/yt/include/io.h>
#include <maps/wikimap/mapspro/services/mrc/libs/yt/include/operation.h>
#include <maps/wikimap/mapspro/services/mrc/libs/yt/include/serialization.h>
#include <maps/wikimap/mapspro/services/mrc/libs/common/include/algorithm/retry.h>
#include <maps/wikimap/mapspro/services/mrc/libs/db/include/feature.h>
#include <maps/wikimap/mapspro/services/mrc/libs/sideview_classifier/include/sideview.h>

#include <maps/libs/log8/include/log8.h>
#include <mapreduce/yt/interface/client.h>
#include <mapreduce/yt/interface/node.h>

#include <util/generic/size_literals.h>

#include <algorithm>

namespace maps::mrc::sideview_classifier {

namespace {

constexpr std::size_t MAX_JOBS_COUNT = 200;
constexpr std::size_t MIN_IMAGES_PER_JOB = 30;

template <typename Functor>
auto retry(Functor&& func) -> decltype(func())
{
    return common::retryOnException<std::exception>(
        common::RetryPolicy()
            .setInitialTimeout(std::chrono::seconds(1))
            .setMaxAttempts(4)
            .setTimeoutBackoff(1),
        std::forward<Functor>(func));
}


class YtProcessor
{
public:
    explicit YtProcessor(const common::Config& mrcConfig)
        : mrcConfig_(mrcConfig)
        , client_(mrcConfig.externals().yt().makeClient())
    {
        const TString basePath(mrcConfig_.externals().yt().path() + "/sideview-classifier");

        if (!client_->Exists(basePath)) {
            INFO() << "Create YT node " << basePath;
            client_->Create(basePath, NYT::NT_MAP, NYT::TCreateOptions()
                .Recursive(true));
        }
        inputPath_ = basePath + "/features-input";
        outputPath_ = basePath + "/features-output";
    }


    Outputs process(const Inputs& inputs)
    {
        INFO() << "YT: Store " << inputs.size() << " inputs at " << inputPath_;

        saveToTable(*client_, inputPath_, inputs);

        size_t numRows = yt::getRowCount(*client_, inputPath_);
        size_t jobsCount = std::clamp(numRows / MIN_IMAGES_PER_JOB,
                                      1UL, MAX_JOBS_COUNT);

        INFO() << "YT: Classify sideview";
        runMapper(jobsCount);

        INFO() << "YT: Load results";
        auto result = loadFromTable(*client_, outputPath_);

        //client_->Remove(inputPath_);
        //client_->Remove(outputPath_);

        INFO() << "YT: Done";
        return result;
    }

private:
    void saveToTable(
            NYT::IIOClient& client,
            const TString& path,
            Inputs inputs)
    {
        const auto writer = client.CreateTableWriter<NYT::TNode>(TString(path));
        auto mds = mrcConfig_.makeMdsClient();

        for (const auto& input : inputs) {
            const auto& feature1 = input.feature1;
            const auto& feature2 = input.feature2;

            writer->AddRow(
                NYT::TNode::CreateMap()
                    (COL_KEY_SOURCE_ID, TString(input.key.sourceId))
                    (COL_KEY_MIN_FEATURE_ID, input.key.minFeatureId)
                    (COL_FEATURE_ID_1, feature1.id())
                    (COL_FEATURE_ID_2, feature2.id())
                    (COL_URL_1, TString(mds.makeReadUrl(feature1.mdsKey())))
                    (COL_URL_2, TString(mds.makeReadUrl(feature2.mdsKey())))
                    (COL_ORIENTATION_1, static_cast<int>(feature1.orientation()))
                    (COL_ORIENTATION_2, static_cast<int>(feature2.orientation()))
            );
        }

    }

    Outputs loadFromTable(
            NYT::IIOClient& client,
            const TString& path)
    {
        auto reader = client.CreateTableReader<NYT::TNode>(path);

        Outputs outputs;
        for (; reader->IsValid(); reader->Next()) {
            const NYT::TNode& row = reader->GetRow();
            outputs.push_back(Output{
                PassageKey{row[COL_KEY_SOURCE_ID].AsString(),
                           row[COL_KEY_MIN_FEATURE_ID].AsInt64()},
                row[COL_FEATURE_ID_1].AsInt64(),
                static_cast<sideview::SideViewType>(row[COL_TYPE].AsInt64()),
                (float)row[COL_CONFIDENCE].AsDouble()
            });
        }
        return outputs;
    }

    void runMapper(size_t jobsCount)
    {
        NYT::TUserJobSpec jobSpec;
        jobSpec.MemoryLimit(6_GB);

        const auto operationOptions = NYT::TNode::CreateMap()
            ("title", "sideview classification")
            ("mapper", NYT::TNode::CreateMap()
                           ("memory_limit", 6_GB)
                           ("memory_reserve_factor", 0.6));

        client_->Map(
            NYT::TMapOperationSpec()
                .AddInput<NYT::TNode>(inputPath_)
                .AddOutput<NYT::TNode>(outputPath_)
                .JobCount(jobsCount)
                .MapperSpec(jobSpec),
            new SideviewClassifier(),
            NYT::TOperationOptions().Spec(operationOptions)
        );
    }

private:
    const common::Config& mrcConfig_;
    NYT::IClientPtr client_;
    TString inputPath_;
    TString outputPath_;
};

} // namespace


Outputs classifyOnYT(
    const Inputs& inputs,
    const common::Config& mrcConfig)
{
    YtProcessor ytProcessor(mrcConfig);
    return retry([&]{ return ytProcessor.process(inputs); });
}

} // maps::mrc::sideview_classifier
