import uuid
import numpy as np
import yt.wrapper as yt_wrapper
from datacloud.dev_utils.data import data_utils
from datacloud.ml_utils.kmeans.lib import fast_kmeans


def upload_mnist(yt_client, table_path, path_to_mnist):
    raw_data = np.genfromtxt(path_to_mnist, skip_header=1, delimiter=',')
    data = []
    table_path = yt_wrapper.TablePath(table_path, schema=[
        {'name': 'key', 'type': 'string'},
        {'name': 'data', 'type': 'string'}
    ])
    for row in range(raw_data.shape[0]):
        data.append({
            'key': str(raw_data[row, 0]) + '-' + str(uuid.uuid4()),
            'data': data_utils.array_tostring(raw_data[row, 1:])
        })
    yt_client.write_table(table_path, data)


def init_centers(yt_client, table_path, n_centers, n_dims, min_val, max_val):
    data = []
    table_path = yt_wrapper.TablePath(table_path, schema=[
        {'name': 'cluster', 'type': 'int64'},
        {'name': 'data', 'type': 'string'},
        {'name': 'n_points', 'type': 'uint64'}
    ])
    for center in range(n_centers):
        data.append({
            'cluster': center,
            'data': data_utils.array_tostring(np.random.uniform(min_val, max_val, n_dims)),
            'n_points': 1
        })
    yt_client.write_table(table_path, data)


def fit(yt_token, yt_client, n_clusters, data_path, centers_path):
    fast_kmeans.expectation_minimization(
        yt_token,
        yt_client.config['proxy']['url'],
        n_clusters,
        data_path,
        centers_path,
        centers_path)


def predict(yt_token, yt_client, n_clusters,
                   centers_path, data_path, pred_path):
    return fast_kmeans.expectation(
        yt_token,
        yt_client.config['proxy']['url'],
        n_clusters,
        centers_path,
        data_path,
        pred_path
    )
