import datetime
from enum import Enum

from yt.wrapper import ypath_join, OperationsTracker, TablePath
from datacloud.dev_utils.time.patterns import FMT_DATE
from datacloud.dev_utils.logging.logger import get_basic_logger


logger = get_basic_logger(__name__)


class RotateLevel(Enum):
    YEARLY = 1
    MONTHLY = 2
    DAILY = 3


def _merge_tables(yt, tables_dict, sort_by=None,
                  preserve_attrs=('optimize_for', 'compression_codec', 'erasure_codec', 'media', )):
    """
    Merges multiple tables into single with preserving order and some attributes.
    `tables_dict` is a dict like {'result_table_path': ['table1', 'table2']}
    """
    if not tables_dict:
        return
    with yt.Transaction():
        tracker = OperationsTracker()
        for merged_table_path, tables in sorted(tables_dict.iteritems()):
            input_tables = sorted(tables)
            if yt.exists(merged_table_path):
                input_tables = [merged_table_path] + input_tables

            last_input_attrs = yt.get(ypath_join(input_tables[-1], '@'))
            merged_yt_table_path = TablePath(
                merged_table_path,
                schema=last_input_attrs.get('schema'),
                attributes={k: v for k, v in last_input_attrs.iteritems() if k in preserve_attrs}
            )
            if sort_by is None and last_input_attrs.get('sorted_by'):
                sort_by = last_input_attrs['sorted_by']
            logger.info('Merging %s tables into %s', len(tables), merged_table_path)
            if sort_by:
                opt = yt.run_sort(
                    input_tables,
                    merged_yt_table_path,
                    sort_by=sort_by,
                    sync=False
                )
            else:
                opt = yt.run_merge(
                    input_tables,
                    merged_yt_table_path,
                    spec={'combine_chunks': True},
                    sync=False
                )
            tracker.add(opt)
        tracker.wait_all()

        for tables in tables_dict.values():
            for t in tables:
                yt.remove(t)


def rotate_stream_tables(yt, stream_tables, history_root, merge_today=False, sort_by=None):
    """ Rotate stream tables into daily tables in `history_root` """
    current_date_str = datetime.datetime.now().strftime(FMT_DATE)
    day_tables = {}
    for table_path in stream_tables:
        time_str = table_path.split('/')[-1]
        date_str = time_str.split('T')[0]

        if not merge_today and date_str == current_date_str:
            continue

        daily_path = ypath_join(history_root, date_str)
        date_stream = day_tables.setdefault(daily_path, [])
        date_stream.append(table_path)
    _merge_tables(yt, day_tables, sort_by=sort_by)
    return day_tables


def rotate_history_tables(yt, history_root, last_merge_level=RotateLevel.MONTHLY, sort_by=None):
    """ Rotate history tables. Merge daily tables into monthly, and monthly into yearly """
    last_merge_level = RotateLevel(last_merge_level)
    today_parts = datetime.datetime.now().strftime(FMT_DATE).split('-')
    # process tables with names like '2018-01-01', '2018-01'
    for merge_level in range(3, last_merge_level.value, -1):
        current_date_str = '-'.join(today_parts[:merge_level - 1])
        tables = sorted(yt.list(history_root, absolute=True))
        merged_tables = {}
        for table_path in tables:
            time_str = table_path.split('/')[-1]
            if len(time_str.split('-')) != merge_level:
                logger.debug('Skip %s len(%r.split("-")) != %s', table_path, time_str, merge_level)
                continue
            date_str = '-'.join(time_str.split('-')[:-1])
            if date_str == current_date_str:
                logger.debug('Skip %s %s == %s', table_path, date_str, current_date_str)
                continue
            merged_path = ypath_join(history_root, date_str)
            logger.debug('Add %s to merge into %s', table_path, merged_path)
            date_stream = merged_tables.setdefault(merged_path, [])
            date_stream.append(table_path)
        _merge_tables(yt, merged_tables, sort_by=sort_by)
