#include "make_user_data_encoded_job.h"

#include "encode_user_data_map.h"
#include "extract_affinity_tokens_map.h"
#include "helpers.h"
#include "schemas.h"
#include "word_count_with_weights_reduce.h"

#include <crypta/lab/lib/native/encoded_user_data.h>
#include <crypta/lib/native/yt/utils/helpers.h>
#include <crypta/lib/proto/user_data/attribute_names.pb.h>
#include <crypta/siberia/bin/make_user_data_encoded/lib/proto_types/table_data_types.pb.h>

#include <mapreduce/yt/common/config.h>
#include <mapreduce/yt/util/temp_table.h>
#include <transfer_manager/cpp/api/transfer.h>

#include <util/datetime/base.h>
#include <util/generic/vector.h>

using namespace NCrypta;
using namespace NCrypta::NSiberia::NUserDataWithDicts;
using namespace NLab::NEncodedUserData;

using TTimestamp = i64;

namespace {
    const TString LAST_UPDATE_TIMESTAMP_ATTR = NLab::TAttributeNames().GetLastUpdateTimestamp();

    TInstant GetModificationTime(NYT::IClientBasePtr client, const TString& table) {
        return GetTimestampFromAttribute(client, table, "modification_time");
    }

    TInstant GetUpdateTime(NYT::IClientBasePtr client, const TString& table) {
        const auto timestamp = GetAttribute<TTimestamp>(client, table, LAST_UPDATE_TIMESTAMP_ATTR);
        return timestamp.Defined() ? TInstant::Seconds(*timestamp) : GetModificationTime(client, table);
    }

    bool NeedsUpdate(NYT::IClientBasePtr client, const TInstant srcUpdateTime, const TString& dstTable, NLog::TLogPtr log) {
        if (client->Exists(dstTable)) {
            const auto dstUpdateTime = GetUpdateTime(client, dstTable);
            log->info("Dst Update Time: {}", dstUpdateTime.Seconds());

            if (srcUpdateTime <= dstUpdateTime) {
                log->info("No need to update user data table with dicts as {} <= {}", srcUpdateTime.ToString(), dstUpdateTime.ToString());
                return false;
            }
        }
        return true;
    }
}

TMakeUserDataWithDictJob::TMakeUserDataWithDictJob(TConfig config, NLog::TLogPtr log)
    : Config(std::move(config))
    , Log(log)
{
    Y_ENSURE(Log != nullptr, "Invalid log ptr");
}

