#include <util/digest/fnv.h>
#include <util/generic/size_literals.h>
#include <util/memory/blob.h>
#include <util/string/reverse.h>

#include <library/cpp/containers/comptrie/prefix_iterator.h>
#include <library/cpp/tld/tld.h>

#include <robot/library/yt/static/command.h>
#include <robot/library/yt/static/tags.h>

#include <wmconsole/version3/wmcutil/url.h>
#include <wmconsole/version3/wmcutil/yt/yt_runner.h>

#include <wmconsole/version3/library/conf/yt.h>

#include <wmconsole/version3/searchqueries-mr/conf/yt.h>
#include <wmconsole/version3/searchqueries-mr/protos/host2vec.pb.h>
#include <wmconsole/version3/searchqueries-mr/protos/user_sessions.pb.h>
#include <wmconsole/version3/searchqueries-mr/tools/host2vec-model/protos/embedding.pb.h>
#include <wmconsole/version3/searchqueries-mr/tools/host2vec-model/utils/utils.h>
#include <wmconsole/version3/wmcutil/yt/triggers.h>

#include "host2vec.h"

namespace NWebmaster {
namespace NHost2Vec {

using namespace NJupiter;

#define OPERATION_WEIGHT 1.0f

namespace {

const char *F_COSINE = "Cosine";
}

struct THost2VecMapper : public NYT::IMapper<NYT::TTableReader<NYT::TNode>, NYT::TTableWriter<NProto::TGroup>> {
    Y_SAVELOAD_JOB(HostFieldName, Words, Vectors, MaxAnalogies)

public:
    THost2VecMapper() = default;
    THost2VecMapper(const TString &hostFieldName, const TString &words, const TString &vectors, size_t maxAnalogies)
        : HostFieldName(hostFieldName)
        , Words(words)
        , Vectors(vectors)
        , MaxAnalogies(maxAnalogies)
    {
    }

    void Start(TWriter *) override {
        VectorsBlob = TBlob::FromFileContent(Vectors);
        TFileInput wordsStream(Words);
        Model.Reset(new NWord2Vec::TModel());
        Model->LoadFromYandex(&wordsStream, VectorsBlob);
        Searcher.Reset(new TBruteforceSearcher("1"));
        Searcher->SetModels(Model.Get(), Model.Get(), false /*normalized*/);
    }

    void Do(TReader *input, TWriter *output) override {
        NProto::TGroup dstMsg;
        for (; input->IsValid(); input->Next()) {
            const TString url = input->GetRow()[HostFieldName].AsString();
            TString host, path;
            TUtf16String wHost;

            if (!NUtils::SplitUrl(url, host, path)) {
                Cerr << "unable to parse url: " << url << Endl;
                continue;
            }

            host = TString{NUtils::GetHost2vecDomain(host)};
            if (NTld::IsTld(host) || host.size() < 3) {
                continue;
            }

            wHost = TUtf16String::FromAscii(host);

            if (!Model->Has(wHost)) {
                dstMsg.SetGroup(host);
                dstMsg.SetHost(host);
                dstMsg.SetCosine(1.0);
                dstMsg.SetMainInGroup(true);
                output->AddRow(dstMsg);
                continue;
            }

            TVector<TUtf16String> words = { wHost };
            TVector<TWordQuality> results = Searcher->FindBestMatches(words, MaxAnalogies, false/* debug*/, 1);
            Sort(results.begin(), results.end());

            for (size_t i = 0; i < results.size(); ++i) {
                const TString analogy = WideToUTF8(results[i].Word);

                if (NTld::IsTld(analogy) || analogy.size() < 3) {
                    continue;
                }

                if (NUtils::IsSubdomain(analogy, host) && analogy != host) {
                    continue;
                }

                dstMsg.SetGroup(host);
                dstMsg.SetHost(analogy);
                dstMsg.SetCosine(results[i].Quality);
                dstMsg.SetMainInGroup(i == 0);
                output->AddRow(dstMsg);
            }
        }
    }

public:
    TString HostFieldName;
    TString Words;
    TString Vectors;
    size_t MaxAnalogies = 300;

