import yt.wrapper as yt_wrapper
from datacloud.config.yt import MODELS_FOLDER, UNRELIABLE_TMP_FOLDER as TMP
from datacloud.dev_utils.yt import take_part, yt_utils
from datacloud.dev_utils.logging.logger import get_basic_logger
from datacloud.dev_utils.crypta.crypta_snapshot import get_snapshot


logger = get_basic_logger(__name__)


def echo_map(rec):
    yield rec


def split_table(yt_client, table, n_parts):
    # TODO: Test
    n_rows = yt_client.row_count(table)
    block_size = int(n_rows / n_parts)
    start_idx = 0
    new_table_pattern = table + '-{part_idx}'
    result_tables = []
    for part in range(n_parts):
        output_table = new_table_pattern.format(part_idx=part + 1)
        yt_client.run_map(
            echo_map,
            table + '[#{}:#{}]'.format(start_idx, start_idx + block_size),
            output_table,
        )
        result_tables.append(output_table)
        start_idx += block_size
    return result_tables


def _extract_interval(yt_client, input_table, output_table, interval):
    output_table = yt_wrapper.TablePath(
        output_table,
        schema=[
            {'name': 'cid', 'type': 'string'},
            {'name': 'score', 'type': 'double'}
        ]
    )
    take_part.take_part(
        yt_client,
        input_table,
        output_table,
        interval,
        'score',
    )


def extract_interval(yt_client, score, date_str, interval):
    score_table = yt_wrapper.ypath_join(
        MODELS_FOLDER, score.partner_id, score.name, date_str)

    tmp_folder = yt_wrapper.ypath_join(
        TMP, 'audience',
        '{}_{}_{}'.format(score.partner_id, score.name, date_str))

    yt_utils.create_folders([tmp_folder], yt_client)

    sorted_score_table = yt_wrapper.TablePath(
        yt_wrapper.ypath_join(tmp_folder, 'sorted_scores-{}'.format(date_str)),
        schema=[
            {'name': 'score', 'type': 'double'},
            {'name': 'cid', 'type': 'string'}
        ]
    )
    if not yt_utils.check_table_exists(sorted_score_table, yt_client):
        yt_client.run_sort(
            yt_wrapper.TablePath(score_table, columns=['cid', 'score']),
            sorted_score_table,
            sort_by='score',
            spec={
                'title': 'Sort score table by cid'
            }
        )
    output_table = yt_wrapper.ypath_join(tmp_folder, 'cids-{}-{}'.format(
        interval.ratio_from, interval.ratio_to)
    )
    _extract_interval(yt_client, sorted_score_table, output_table, interval)
    yt_client.run_sort(
        output_table,
        sort_by='cid'
    )
    return output_table


@yt_wrapper.with_context
def _cid_to_id_value_reducer(_, recs, context):
    score = None
    for rec in recs:
        if context.table_index == 0:
            score = rec['score']
        elif score is not None:
            if rec['id_type'] not in ('email_md5', 'phone_md5'):
                continue
            yield {
                'id_value': rec['id_value'],
                'id_type': rec['id_type'],
                'score': score,
            }


def cid_to_id_value(yt_client, cid_table, cid_to_all_table, output_table):
    # TODO: Sort table once and save for future, use expiration_date
    output_table = yt_wrapper.TablePath(
        output_table,
        schema=[
            {'name': 'id_value', 'type': 'string'},
            {'name': 'id_type', 'type': 'string'},
            {'name': 'score', 'type': 'double'}
        ]
    )
    yt_client.run_reduce(
        _cid_to_id_value_reducer,
        [cid_table, cid_to_all_table],
        output_table,
        reduce_by='cid',
        spec={
            'title': 'Expand cid to id_value'
        }
    )


def download_segment(yt_client, table):
    """
    Expects `table` with columns: [id_type, id_value]
    """
    data = ['phone,email']
    for row in yt_client.read_table(table):
        id_type = row.get('id_type')
        id_value = row.get('id_value')
        if not id_value:
            continue
        if id_type == 'phone_md5':
            data.append('{},'.format(id_value))
        elif id_type == 'email_md5':
            data.append(',{}'.format(id_value))
        else:
            logger.info('Empty id value')
    return data


def prepare_audience_table(yt_client, score, date_str, interval, output_table):
    crypta = get_snapshot(yt_client, date_str)
    cid_interval_table = extract_interval(yt_client, score, date_str, interval)
    cid_to_id_value(yt_client, cid_interval_table, crypta.cid_to_all, output_table)


def build_audience_table_name(score, date_str, interval):
    return '{}-{}-{}-{}-{}'.format(
        score.partner_id, score.name, date_str,
        interval.ratio_from, interval.ratio_to
    )