int TMakeUserDataWithDictJob::Do() {
    Log->info("================= Start =================");

    auto ytClient = NYT::CreateClient(Config.GetYt().GetProxy());
    NYT::TConfig::Get()->Pool = Config.GetYt().GetPool();

    const auto& srcUserDataTable = Config.GetSrcUserDataTable();
    const auto& dstUserDataTable = Config.GetDstUserDataTable();
    const auto& dstWordDictTable = Config.GetDstWordDictTable();
    const auto& dstHostDictTable = Config.GetDstHostDictTable();
    const auto& dstAppDictTable = Config.GetDstAppDictTable();

    Log->info("Src User Data Table: {}", srcUserDataTable);
    Log->info("Dst User Data table: {}", dstUserDataTable);
    Log->info("Dst Word Dict table: {}", dstWordDictTable);
    Log->info("Dst Host Dict table: {}", dstHostDictTable);
    Log->info("Dst App Dict table: {}", dstAppDictTable);

    auto tx = ytClient->StartTransaction();

    const auto srcUpdateTime = GetUpdateTime(tx, srcUserDataTable);
    Log->info("Src Update Time: {}", srcUpdateTime.Seconds());

    if (NeedsUpdate(tx, srcUpdateTime, dstUserDataTable, Log)) {
        Log->info("Updating {} from {}", dstUserDataTable, srcUserDataTable);

        NYT::TTempTable words(tx, "words");
        NYT::TTempTable hosts(tx, "hosts");
        NYT::TTempTable apps(tx, "apps");
        ExtractAffinityTokens(tx, srcUserDataTable, words.Name(), hosts.Name(), apps.Name());

        NYT::TTempTable wordsCount(tx, "words_count");
        NYT::TTempTable hostsCount(tx, "hosts_count");
        NYT::TTempTable appsCount(tx, "apps_count");
        PerformWordCountWithWeights(tx, words.Name(), wordsCount.Name());
        PerformWordCountWithWeights(tx, hosts.Name(), hostsCount.Name());
        PerformWordCountWithWeights(tx, apps.Name(), appsCount.Name());

        MergeNewTokensIntoPriorDict(tx, wordsCount.Name(), dstWordDictTable);
        MergeNewTokensIntoPriorDict(tx, hostsCount.Name(), dstHostDictTable);
        MergeNewTokensIntoPriorDict(tx, appsCount.Name(), dstAppDictTable);

        EncodeUserDataWithDicts(tx, srcUserDataTable, dstWordDictTable, dstHostDictTable, dstAppDictTable, dstUserDataTable);
        SetAttribute(tx, dstUserDataTable, LAST_UPDATE_TIMESTAMP_ATTR, static_cast<TTimestamp>(srcUpdateTime.Seconds()));

        Log->info("Committing");
        tx->Commit();

        TFileOutput releaseInfoStream("encoded_user_data_release_info.json");
        releaseInfoStream << "{}";
    }

    auto replicaYtClient = NYT::CreateClient(Config.GetReplicaYt().GetProxy());

    if (NeedsUpdate(replicaYtClient, srcUpdateTime, dstUserDataTable, Log)) {
        Log->info("Transfering tables to replica cluster");

        NTM::TMultiTaskContainer taskContainer(ytClient, replicaYtClient);
        TVector<std::pair<NYT::TYPath, NYT::TYPath>> allTargets{
            {NYT::JoinYPaths(Config.GetTmpReplicaDir(), "host_dict"), dstHostDictTable},
            {NYT::JoinYPaths(Config.GetTmpReplicaDir(), "word_dict"), dstWordDictTable},
            {NYT::JoinYPaths(Config.GetTmpReplicaDir(), "app_dict"), dstAppDictTable},
            {NYT::JoinYPaths(Config.GetTmpReplicaDir(), "user_data"), dstUserDataTable},
        };
        for (const auto& [tmp, target]: allTargets) {
            taskContainer.AddTask({{target, tmp}});
        }

        replicaYtClient->Create(Config.GetTmpReplicaDir(), NYT::NT_MAP, NYT::TCreateOptions().Recursive(true).IgnoreExisting(true));

        NTM::TMultiTaskTransferSettings multiTaskTransferSettings;
        multiTaskTransferSettings.TMSettings.TransferManagerHost = Config.GetTransferManager().GetHost();
        multiTaskTransferSettings.TMSettings.TransferManagerPort = Config.GetTransferManager().GetPort();
        multiTaskTransferSettings.TMSettings.TransferManagerRetries = Config.GetTransferManager().GetRetries();
        multiTaskTransferSettings.RetryCount = 0;
        multiTaskTransferSettings.MaxSimultaneousTransfers = 3;

        NTM::Transfer(taskContainer, NTM::TTransferSettings(), multiTaskTransferSettings);

        {
            Log->info("Moving tables from tmp to target");
            auto tx = replicaYtClient->StartTransaction();
            for (const auto& [tmp, target]: allTargets) {
                tx->Move(tmp, target, NYT::TMoveOptions().Force(true).Recursive(true));
            }
            SetAttribute(tx, dstUserDataTable, LAST_UPDATE_TIMESTAMP_ATTR, static_cast<TTimestamp>(srcUpdateTime.Seconds()));

            Log->info("Committing");
            tx->Commit();
        }

        Log->info("Done");
    }

    Log->info("================= Finish =================");
    return 0;
}

