import sys
from library.python.nyt import client as nyt_client
nyt_client.initialize(sys.argv)
import os
import yt.wrapper as yt_wrapper
from datacloud.ml_utils.kmeans.lib import kmeans
from datacloud.dev_utils.yt import yt_utils


def main():
    yt_token = yt_wrapper.config['token'] or os.environ.get('YT_TOKEN')
    assert yt_token, 'No YT_TOKEN provided'

    n_clusters = 1000
    yt_client = yt_utils.get_yt_client()

    kmeans_folder = '//projects/scoring/tmp/re9ulusv/clust-tcs-dssm'
    centers_path = yt_wrapper.ypath_join(kmeans_folder, 'centers')
    data_path = yt_wrapper.ypath_join(kmeans_folder, 'data')

    # init centers
    print('Init centers!')
    n_dims = 600
    min_val, max_val = -1, 1
    kmeans.init_centers(yt_client, centers_path, n_clusters, n_dims, min_val, max_val)
    print('Done')

    # upload mnist
    # path_to_mnist = '/home/re9ulusv/data/mnist_train.csv'
    # kmeans.upload_mnist(yt_client, sample_data_path, path_to_mnist)

    predict_every = 1
    for i in range(1000):
        print('Iter {}'.format(i))
        kmeans.fit(yt_token, yt_client, n_clusters, data_path, centers_path)
        if (i + 1) % predict_every == 0:
            changed = kmeans.predict(yt_token, yt_client, n_clusters,
                                     centers_path, data_path, data_path)
            print('Changed ', changed)


if __name__ == '__main__':
    main()
