# -*- coding: utf-8 -*-
import numpy as np
import pandas as pd
from datacloud.dev_utils.data.data_utils import array_fromstring
from datacloud.dev_utils.yt.yt_utils import get_yt_client
from tqdm import tqdm


def load_data_from_yt(table, target_column_name='target', yt_cluster=None, enable_read_parallel=False):
    if yt_cluster:
        yt_client = get_yt_client(yt_cluster)
    else:
        yt_client = get_yt_client()
    features = []
    targets = []
    external_ids = []
    total_recs = yt_client.row_count(table)
    recs = yt_client.read_table(table, enable_read_parallel=enable_read_parallel)
    for _ in tqdm(range(total_recs)):
        rec = next(recs)
        features.append(array_fromstring(rec['features']))
        if target_column_name is not None:
            targets.append(rec[target_column_name])
        external_ids.append(rec['external_id'])

    print('Data fully loaded!')
    targets = np.array(targets)
    features = np.vstack(features)
    external_ids = np.array(external_ids)
    return external_ids, features, targets


def load_test_train(table, target_column_name='target', yt_cluster=None, enable_read_parallel=False):
    external_ids, features, targets = load_data_from_yt(table, target_column_name, yt_cluster, enable_read_parallel)
    ix_test = targets == -1
    ix_train = targets != -1
    X_test = features[ix_test]
    external_ids_train = external_ids[ix_train]
    X_train = features[ix_train]
    y_train = targets[ix_train]
    external_ids_test = external_ids[ix_test]
    X_test = features[ix_test]
    y_test = features[ix_test]
    return external_ids_train, X_train, y_train, external_ids_test, X_test, y_test


def load_features_to_pandas(table, target_column_name='target', yt_cluster=None, enable_read_parallel=False):
    external_ids, features, targets = load_data_from_yt(table, target_column_name, yt_cluster, enable_read_parallel)
    df = pd.DataFrame(
        data=np.hstack((external_ids.reshape(-1, 1), targets.reshape(-1, 1), features)),
        columns=['external_id', target_column_name] + [i for i in range(features.shape[1])]
    )
    return df
