import numpy as np
import datacloud.dev_utils.data.data_utils as du
import yt.wrapper as yt_wrapper


class MaxReducer:
    def __init__(self, ext_id_key):
        self.ext_id_key = ext_id_key

    def __call__(self, key, recs):
        batch_size = 10000
        data, ind = [], 0
        max_array = None
        for rec in recs:
            ind += 1
            data.append(du.array_fromstring(rec['dot_array']))
            if ind == batch_size:
                if max_array is not None:
                    data.append(max_array)
                max_array = np.amax(data, axis=0)
                data, ind = [], 0
        if ind:
            if max_array is not None:
                data.append(max_array)
            max_array = np.amax(data, axis=0)
        yield {
            self.ext_id_key: key[self.ext_id_key],
            'features': du.array_tostring(max_array)
        }


class MaxFeatureReducer:
    def __init__(self, ext_id_key):
        self.ext_id_key = ext_id_key

    def __call__(self, key, recs):
        data = []
        for rec in recs:
            data.append(du.array_fromstring(rec['features']))
        max_array = np.amax(data, axis=0)
        yield {
            self.ext_id_key: key[self.ext_id_key],
            'features': du.array_tostring(max_array)
        }


def join_dssm_scores(processor, tables_to_take=25):  # 25 weeks - 175 days
    yt_client, config = processor.yt_client, processor.config
    tables = []
    for table in yt_client.list(config.weekly_dir):
        # TODO: Add assertion that `table` contains date in it's name
        if table <= config.date_str:
            tables.append(yt_wrapper.ypath_join(config.weekly_dir, table))
    tables.sort(reverse=True)
    tables = tables[:tables_to_take]

    yt_client.run_reduce(
        MaxFeatureReducer(config.ext_id_key),
        tables,
        config.ready_table,
        reduce_by=config.ext_id_key,
        spec=dict(
            title='[{}] Join scores for {} tables'.format(config.tag, tables_to_take),
            **config.cloud_nodes_spec
        )
    )
    yt_client.run_sort(
        config.ready_table,
        sort_by=config.ext_id_key,
        spec=dict(
            title='[{}] Sort after join scores'.format(config.tag),
            **config.cloud_nodes_spec
        )
    )
    processor.collect_garbage()