    TSimpleSharedPtr<NWord2Vec::TModel> Model;
    THolder<TSearcher> Searcher;
    TBlob VectorsBlob;
};

REGISTER_MAPPER(THost2VecMapper)

struct TReduceHostGroups: public NYT::IReducer<NYT::TTableReader<NProto::TGroup>, NYT::TTableWriter<NProto::TGroupHashes>> {
    void Do(TReader *input, TWriter *output) override {
        TSet<ui32> groupHashesSet;
        const TString host = input->GetRow().GetHost();

        for (; input->IsValid(); input->Next()) {
            const auto &row = input->GetRow();
            const TString &group = row.GetGroup();
            groupHashesSet.insert(FnvHash<ui32>(group.data(), group.size()));
        }

        TVector<ui32> groupHashes(groupHashesSet.begin(), groupHashesSet.end());
        TString stream;
        stream.assign(reinterpret_cast<char*>(&groupHashes[0]), groupHashes.size() * sizeof(ui32));
        if ((stream.size() % sizeof(ui32)) != 0) {
            ythrow yexception() << "broken seq " << stream.size();
        }

        NProto::TGroupHashes dstMsg;
        dstMsg.SetHost(host);
        dstMsg.SetGroups(stream);
        output->AddRow(dstMsg);
    }
};

REGISTER_REDUCER(TReduceHostGroups)

using TReduceGetColumnHash32 = TReduceGetColumnHash<ui32>;
REGISTER_REDUCER(TReduceGetColumnHash32)

//SortBy Host
struct TStatisticsMapper : public NYT::IMapper<NYT::TTableReader<NUserSessions::NProto::TStatistics>, NYT::TTableWriter<NHost2Vec::NProto::TStatistics>> {
    void Do(TReader *input, TWriter *output) override {
        struct TCounter {
            void Add(const NUserSessions::NProto::TStatistics &row) {
                Clicks += row.GetClicks();
                Shows += row.GetShows();
            }

        public:
            size_t Clicks = 0;
            size_t Shows = 0;
        };

        THashMap<TString, TCounter> counters;
        for (; input->IsValid(); input->Next()) {
            const auto &row = input->GetRow();
            const TString host = TString{NUtils::GetHost2vecDomain(row.GetHost())};
            counters[host].Add(row);
        }

        NHost2Vec::NProto::TStatistics stats;
        for (const auto &counter : counters) {
            stats.SetHost(counter.first);
            stats.SetClicks(counter.second.Clicks);
            stats.SetShows(counter.second.Shows);
            output->AddRow(stats);
        }
    }
};

REGISTER_MAPPER(TStatisticsMapper)

//ReduceBy Host
struct TStatisticsReducer : public NYT::IReducer<NYT::TTableReader<NHost2Vec::NProto::TStatistics>, NYT::TTableWriter<NHost2Vec::NProto::TStatistics>> {

public:
    void Do(TReader *input, TWriter *output) override {
        const TString host = input->GetRow().GetHost();
        size_t clicks = 0;
        size_t shows = 0;
        for (; input->IsValid(); input->Next()) {
            const auto &row = input->GetRow();
            clicks += row.GetClicks();
            shows += row.GetShows();
        }

        NHost2Vec::NProto::TStatistics stats;
        stats.SetHost(host);
        stats.SetClicks(clicks);
        stats.SetShows(shows);
        output->AddRow(stats);
    }
};

REGISTER_REDUCER(TStatisticsReducer)

struct TModelFilterMapper : public NYT::IMapper<NYT::TTableReader<NHost2Vec::NProto::TEmbedding>, NYT::TTableWriter<NHost2Vec::NProto::TEmbedding>> {
    Y_SAVELOAD_JOB(HostsWithTraffic)

public:
    TModelFilterMapper() = default;
    TModelFilterMapper(const THashMap<TString, ui32> &hostsWithTraffic)
        : HostsWithTraffic(hostsWithTraffic)
    {
    }

