# -*- coding: utf-8 -*-
from collections import namedtuple

import numpy as np
from datacloud.dev_utils.logging.logger import get_basic_logger
from datacloud.dev_utils.data.data_utils import array_fromstring

DEFAULT_TARGET_COL = 'target'
DEFAULT_FEATURES_COL = 'features'
DEFAULT_IDS_COL = 'external_id'

LOAD_TABLE_RESULT = namedtuple('LOAD_TABLE_RESULT', (
    'X', 'y', 'ids', 'X_test', 'ids_test'
))


def load_table_from_yt(
        yt_client,
        table_path,
        target_name=DEFAULT_TARGET_COL,
        features_column=DEFAULT_FEATURES_COL,
        external_id_column=DEFAULT_IDS_COL,
        verbose=None,
        logger=None,
        max_rows=None):

    logger = logger or get_basic_logger(__name__)
    X, X_test = [], []
    ids, ids_test = [], []
    y = []

    should_decode = None

    for i, row in enumerate(yt_client.read_table(table_path)):
        if max_rows is not None and i >= max_rows:
            break

        if should_decode is None:
            features_col_type = type(row[features_column])
            if features_col_type is str:
                should_decode = True
            elif features_col_type is list:
                should_decode = False
            else:
                raise ValueError('Bad features column type: {}'.format(features_col_type))

        if should_decode:
            features = array_fromstring(row[features_column])
        else:
            features = row[features_column]

        external_id = row[external_id_column]

        if target_name is not None and row[target_name] >= 0:
            X.append(features)
            ids.append(external_id)

            if row[target_name] >= 1:
                y.append(1)
            else:
                y.append(0)
        else:
            X_test.append(features)
            ids_test.append(external_id)

        if verbose and (i + 1) % verbose == 0:
            logger.info('Loaded {} rows'.format(i + 1))

    X = np.array(X)
    X_test = np.array(X_test)
    y = np.array(y)
    ids = np.array(ids)
    ids_test = np.array(ids_test)

    return LOAD_TABLE_RESULT(
        X=X,
        y=y,
        ids=ids,
        X_test=X_test,
        ids_test=ids_test
    )
