#include <util/generic/size_literals.h>
#include <util/string/split.h>
#include <mapreduce/yt/interface/client.h>
#include <mapreduce/yt/common/config.h>
#include <library/cpp/dot_product/dot_product.h>
#include <datacloud/dev_utils/data/custom_data_utils.h>
#include <datacloud/features/dssm/data/data.pb.h>

using namespace NYT;

namespace Datacloud {
namespace Features {
namespace DSSM {


class TFastDSSMMaxReducer : public IReducer<
    TTableReader<TDSSMRecord>, TTableWriter<TDSSMRecord>> 
{
public:
    void Do(TReader* reader, TWriter* writer) override {
        TVector<float> resultVector;
        TString key;
        for (; reader->IsValid(); reader->Next()) {
            const auto& row = reader->GetRow();
            const auto data = DataCloud::FromString(row.GetFeatures());
            key = row.GetKey();
            if (resultVector.size() != 0) {
                for (size_t idx = 0; idx < resultVector.size(); ++idx) {
                    resultVector[idx] = std::max(resultVector[idx], data[idx]);
                }
            } else {
                resultVector = data;
            }
        }
        TDSSMRecord outRow;
        outRow.SetKey(key);
        outRow.SetFeatures(DataCloud::ToString(resultVector));
        writer->AddRow(outRow);
    }
};
REGISTER_REDUCER(TFastDSSMMaxReducer);


void FastDSSMStep5(
        const TString& ytToken,
        const TString& cluster,
        const TString& inputTable,
        const TString& outputTable) {
    auto client = NYT::CreateClient(
        cluster, NYT::TCreateClientOptions().Token(ytToken));    

    const auto outputSchema = TTableSchema()
        .AddColumn(TColumnSchema().Type(VT_STRING).Name("key"))
        .AddColumn(TColumnSchema().Type(VT_STRING).Name("features"));

    client->MapReduce(
        TMapReduceOperationSpec()
            .ReduceBy({"key"})
            .AddInput<TDSSMRecord>(inputTable)
            .AddOutput<TDSSMRecord>(TRichYPath(outputTable)
                .Schema(outputSchema)),
         nullptr,
         new TFastDSSMMaxReducer,  // ReduceCombiner
         new TFastDSSMMaxReducer);
}

class TUniqueHashReducer : public IReducer<
    TTableReader<TDSSMYuid2TitleUrlRecord>,
    TTableWriter<TDSSMYuid2TitleUrlRecord> >
{
public:
    void Do(TReader* reader, TWriter* writer) override {
        for (; reader->IsValid(); reader->Next()) {
            auto row = reader->GetRow();
            row.SetYuids("");  // DO NOT REMOVE, required for the DSSM builder script
            writer->AddRow(row);
            break;
        }
    }
};
REGISTER_REDUCER(TUniqueHashReducer);

void FastUniqueHashReduce(
        const TString& ytToken,
        const TString& cluster,
        const TString& inputTable,
        const TString& outputTable) {
    auto client = NYT::CreateClient(
        cluster, NYT::TCreateClientOptions().Token(ytToken));
    const auto outputSchema = TTableSchema()
        .AddColumn(TColumnSchema().Type(VT_STRING).Name("title"))
        .AddColumn(TColumnSchema().Type(VT_STRING).Name("url"))
        .AddColumn(TColumnSchema().Type(VT_STRING).Name("yuids"))
        .AddColumn(TColumnSchema().Type(VT_STRING).Name("hash"));

    client->MapReduce(
        TMapReduceOperationSpec()
            .ReduceBy({"hash"})
            .AddInput<TDSSMYuid2TitleUrlRecord>(inputTable)
            .AddOutput<TDSSMYuid2TitleUrlRecord>(TRichYPath(outputTable)
                .Schema(outputSchema)),
        nullptr,  // reducecombiner
        new TUniqueHashReducer,
        new TUniqueHashReducer);
}

class TJoinCidsReducer : public IReducer<
    TTableReader<::google::protobuf::Message>,
    TTableWriter<TDSSMRecord> > 
{
public:
    void Do(TReader* reader, TWriter* writer) override
    {
        i64 counter = 0;
        const i64 maxCounter = 10000000;
        TDSSMRecord outputRecord;
        bool hasFeatures = false;
        for (; reader->IsValid(); reader->Next()) {
            const auto tableIndex = reader->GetTableIndex();
            if (tableIndex == 0) {
                const auto& row = reader->GetRow<THashVectorRecord>();
                outputRecord.SetFeatures(row.GetVector());
                hasFeatures = true;
            } else if (hasFeatures && counter < maxCounter) {
                ++counter;
                outputRecord.SetKey(reader->GetRow<THash2KeyRecord>().GetKey());
                writer->AddRow(outputRecord);
            } else {
                break;
            }
        }
    }
};
REGISTER_REDUCER(TJoinCidsReducer);

void FastJoinCidsReduce(
        const TString& ytToken,
        const TString& cluster,
        const TString& inputDotTable,
        const TString& inputHash2KeyTable,
        const TString& outputTable) {
    auto client = NYT::CreateClient(
        cluster, NYT::TCreateClientOptions().Token(ytToken));
    const auto outputSchema = TTableSchema()
        .AddColumn(TColumnSchema().Type(VT_STRING).Name("key"))
        .AddColumn(TColumnSchema().Type(VT_STRING).Name("features"));

    client->JoinReduce(
        TJoinReduceOperationSpec()
            .JoinBy({"hash"})
            .AddInput<THashVectorRecord>(
                TRichYPath(inputDotTable)
                .Foreign(true))
            .AddInput<THash2KeyRecord>(inputHash2KeyTable)
            .AddOutput<TDSSMRecord>(TRichYPath(outputTable)
                .Schema(outputSchema)),
        new TJoinCidsReducer);
}

class TDotProductMapper : public IMapper<
     TTableReader<THashVectorRecord>,
     TTableWriter<THashVectorRecord> > 
{
public:
    Y_SAVELOAD_JOB(query2Vectors_);
    TDotProductMapper() = default;
    TDotProductMapper(TVector<TVector<float> > query2Vectors) :
        query2Vectors_(std::move(query2Vectors)) {}
    void Do(TReader* reader, TWriter* writer) override
    {
        for (; reader->IsValid(); reader->Next()) {
            auto record = reader->GetRow();
            TVector<float> vector;
            StringSplitter(record.GetVector()).Split(' ').ParseInto(&vector);
            TVector<float> dotProduct(query2Vectors_.size(), 0);
            for (size_t ii = 0; ii < query2Vectors_.size(); ++ii) {
                dotProduct[ii] = DotProduct(
                    vector.cbegin(),
                    query2Vectors_[ii].cbegin(),
                    vector.size()
                );
            }
            record.SetVector(DataCloud::ToString(dotProduct));
            writer->AddRow(record);
        }
    }
private:
    TVector<TVector<float> > query2Vectors_;
};
REGISTER_MAPPER(TDotProductMapper);

void FastDotProduct(
        const TString& ytToken,
        const TString& cluster,
        const TVector<TVector<float> >& clusterCenters,
        const TString& inputTable,
        const TString& outputTable) {
    auto client = NYT::CreateClient(
        cluster, NYT::TCreateClientOptions().Token(ytToken));
    const auto outputSchema = TTableSchema()
        .AddColumn(TColumnSchema().Type(VT_STRING).Name("hash"))
        .AddColumn(TColumnSchema().Type(VT_STRING).Name("vector"));

    client->Map(
        TMapOperationSpec()
            .AddInput<THashVectorRecord>(inputTable)
            .AddOutput<THashVectorRecord>(TRichYPath(outputTable)
                .Schema(outputSchema)),
        new TDotProductMapper(clusterCenters));
}

}  // namespace DSSM
}  // namespace Features
}  // namespace Datacloud
