from datacloud.dev_utils.logging.logger import get_basic_logger

logger = get_basic_logger(__name__)


class Interval:
    def __init__(self, ratio_from, ratio_to):
        self.ratio_from = ratio_from
        self.ratio_to = ratio_to
        self.check_interval(self)

    def to_idx(self, row_count):
        idx_from = int(row_count * self.ratio_from)
        idx_to = int(row_count * self.ratio_to)
        return idx_from, idx_to

    def table_modifier(self, row_count):
        """ Get table_path modifier to take part of records from table_path
        Args:
            row_count (int): N records in table
        Returns:
            str : Table path modifier to select required range of recrods
        """
        self.check_interval(self)
        records_from, records_to = self.to_idx(row_count)
        modifier = '[#{}:#{}]'.format(records_from, records_to)
        return modifier

    def __str__(self):
        return '[{}, {}]'.format(self.ratio_from, self.ratio_to)

    @staticmethod
    def check_interval(interval):
        assert 0 <= interval.ratio_from <= 1, 'ratio_from must be between 0 and 1, current interval is {}'.format(interval)
        assert 0 <= interval.ratio_to <= 1, 'ratio_to must be between 0 and 1, current interval is {}'.format(interval)
        assert interval.ratio_from <= interval.ratio_to, 'The condition "ratio_from <= ratio_to" is not satisfied.'


def take_part(yt_client, table_in, table_out, interval, key):
    """ Take part of the table from ratio_from% to ratio_to%

    Will sort records from table_in in ascending order, and save records
    in range [records_in_input_table * ratio_from, records_in_input_table * ratio_to]

    Args:
        yt_client (yt.wrapper.YtClient): YtClient
        table_in (str): Path to table on yt from from which records will be taken
        table_out (str): Path to table on yt to which result records will be saved
        interval (Interval): Inerval of the table that will be taken
        key (str): Column key by which records will be sorted before take part
    """
    logger.info('Start take part from table {} interval: {}'.format(table_in, interval))
    Interval.check_interval(interval)
    if isinstance(key, str):
        key = [key]

    input_sorted_by = yt_client.get_attribute(table_in, 'sorted_by')
    assert input_sorted_by == key, 'Input table is sorted by {}, but take part is required by {}'.format(input_sorted_by, key)

    tag = 'TAKE-PART-{}-to-{}'.format(interval.ratio_from, interval.ratio_to)
    row_count = yt_client.row_count(table_in)
    yt_client.run_sort(
        str(table_in) + interval.table_modifier(row_count),
        table_out,
        sort_by=key,
        spec={'title': '[{}] Copy segment'.format(tag)}
    )


def sort_and_take_part(yt_client, table_in, table_out, interval, key):
    tag = 'TAKE-PART-{}-to-{}'.format(interval.ratio_from, interval.ratio_to)
    with yt_client.Transaction():
        with yt_client.TempTable('//projects/scoring/tmp') as tmp_sorted_table:
            yt_client.run_sort(
                table_in,
                tmp_sorted_table,
                sort_by=key,
                spec={'title': '[{}] Sort input table'.format(tag)}
            )
            take_part(yt_client, tmp_sorted_table, table_out, interval, key)
