#include <util/digest/fnv.h>
#include <util/generic/hash_set.h>
#include <util/generic/size_literals.h>
#include <util/random/random.h>

#include <wmconsole/version3/library/jupiter/jupiter.h>
#include <wmconsole/version3/wmcutil/hostid.h>
#include <wmconsole/version3/wmcutil/url.h>
#include <wmconsole/version3/wmcutil/yt/triggers.h>
#include <robot/library/yt/static/command.h>
#include <wmconsole/version3/wmcutil/yt/yt_runner.h>
#include <wmconsole/version3/processors/tools/IKS/utils/canonizer.h>
#include <wmconsole/version3/processors/tools/host2vec/utils/utils.h>
#include <wmconsole/version3/processors/user_sessions/library/regions_limiter.h>
#include <wmconsole/version3/processors/user_sessions/niche2/conf/config.h>
#include <wmconsole/version3/processors/user_sessions/niche2/preparation/tables.pb.h>
#include <robot/library/yt/static/tags.h>
#include <library/cpp/text_processing/tokenizer/tokenizer.h>
#include <wmconsole/version3/library/dssm/dssm_utils.h>
#include <wmconsole/version3/processors/user_sessions/library/utils.h>
#include <library/cpp/compute_graph/compute_graph.h>
#include <wmconsole/version3/processors/user_sessions/protos/user_sessions.pb.h>

#include <wmconsole/version3/processors/user_sessions/niche/conf/config.h>

#include "task_get_cluster_names.h"

namespace NWebmaster {
namespace NNiche2 {
using namespace NJupiter;
using namespace NProto;

static const TInputTag<TQueryToCluster> QueryToClusterInputTag(1);
static const TInputTag<TClusterInfo> ClusterInfoInputTag(2);

static const TOutputTag<TClusterInfo> ClusterInfoOutputTag(1);

struct TGetClusterInfoMapper: public NYT::IMapper<NYT::TTableReader<NYT::TNode>, NYT::TTableWriter<TClusterInfo>> {
    void Do(TReader *input, TWriter *output) override {
        TClusterInfo dstMsg;
        for (; input->IsValid(); input->Next()) {
            auto &row = input->GetRow();
            dstMsg.SetClusterId(FromString<int64_t>(row["ClusterID"].AsString()));
            dstMsg.SetClusterName(row["ClusterName"].AsString());
            output->AddRow(dstMsg);
        }
    }
};
REGISTER_MAPPER(TGetClusterInfoMapper)

struct TSampleClusterInfoMapper: public NYT::IMapper<NYT::TTableReader<TQueryToCluster>, NYT::TTableWriter<TQueryToCluster>> {
    void Do(TReader *input, TWriter *output) override {
        for (; input->IsValid(); input->Next()) {
            auto &row = input->GetRow();
            if (RandomNumber<float>() <= 0.001) {
                output->AddRow(row);
            }
        }
    }
};
REGISTER_MAPPER(TSampleClusterInfoMapper)

// ClusterId
struct TCalculateCentroidReducer: public TTaggedReducer {
public:
    void StartTagged(TTagedWriter) override final {
        TDssmL3Model::Ptr dssmL3Model(new TDssmL3Model);
        DssmCtr.Reset(new TDssmCtr(dssmL3Model));
    }

