# -*- coding: utf-8 -*-
import types
import random
import yt.wrapper as yt_wrapper
from datacloud.dev_utils.yt.yt_utils import get_yt_client
from datacloud.dev_utils.logging.logger import get_basic_logger


__all__ = [
    'compress_table',
    'get_tables',
    'unique_reducer',
    'SubsampleMap',
]

YT_DIR_TYPE = 'map_node'
YT_TABLE_TYPE = 'table'
YT_FILE_TYPE = 'file'


def compress_table(
        src,
        dst_table=None,
        yt_client=None,
        check_codecs=False,
        combine_chunks=True,
        merge_by=None,
        title_suffix=''):
    """
        Compress src_table to dst_table using transform_method
        :param src: iterable of src table or single src table (could be string or TablePath)
    """
    yt_client = yt_client or get_yt_client()

    with yt_client.Transaction():
        try:
            if isinstance(src, types.StringTypes):
                raise TypeError()
            iter(src)
        except TypeError:
            src_table = src
        else:
            if dst_table is None:
                raise ValueError('Dst table should be provided with multiple input')
            if merge_by is None:
                raise ValueError('merge_by should be provided with multiple input')

            yt_client.run_merge(
                src,
                dst_table,
                mode='sorted',
                spec={
                    'merge_by': merge_by,
                    'combine_chunks': combine_chunks,
                    'title': '[XPROD-COMPRESS] merge tables together ' + title_suffix
                }
            )

            src_table = dst_table
            for t in src:
                yt_client.remove(t, force=True)

        transform_op = yt_client.transform(
            src_table,
            dst_table,
            erasure_codec='lrc_12_2_2',
            compression_codec='brotli_6',
            check_codecs=check_codecs,
            optimize_for='scan',
            spec={
                'combine_chunks': combine_chunks,
                'title': '[XPROD-COMPRESS] ' + title_suffix
            }
        )

        return transform_op


def get_tables(path, yt_client, print_depth=0, get_files=False, logger=None):
    logger = logger or get_basic_logger(__name__)
    obj_type = yt_client.get_attribute(path, 'type')
    if obj_type == YT_TABLE_TYPE:
        return [path]
    elif obj_type != YT_DIR_TYPE:
        return []

    tables = []
    for table in sorted(yt_client.list(path)):
        if print_depth > 0:
            logger.info(' Depth {} {}'.format(print_depth, table))
        table_path = yt_wrapper.ypath_join(path, table)
        obj_type = yt_client.get_attribute(table_path, 'type')

        if obj_type == YT_TABLE_TYPE:
            tables.append(table_path)
        elif obj_type == YT_FILE_TYPE and get_files:
            tables.append(table_path)
        elif obj_type == YT_DIR_TYPE:
            tables.extend(get_tables(
                table_path,
                yt_client,
                max(print_depth - 1, 0),
                get_files=get_files,
                logger=logger
            ))

    return tables


def unique_reducer(key, recs):
    for rec in recs:
        yield rec
        break


class SubsampleMap(object):
    def __init__(self, prob_to_take_record):
        assert 0 <= prob_to_take_record <= 1.0, 'prob_to_take_record must be in [0, 1], your prob is {}'.format(prob_to_take_record)
        self._prob_to_take = prob_to_take_record

    def __call__(self, rec):
        if random.random() <= self._prob_to_take:
            yield rec