void TMakeUserDataWithDictJob::ExtractAffinityTokens(NYT::ITransactionPtr tx, const TString& srcUserDataTable, const TString& dstWords, const TString& dstHosts, const TString& dstApps) {
    TExtractAffinityTokensMap::TOutputIndexes::TBuilder outputBuilder;
    outputBuilder.Add(
        NYT::TRichYPath(dstWords).Schema(NYT::CreateTableSchema<NLab::TWeightedToken>()).OptimizeFor(NYT::OF_SCAN_ATTR),
        TExtractAffinityTokensMap::EOutputTables::Words);
    outputBuilder.Add(
        NYT::TRichYPath(dstHosts).Schema(NYT::CreateTableSchema<NLab::TWeightedToken>()).OptimizeFor(NYT::OF_SCAN_ATTR),
        TExtractAffinityTokensMap::EOutputTables::Hosts);
    outputBuilder.Add(
        NYT::TRichYPath(dstApps).Schema(NYT::CreateTableSchema<NLab::TWeightedToken>()).OptimizeFor(NYT::OF_SCAN_ATTR),
        TExtractAffinityTokensMap::EOutputTables::Apps);

    auto spec = NYT::TMapOperationSpec();
    spec.AddInput<NLab::TUserData>(srcUserDataTable);
    AddOutputs<NLab::TWeightedToken>(spec, outputBuilder.GetTables());

    Log->info("Extracting words and hosts from {}", srcUserDataTable);
    tx->Map(spec, new TExtractAffinityTokensMap(outputBuilder.GetIndexes()));
}

void TMakeUserDataWithDictJob::PerformWordCountWithWeights(NYT::ITransactionPtr tx, const TString& weightedTokensTable, const TString& countedTokensTable) {
    const auto& tokenField = YT_FIELD(NLab::TWeightedToken, Token);
    Log->info("Sorting {} by {}", weightedTokensTable, tokenField);
    tx->Sort(weightedTokensTable, weightedTokensTable, {tokenField});

    Log->info("Performing word count on {} by {}", weightedTokensTable, tokenField);

    auto spec = NYT::TReduceOperationSpec()
        .SortBy(tokenField)
        .ReduceBy({tokenField})
        .AddInput<NLab::TWeightedToken>(weightedTokensTable)
        .AddOutput<TWeightedWordCount>(countedTokensTable);
    tx->Reduce(spec, new TWordCountWithWeightsReduce());
}

void TMakeUserDataWithDictJob::MergeNewTokensIntoPriorDict(NYT::ITransactionPtr tx, const TString& newTokensTable, const TString& priorDictTable) {
    const auto& weightField = YT_FIELD(TWeightedWordCount, Weight);
    Log->info("Sorting {} by {}", newTokensTable, weightField);
    tx->Sort(newTokensTable, newTokensTable, {weightField});

    Log->info("Reading prior token dict from {}", priorDictTable);
    const auto& priorDict = ReadStringToWeightedIdDict(tx, priorDictTable);
    Log->info("Merging new tokens into the prior dict");
    MergeDictWithNewTokens(Log, tx, priorDict, newTokensTable, priorDictTable);
}

void TMakeUserDataWithDictJob::EncodeUserDataWithDicts(
        NYT::ITransactionPtr tx,
        const TString& srcUserDataTable,
        const TString& wordDictTable,
        const TString& hostDictTable,
        const TString& appDictTable,
        const TString& dstUserDataTable)
{
    Log->info("Reading new token dicts from {}, {}, and {}", wordDictTable, hostDictTable, appDictTable);
    auto wordDict = ReadStringToIdDict(tx, wordDictTable);
    auto hostDict = ReadStringToIdDict(tx, hostDictTable);
    auto appDict = ReadStringToIdDict(tx, appDictTable);

    Log->info("Encoding user data with the new dicts");

    auto spec = NYT::TMapOperationSpec()
        .AddInput<NLab::TUserData>(NYT::TRichYPath(srcUserDataTable).Schema(GetUserDataWithDictsSchema()).OptimizeFor(NYT::OF_SCAN_ATTR))
        .AddOutput<NLab::TUserData>(NYT::TRichYPath(dstUserDataTable).Schema(GetUserDataWithDictsSchema()).OptimizeFor(NYT::OF_SCAN_ATTR))
        .Ordered(true)
        .MapperSpec(NYT::TUserJobSpec().MemoryLimit(Config.GetConverterMemoryLimitMb() * 1024 * 1024));

    tx->Map(spec, new TEncodeUserDataMap(std::move(wordDict), std::move(hostDict), std::move(appDict)));

    SetYqlProtoFields<NLab::TUserData>(tx, dstUserDataTable);
}