    void DoTagged(TTagedReader reader, TTagedWriter writer) override {
        TClusterInfo dstMsg;
        TVector<float> sumDssm;
        uint64_t total = 0;
        while (reader.IsValid()) {
            for (auto &row: reader.GetRows(QueryToClusterInputTag)) {
                if (!IsUtf(row.GetQueryText())) {
                    continue;
                }
                TVector<float> dssm;
                DssmCtr->GetQueryEmbedding(row.GetQueryText(), dssm);
                if (sumDssm.empty()) {
                    sumDssm = dssm;
                } else {
                    for (size_t i = 0; i < dssm.size(); ++i) {
                        sumDssm[i] += dssm[i];
                    }
                }
                total += 1;
            }
            for (auto &row: reader.GetRows(ClusterInfoInputTag)) {
                dstMsg.SetClusterId(row.GetClusterId());
                dstMsg.SetClusterName(row.GetClusterName());
            }
        }
        if (total != 0) {
            for (auto &value: sumDssm) {
                dstMsg.AddCentroidDssm(value / total);
            }
        }
        writer.AddRow(dstMsg, ClusterInfoOutputTag);
    }
private:
    THolder<TDssmCtr> DssmCtr;
};
REGISTER_REDUCER(TCalculateCentroidReducer)

int TaskGetClusterNames(int, const char **) {
    const auto& cfg = TConfig::CInstance();

    auto client = NYT::CreateClient(cfg.MR_SERVER_HOST);

    auto tx = client->StartTransaction();

    NComputeGraph::TJobRunner runner;

    auto mappedClusterInfo = runner.AddJob([&]() {
        TMapCmd<TGetClusterInfoMapper>(tx)
            .Input(TTable<NYT::TNode>(tx, cfg.TABLE_SEARCHSHARE_CLUSTERS))
            .Output(TTable<TClusterInfo>(tx, cfg.TABLE_SEARCHSHARE_CLUSTERS_INFO))
            .Do();
    });

    auto sampledClusterInfo = runner.AddJob([&]() {
        TMapCmd<TSampleClusterInfoMapper>(tx)
            .Input(TTable<TQueryToCluster>(tx, cfg.TABLE_SEARCHSHARE_QUERY_TO_CLUSTER_HASHED))
            .Output(TTable<TQueryToCluster>(tx, cfg.TABLE_SEARCHSHARE_QUERY_TO_CLUSTER_SAMPLED))
            .Do();
    });

    auto orderedClsuterInfo = runner.AddJob([&]() {
        TSortCmd<TClusterInfo>(tx)
            .Input<TClusterInfo>(cfg.TABLE_SEARCHSHARE_CLUSTERS_INFO)
            .Output<TClusterInfo>(cfg.TABLE_SEARCHSHARE_CLUSTERS_INFO)
            .By({"ClusterId"})
            .Do();
    }, { mappedClusterInfo });

    auto orderedSampledClusterInfo = runner.AddJob([&]() {
        TSortCmd<TQueryToCluster>(tx)
            .Input<TQueryToCluster>(cfg.TABLE_SEARCHSHARE_QUERY_TO_CLUSTER_SAMPLED)
            .Output<TQueryToCluster>(cfg.TABLE_SEARCHSHARE_QUERY_TO_CLUSTER_SAMPLED)
            .By({"ClusterId"})
            .Do();
    }, { sampledClusterInfo });

    auto calculatedCentroid = runner.AddJob([&]() {
        TReduceCmd<TCalculateCentroidReducer>(tx)
            .Input(TTable<TQueryToCluster>(tx, cfg.TABLE_SEARCHSHARE_QUERY_TO_CLUSTER_SAMPLED), QueryToClusterInputTag)
            .Input(TTable<TClusterInfo>(tx, cfg.TABLE_SEARCHSHARE_CLUSTERS_INFO), ClusterInfoInputTag)
            .Output(TTable<TClusterInfo>(tx, cfg.TABLE_SEARCHSHARE_CLUSTERS_FULL_INFO), ClusterInfoOutputTag)
            .AddYtFile(cfg.DSSM_MODEL)
            .MemoryLimit(2_GBs)
            .ReduceBy({"ClusterId"})
            .Do();
    }, { orderedClsuterInfo, orderedSampledClusterInfo });

    runner.AddJob([&]() {
        TSortCmd<TClusterInfo>(tx)
            .Input<TClusterInfo>(cfg.TABLE_SEARCHSHARE_CLUSTERS_FULL_INFO)
            .Output<TClusterInfo>(cfg.TABLE_SEARCHSHARE_CLUSTERS_FULL_INFO)
            .By({"ClusterId"})
            .Do();
    }, { calculatedCentroid });

    runner.Run();

    tx->Commit();

    return 0;
}

} //namespace NNiche2
} //namespace NWebmaster
