#include <util/generic/size_literals.h>
#include <mapreduce/yt/interface/client.h>
#include <mapreduce/yt/common/config.h>
#include <datacloud/dev_utils/data/custom_data_utils.h>
#include <datacloud/ml_utils/kmeans/lib/data.pb.h>
#include <datacloud/dev_utils/data/custom_data_utils.h>
#include <library/cpp/l2_distance/l2_distance.h>
#include <limits>


using namespace NYT;

namespace Datacloud {
namespace MLUtils {
namespace KMeans {

class TKMeansPrediction : public IMapper <
    TTableReader<TObjectRecord>,
    TTableWriter<TObjectRecord> >
{
public:
    Y_SAVELOAD_JOB(NClusters_);

    TKMeansPrediction() = default;

    TKMeansPrediction(i64 nClusters) : NClusters_(nClusters) { }

    void Start(TWriter*) override {
        TFileInput tableDump("centers_table");
        auto reader = CreateTableReader<TNode>(&tableDump);

        for (; reader->IsValid(); reader->Next()) {
            const auto& row = reader->GetRow();
            const auto cluster = row["cluster"].AsInt64();
            centers[cluster] = DataCloud::FromString(row["data"].AsString());
        }
    }

    void Do(TReader* reader, TWriter* writer) override {
        for (; reader->IsValid(); reader->Next()) {
            auto record = reader->GetRow();
            const auto& data = DataCloud::FromString(record.GetData());
            i64 closestCluster = 0;
            double minDist = std::numeric_limits<float>::infinity();
            for (const auto& clustRec : centers) {
                const auto cluster = clustRec.first;
                const auto dist = L2SqrDistance(
                    data.cbegin(), centers[cluster].cbegin(), data.size());
                if (dist < minDist) {
                    minDist = dist;
                    closestCluster = cluster;
                } 
            }
            if (record.GetCluster() != closestCluster) {
                WriteCustomStatistics("kmeans-stat/clusters-changed", 1);
            }
            record.SetCluster(closestCluster);
            writer->AddRow(record);
        }
    }

private:
    TMap<i64, TVector<float> > centers;
    i64 NClusters_;
};
REGISTER_MAPPER(TKMeansPrediction);

i64 Expectation(  // TODO: Rename to Prediction
        const TString& ytToken,
        const TString& cluster,
        i64 nClusters,
        const TString& clusterCentersTable,
        const TString& objectsTable,
        const TString& outputExpectationTable) {
    auto client = NYT::CreateClient(
        cluster, NYT::TCreateClientOptions().Token(ytToken));

    const auto outputSchema = TTableSchema()
        .AddColumn(TColumnSchema().Type(VT_INT64).Name("cluster"))
        .AddColumn(TColumnSchema().Type(VT_STRING).Name("key"))
        .AddColumn(TColumnSchema().Type(VT_STRING).Name("data"));

    auto operation = client->Map(
        TMapOperationSpec()
            .MapperSpec(
                TUserJobSpec()
                .AddFile(
                    TRichYPath(clusterCentersTable)
                    .Format("yson")
                    .FileName("centers_table")
                ))
        .AddInput<TObjectRecord>(objectsTable)
        .AddOutput<TObjectRecord>(TRichYPath(outputExpectationTable)
            .Schema(outputSchema)),
        new TKMeansPrediction(nClusters)  // TKMeansExpectation(nClusters)
    );
    return operation->GetJobStatistics().GetCustomStatistics("kmeans-stat/clusters-changed").Sum().GetOrElse(-1);
}


//=//=//
class TKMeansExpectation : public IMapper <
    TTableReader<TObjectRecord>,
    TTableWriter<TClusterCenterRecord> >
{
public:
    Y_SAVELOAD_JOB(NClusters_);

    TKMeansExpectation() = default;

    TKMeansExpectation(i64 nClusters) : NClusters_(nClusters) { }

    void Start(TWriter*) override {
        TFileInput tableDump("centers_table");
        auto reader = CreateTableReader<TNode>(&tableDump);

        for (; reader->IsValid(); reader->Next()) {
            const auto& row = reader->GetRow();
            const auto cluster = row["cluster"].AsInt64();
            centers[cluster] = DataCloud::FromString(row["data"].AsString());
        }
    }

    void Do(TReader* reader, TWriter* writer) override {
        for (; reader->IsValid(); reader->Next()) {
            auto record = reader->GetRow();
            const auto& data = DataCloud::FromString(record.GetData());
            i64 closestCluster = 0;
            double minDist = std::numeric_limits<float>::infinity();
            for (const auto& clustRec : centers) {
                const auto cluster = clustRec.first;
                const auto dist = L2SqrDistance(
                    data.cbegin(), centers[cluster].cbegin(), data.size());
                if (dist < minDist) {
                    minDist = dist;
                    closestCluster = cluster;
                } 
            }
            if (record.GetCluster() != closestCluster) {
                WriteCustomStatistics("kmeans-stat/clusters-changed", 1);
            }
            TClusterCenterRecord outputRecord;
            outputRecord.SetCluster(closestCluster);
            outputRecord.SetData(record.GetData());
            outputRecord.SetNPoints(1);
            writer->AddRow(outputRecord);
        }
    }

private:
    TMap<i64, TVector<float> > centers;
    i64 NClusters_;
};
REGISTER_MAPPER(TKMeansExpectation);
//=//=//


class TKMeansMinimization : public IReducer<
    TTableReader<TClusterCenterRecord>,
    TTableWriter<TClusterCenterRecord>
    >
{
public:
    void Do(TReader* reader, TWriter* writer) override {
        TVector<float> coordSum;
        i64 nItems = 0;
        i64 cluster;
        for (; reader->IsValid(); reader->Next()) {
            const auto& record = reader->GetRow();
            const auto& data = DataCloud::FromString(record.GetData());
            cluster = record.GetCluster();
            if (coordSum.empty()) {
                coordSum = data;
            } else {
                // TODO: Look for SIMD summation in Arcadia
                for (size_t idx = 0; idx < coordSum.size(); ++idx) {
                    coordSum[idx] += data[idx];
                }
            }
            nItems += record.GetNPoints();
        }
        TClusterCenterRecord outputRecord;
        outputRecord.SetCluster(cluster);
        outputRecord.SetData(DataCloud::ToString(coordSum));
        outputRecord.SetNPoints(nItems);
        writer->AddRow(outputRecord);
    }
};
REGISTER_REDUCER(TKMeansMinimization);


class TKMeansNormalization : public IMapper<
    TTableReader<TClusterCenterRecord>,
    TTableWriter<TClusterCenterRecord>
    >
{
public:
    void Do(TReader* reader, TWriter* writer) override {
        for (; reader->IsValid(); reader->Next()) {
            auto record = reader->GetRow();
            auto data = DataCloud::FromString(record.GetData());
            const auto nPoints = record.GetNPoints();
            for (size_t idx = 0; idx < data.size(); ++idx) {
                data[idx] /= nPoints;
            }
            record.SetNPoints(1);
            record.SetData(DataCloud::ToString(data));
            writer->AddRow(record);
        }
    }
};
REGISTER_MAPPER(TKMeansNormalization);


void Minimization(
        const TString& ytToken,
        const TString& cluster,
        const TString& expectationTable,
        const TString& outputClusterCenterTable) {
    auto client = NYT::CreateClient(
        cluster, NYT::TCreateClientOptions().Token(ytToken));
    const auto outputSchema = TTableSchema()
        .AddColumn(TColumnSchema().Type(VT_INT64).Name("cluster"))
        .AddColumn(TColumnSchema().Type(VT_STRING).Name("data"));

    client->Reduce(
        TReduceOperationSpec()
            .ReduceBy({"cluster"})
            .AddInput<TObjectRecord>(expectationTable)
            .AddOutput<TClusterCenterRecord>(
                TRichYPath(outputClusterCenterTable)
                    .Schema(outputSchema)),
        new TKMeansMinimization
    );
}


void ExpectationMinimization(
        const TString& ytToken,
        const TString& cluster,
        i64 nClusters,
        const TString& objectsTable,
        const TString& clusterCentersTable,
        const TString& outputClusterCenterTable) {
    auto client = NYT::CreateClient(
        cluster, NYT::TCreateClientOptions().Token(ytToken));
    const auto outputSchema = TTableSchema()
        .AddColumn(TColumnSchema().Type(VT_INT64).Name("cluster"))
        .AddColumn(TColumnSchema().Type(VT_STRING).Name("data"))
        .AddColumn(TColumnSchema().Type(VT_UINT64).Name("n_points"));
    
    // TODO Use transaction
    client->MapReduce(
        TMapReduceOperationSpec()
            .DataSizePerMapJob(200_MB)
            .MapperSpec(
                TUserJobSpec()
                .AddFile(
                    TRichYPath(clusterCentersTable)
                    .Format("yson")
                    .FileName("centers_table")
            ))
            .ReduceBy({"cluster"})
            .AddInput<TObjectRecord>(objectsTable)
            .AddOutput<TClusterCenterRecord>(
                TRichYPath(outputClusterCenterTable)
                .Schema(outputSchema)),
        new TKMeansExpectation(nClusters),
        new TKMeansMinimization,  // ReduceCombiner
        new TKMeansMinimization
    );

     client->Map(
        TMapOperationSpec()
            .AddInput<TClusterCenterRecord>(outputClusterCenterTable)
            .AddOutput<TClusterCenterRecord>(
                TRichYPath(outputClusterCenterTable)
                      .Schema(outputSchema)),
        new TKMeansNormalization);
}

}  // namespace KMeans
}  // namespace MLUtils
}  // namespace Datacloud
