//
// Created by luckybug on 17.07.18.
//
#include <library/cpp/hnsw/index_builder/dense_vector_distance.h>
#include <library/cpp/kmeans_hnsw/clustering_data.h>
#include <library/cpp/kmeans_hnsw/kmeans_clustering.h>
#include <library/cpp/kmeans_hnsw/dense_vector_centroids_factory.h>
#include <mapreduce/yt/interface/init.h>
#include <mapreduce/yt/interface/client.h>

#include "distance_types.h"

using TVectorComponent = float;


using TBaseDotDistance = NHnsw::TDotProduct<TVectorComponent>;
using TDotDistance = NHnsw::TDistanceWithDimension<TVectorComponent, TBaseDotDistance>;

using TBaseEuclidDistance = NHnsw::TL2SqrDistance<TVectorComponent>;
using TEuclidDistance = NHnsw::TDistanceWithDimension<TVectorComponent, TBaseEuclidDistance>;

using TItemStorage = NHnsw::TDenseVectorStorage<TVectorComponent>;
using TClustering = NKmeansHnsw::TClusteringData<TItemStorage>;

NKmeansHnsw::TBuildKmeansClustersOptions GetBuildKmeansClustersOptions(const size_t numClusters, const size_t numThreads) {
    NKmeansHnsw::TBuildKmeansClustersOptions kmeansClustersOptions = {};
    kmeansClustersOptions.ReportProgress = true;
    kmeansClustersOptions.NumIterations = 16;
    kmeansClustersOptions.HnswNeighborhoodSize = 16;
    kmeansClustersOptions.NumClusters = numClusters;
    kmeansClustersOptions.NumThreads = numThreads;
    kmeansClustersOptions.NumVectorsForEmptyClusters = 3;

    NHnsw::THnswBuildOptions& hnswOpts = kmeansClustersOptions.HnswBuildOptions;

    hnswOpts.NumThreads = numThreads;
    hnswOpts.MaxNeighbors = Min<ui32>(32, static_cast<const ui32 &>(kmeansClustersOptions.NumClusters - 1));
    hnswOpts.BatchSize = 1000;
    hnswOpts.UpperLevelBatchSize = 40000;
    hnswOpts.SearchNeighborhoodSize = 300;
    hnswOpts.NumExactCandidates = 100;
    hnswOpts.LevelSizeDecay = Max<ui32>(static_cast<const ui32 &>(hnswOpts.MaxNeighbors / 2), 2);
    hnswOpts.ReportProgress = true;

    return kmeansClustersOptions;
}

TVector<NYT::TNode> ReadTable(const TString & sourceTable, const NYT::IClientPtr& client) {

    auto reader = client->CreateTableReader<NYT::TNode>(sourceTable);

    const size_t rowsCount = static_cast<size_t>(client->Get(sourceTable + "/@row_count").AsInt64());

    TVector<NYT::TNode> nodes(Reserve(rowsCount));

    for (; reader->IsValid(); reader->Next()) {
        auto node = reader->MoveRow();
        nodes.emplace_back(std::move(node));
    }

    return nodes;
}


int main(int argc, const char * argv[]) {
    NYT::Initialize(argc, argv);

    if(argc < 7) {
        Cerr << "Usage: " << argv[0] << " srcTable dstClusteredTable dstClustersTable coordinateFieldName dimension distanceType(Dot|Euclid clustersCount)" << Endl;
        return -1;
    }

    argc--;
    argv++;

    const TString srcTable = argv[0];   //"//home/so_fml/luckybug/mails_coorinates"
    const TString dstClusteredTable = argv[1];   //"//home/so_fml/luckybug/clustered_mails"
    const TString dstClustersTable = argv[2];   //"//home/so_fml/luckybug/mails_clusters"
    const TString coordinateFieldName = argv[3];   //"coordinate"
    const auto dimension = FromString<size_t>(argv[4]);   //100
    const auto distanceType = FromString<TDistanceType>(argv[5]);   //100
    const auto clustersCount = FromString<size_t>(argv[6]);   //100


    auto client = NYT::CreateClient("hahn");
    auto nodes = ReadTable(srcTable, client);

    TVector<TVectorComponent> coordinates(Reserve(dimension*nodes.size()));

    for(const auto & node : nodes) {
        for(const auto & v : node[coordinateFieldName].AsList()) {
            switch(v.GetType()) {

                case NYT::TNode::Int64:
                    coordinates.emplace_back(v.AsInt64());
                    break;
                case NYT::TNode::Uint64:
                    coordinates.emplace_back(v.AsUint64());
                    break;
                case NYT::TNode::Double:
                    coordinates.emplace_back(v.AsDouble());
                    break;
                case NYT::TNode::Undefined:
                case NYT::TNode::String:
                case NYT::TNode::Bool:
                case NYT::TNode::List:
                case NYT::TNode::Map:
                case NYT::TNode::Null:
                    ythrow yexception() << v.GetType();
            }
        }
    }

    const TItemStorage vectors(TBlob::NoCopy(coordinates.data(), coordinates.size() * sizeof(TVectorComponent)), dimension);

    NKmeansHnsw::TBuildKmeansClustersOptions kmeansClustersOptions = GetBuildKmeansClustersOptions(clustersCount, 16);

    const TClustering clustering = [distanceType, &vectors, &kmeansClustersOptions, dimension] {
        switch (distanceType) {

            case TDistanceType::Dot:
                return NKmeansHnsw::BuildKmeansClusters<TDotDistance>(
                        vectors, NKmeansHnsw::TDenseVectorCentroidsFactory<TVectorComponent>(dimension),
                        kmeansClustersOptions, TDotDistance(TBaseDotDistance(), dimension));
            case TDistanceType::Euclid:
                return NKmeansHnsw::BuildKmeansClusters<TEuclidDistance>(
                        vectors, NKmeansHnsw::TDenseVectorCentroidsFactory<TVectorComponent>(dimension),
                        kmeansClustersOptions, TEuclidDistance(TBaseEuclidDistance(), dimension));
        }

    }();



    {
        auto writer = client->CreateTableWriter<NYT::TNode>(dstClustersTable);

        for (size_t clusterId = 0; clusterId < clustering.ClusterIds.size(); ++clusterId) {
            const TVector<ui32>& clusterIds = clustering.ClusterIds[clusterId];

            NYT::TNode newNode;
            newNode("cluster", clusterId);

            auto c = NYT::TNode::CreateList();
            for (size_t i = 0; i < dimension; ++i)
                c.Add(clustering.Clusters.GetItem(clusterId)[i]);

            newNode("coordinate", c);

            for (auto id : clusterIds)
                nodes[id]("cluster", clusterId);

            writer->AddRow(newNode);
        }
    }
    {
        auto writer = client->CreateTableWriter<NYT::TNode>(dstClusteredTable);

        for(const auto & node : nodes) {
            writer->AddRow(node);
        }
    }
}