    void Do(TReader *input, TWriter *output) override {
        for (; input->IsValid(); input->Next()) {
            auto row = input->GetRow();
            if (HostsWithTraffic.contains(row.Gethost())) {
                row.Setno(HostsWithTraffic.at(row.Gethost()));
                output->AddRow(row);
            }
        }
    }

public:
    THashMap<TString, ui32> HostsWithTraffic;
};

REGISTER_MAPPER(TModelFilterMapper)

void UpdateTrafficFilterTable(NYT::IClientBasePtr client, const TString &fltTraffic) {
    const int DAYS = 93;

    NYT::ITransactionPtr tx = client->StartTransaction();

    TDeque<NYTUtils::TTableInfo> tables;
    NYTUtils::GetTableList(tx, TCommonYTConfigSQ::CInstance().TABLE_PARSED_USER_SESSIONS_STATS_DAILY_ROOT, tables);
    std::sort(tables.rbegin(), tables.rend(), NYTUtils::TTableInfo::TNameLess());
    if (tables.size() > DAYS) {
        tables.resize(DAYS);
    }

    TDeque<TTable<NUserSessions::NProto::TStatistics>> inputTables;
    for (const auto &table : tables) {
        inputTables.push_back(TTable<NUserSessions::NProto::TStatistics>(tx, table.Name));
    }

    TMapCombineReduceCmd<TStatisticsMapper, TStatisticsReducer, TStatisticsReducer>(tx)
        .OperationWeight(OPERATION_WEIGHT)
        .MapperMemoryLimit(1_GBs)
        .Inputs(inputTables)
        .Output(TTable<NHost2Vec::NProto::TStatistics>(tx, fltTraffic))
        .ReduceBy({"Host"})
        .Do()
    ;

    TSortCmd<NHost2Vec::NProto::TStatistics>(tx, TTable<NHost2Vec::NProto::TStatistics>(tx, fltTraffic))
        .OperationWeight(OPERATION_WEIGHT)
        .By({"Host"})
        .Do()
    ;

    tx->Commit();
}

void FilterModelTable(NYT::IClientBasePtr client, const TString &fltTraffic, const TString &srcTable, const TString &fltTable) {
    NYT::ITransactionPtr tx = client->StartTransaction();

    THashMap<TString, ui32> srcHosts;
    auto srcReader = TTable<NHost2Vec::NProto::TEmbedding>(tx, srcTable).SelectFields({"host"}).GetReader();
    for (; srcReader->IsValid(); srcReader->Next()) {
        srcHosts[srcReader->GetRow().Gethost()] = srcHosts.size();
    }

    THashMap<TString, ui32> hostsWithTraffic = {
        {"</s>", 0}  //default word2vec token
    };
    auto statReader = TTable<NHost2Vec::NProto::TStatistics>(tx, fltTraffic).GetReader();
    for (; statReader->IsValid(); statReader->Next()) {
        const TString &host = statReader->GetRow().GetHost();
        if (srcHosts.contains(host)) {
            hostsWithTraffic[host] = srcHosts.at(host);
        }
    }

    TMapCmd<TModelFilterMapper>(tx, new TModelFilterMapper(hostsWithTraffic))
        .MemoryLimit(2_GBs)
        .OperationWeight(OPERATION_WEIGHT)
        .Input(TTable<NHost2Vec::NProto::TEmbedding>(tx, srcTable))
        .Output(TTable<NHost2Vec::NProto::TEmbedding>(tx, fltTable))
        .Do()
    ;

    TSortCmd<NHost2Vec::NProto::TEmbedding>(tx, TTable<NHost2Vec::NProto::TEmbedding>(tx, fltTable))
        .OperationWeight(OPERATION_WEIGHT)
        .By({"no"})
        .Do()
    ;

    tx->Commit();
}

void FilterModelByTraffic(NYT::IClientBasePtr client, const TString &srcHost2vecRoot) {
    const char *NAME_TABLE          = "table";
    const char *NAME_WORDS          = "words";
    const char *NAME_VECTORS        = "vectors";
    const TString srcTable          = NYTUtils::JoinPath(srcHost2vecRoot, NAME_TABLE);

    const TString fltHost2vecRoot   = NYTUtils::JoinPath(TCommonYTConfig::CInstance().FILE_SOURCE_MODELS_ROOT, "host2vec-filtered");
    const TString fltTable          = NYTUtils::JoinPath(fltHost2vecRoot, NAME_TABLE);
    const TString fltTraffic        = NYTUtils::JoinPath(fltHost2vecRoot, "traffic");
    const TString fltFileWords      = NYTUtils::JoinPath(fltHost2vecRoot, NAME_WORDS);
    const TString fltFileVectors    = NYTUtils::JoinPath(fltHost2vecRoot, NAME_VECTORS);

    NYT::ITransactionPtr tx = client->StartTransaction();
    TYtTimeTrigger updateAttrModel(tx, fltHost2vecRoot);
    if (updateAttrModel.NeedUpdate()) {
        LOG_INFO("host2vec, filtering model");
        NYTUtils::CreatePath(tx, fltHost2vecRoot);
        UpdateTrafficFilterTable(tx, fltTraffic);
        FilterModelTable(tx, fltTraffic, srcTable, fltTable);
        UpdateFilteredModel(tx, fltTable, fltFileWords, fltFileVectors);
        updateAttrModel.Update();
        LOG_INFO("host2vec, filtering model - done");
    }
    tx->Commit();
}

void TaskUpdateHost2vecGroups(NYT::IClientBasePtr client, const TString &hostFieldName, const TTableConfig &host2VecConfig) {
    //source model "browser sessions" https://wiki.yandex-team.ru/jandekspoisk/kachestvopoiska/relevance/word2vec/
    const char *NAME_WORDS          = "words";
    const char *NAME_VECTORS        = "vectors";
    const TString srcHost2vecRoot   = NYTUtils::JoinPath(TCommonYTConfig::CInstance().FILE_SOURCE_MODELS_ROOT, "host2vec");
    const TString srcFileWords      = NYTUtils::JoinPath(srcHost2vecRoot, NAME_WORDS);
    const TString srcFileVectors    = NYTUtils::JoinPath(srcHost2vecRoot, NAME_VECTORS);

    FilterModelByTraffic(client, srcHost2vecRoot);

    NYT::ITransactionPtr tx = client->StartTransaction();
    TYtTimeTrigger updateAttrGroups(tx, host2VecConfig.Hosts2Vec);
    if (!updateAttrGroups.NeedUpdate()) {
        return;
    }

    LOG_INFO("host2vec, updating groups");

    TMapCmd<THost2VecMapper>(tx, new THost2VecMapper(hostFieldName, NAME_WORDS, NAME_VECTORS, 300))
        .OperationWeight(OPERATION_WEIGHT)
        .Input<NYT::TNode>(host2VecConfig.SourceHosts)
        .Output(TTable<NProto::TGroup>(tx, host2VecConfig.Hosts2Vec))
        .JobCount(100000)
        .AddYtFile(srcFileWords)
        .AddYtFile(srcFileVectors)
        .MemoryLimit(5_GBs)
        .Do()
    ;

    TSortCmd<NProto::TGroup>(tx, TTable<NProto::TGroup>(tx, host2VecConfig.Hosts2Vec))
        .OperationWeight(OPERATION_WEIGHT)
        .By({F_GROUP, F_COSINE})
        .Do()
    ;

    DoParallel(
        TCombineReduceCmd<TReduceHostGroups, TReduceHostGroups>(tx, nullptr, new TReduceHostGroups)
            .Input(TTable<NProto::TGroup>(tx, host2VecConfig.Hosts2Vec))
            .Output(TTable<NProto::TGroupHashes>(tx, host2VecConfig.HostGroupsHash))
            .ReduceBy(F_HOST),

        TCombineReduceCmd<TReduceGetColumnHash32, TReduceGetColumnHash32>(tx, nullptr, new TReduceGetColumnHash32(F_HOST, F_HASH))
            .Input<NYT::TNode>(host2VecConfig.Hosts2Vec)
            .Output<NYT::TNode>(NYT::TRichYPath(host2VecConfig.HostsHash))
            .ReduceBy(F_HOST),

        TCombineReduceCmd<TReduceGetColumnHash32, TReduceGetColumnHash32>(tx, nullptr, new TReduceGetColumnHash32(F_GROUP, F_HASH))
            .Input<NYT::TNode>(host2VecConfig.Hosts2Vec)
            .Output<NYT::TNode>(NYT::TRichYPath(host2VecConfig.GroupsHash))
            .ReduceBy(F_GROUP)
    );

    DoParallel(
        TSortCmd<NYT::TNode>(tx)
            .Input<NYT::TNode>(host2VecConfig.HostGroupsHash)
            .Output<NYT::TNode>(host2VecConfig.HostGroupsHash)
            .OperationWeight(OPERATION_WEIGHT)
            .By({F_HOST}),

        TSortCmd<NYT::TNode>(tx)
            .Input<NYT::TNode>(host2VecConfig.HostsHash)
            .Output<NYT::TNode>(host2VecConfig.HostsHash)
            .OperationWeight(OPERATION_WEIGHT)
            .By({F_HOST}),

        TSortCmd<NYT::TNode>(tx)
            .Input<NYT::TNode>(host2VecConfig.GroupsHash)
            .Output<NYT::TNode>(host2VecConfig.GroupsHash)
            .OperationWeight(OPERATION_WEIGHT)
            .By({F_GROUP})
    );

    updateAttrGroups.Update();
    tx->Commit();

    LOG_INFO("host2vec, updating groups - done");
}

void BuildHostToGroupTrie(NYT::IClientBasePtr client, const TString &hostGroupsHashTable, THashMap<TString, TVector<ui32>> &ownerToGroups, TVector<char> &trieStream) {
    TCompactTrie<char>::TBuilder trieBuilder;
    for (auto reader = client->CreateTableReader<NYT::TNode>(hostGroupsHashTable); reader->IsValid(); reader->Next()) {
        const NYT::TNode &row = reader->GetRow();
        const TString host = row[F_HOST].AsString();
        const TString groups = row[F_GROUPS].AsString();
        //const bool isMain = row[F_MAIN_IN_GROUP].AsString();
        if (NTld::IsTld(host) || host.size() < 3) {
            continue;
        }

        GetGroupsHashes(groups, ownerToGroups[host]);
        TString rhost = host;
        ReverseInPlace(rhost);
        trieBuilder.Add(rhost, ownerToGroups.size());
    }

    LOG_INFO("direct, loaded %lu owner groups", ownerToGroups.size());

    TBufferStream data;
    trieBuilder.SaveAndDestroy(data);
    trieStream.assign(data.Buffer().Data(), data.Buffer().Data() + data.Buffer().Size());
    LOG_INFO("direct, built trie with size %lu bytes", trieStream.size());
}

void GetGroupsHashes(const TString &groupsStr, TVector<ui32> &groups) {
    const ui32* groupsData = reinterpret_cast<const ui32*>(groupsStr.data());
    groups.assign(groupsData, groupsData + groupsStr.size() / sizeof(ui32));
}

void LoadGroupsHashes(NYT::IClientBasePtr client, const TString &groupsHashTable, THashMap<ui32, TString> &hashToGroup) {
    for (auto reader = client->CreateTableReader<NYT::TNode>(groupsHashTable); reader->IsValid(); reader->Next()) {
        const NYT::TNode &row = reader->GetRow();
        const TString group = row[F_GROUP].AsString();
        const ui32 hash = row[F_HASH].AsUint64();
        hashToGroup[hash] = group;
    }

    LOG_INFO("direct, loaded %lu group hashes", hashToGroup.size());
}

} // namespace NHost2Vec
} // namespace NWebmaster
