# -*- coding: utf-8 -*-
import collections
import datetime
import logging

import yt.wrapper as yt

import irt.bannerland.options
import irt.common.yt as common_yt


log = logging.getLogger(__name__)

DEFAULT_DATE_FMT = irt.bannerland.options.get_option('bannerland_pocket_name_format')


def get_dir_dt_from_dir_name(d, dt_format):
    d_name = d.split('/')[-1]
    try:
        return datetime.datetime.strptime(d_name, dt_format)
    except ValueError:
        dt_len = len(datetime.datetime.now().strftime(dt_format))
        return datetime.datetime.strptime(d_name[:dt_len], dt_format)


# нужен базовый класс (object), иначе криво работает super (python 2)
class YtDirWorker(object):
    """Класс для выполнения операции в yt-директории"""
    def __init__(self, attr_name, depend_attrs=(), yt_client=yt, ignore_depends=False):
        """
        :param str attr_name:   имя атрибута, который ставится на ноду после успешного завершения
        :param list[str] depend_attrs:  список атрибутов которые должны стоять, чтобы запустить обработку
        """
        self.attr_name = attr_name
        self.depend_attrs = depend_attrs
        self.yt_client = yt_client
        self.ignore_depends = ignore_depends

    def do_work(self, node):
        """
        Абстрактный метод обработки ноды
        :param str node: Нода для обработки (может быть как директорией, так и таблицей)
        """
        raise Exception("method do_work not implemented")

    def do_work_wrapper(self, node):
        self.do_work(node)

    def process(self, work_dir):
        """Запускает обработку директории work_dir"""
        for node in self._get_nodes(work_dir):
            # если что-то пошло не так в do_work, нужно кидать исключение
            logging.info('call do_work for node %s', node)
            start_time = datetime.datetime.now()
            self.do_work_wrapper(node)
            self.set_dir_worker_attr_done(node, start_time)

    @staticmethod
    def get_timings(start_time):
        end_time = datetime.datetime.now()
        duration = end_time - start_time
        return {
            'start_time': start_time.strftime(DEFAULT_DATE_FMT),
            'end_time': end_time.strftime(DEFAULT_DATE_FMT),
            'duration': duration.total_seconds(),
        }

    def set_dir_worker_attr_done(self, node, start_time):
        timings = self.get_timings(start_time)
        self.yt_client.set_attribute(node, self.attr_name, timings)

    def _get_nodes(self, work_dir):
        all_nodes = self.yt_client.list(work_dir, absolute=True)

        # ищем необработанные ноды (идём с конца до первой обработанной ноды, дальше уже не смотрим)
        check_nodes = []
        for node in sorted(all_nodes, reverse=True):
            is_processed = self.yt_client.get_attribute(node, self.attr_name, False) if self.attr_name else False
            if is_processed:
                break
            check_nodes.append(node)

        # проверяем зависимости
        if self.ignore_depends:
            nodes = check_nodes
        else:
            nodes = [node for node in check_nodes if self._check_needed_attrs(node)]
        return sorted(nodes)

    def _check_needed_attrs(self, node):
        for attr in self.depend_attrs:
            if not self.yt_client.get_attribute(node, attr, None):
                return False
        return True


class _DateRange:
    """Helper class to keep track of seen dates

    Accepts range, i.e. '0-30' or '90-' and action i.e. 'all' or '1/week', 'n/month',
    Keeps track of seen dates and responds if date should be kept or deleted.
    """
    def __init__(self, rng, action):
        start, end = rng.split('-', 1)
        self.start = int(start) if start else 0
        self.end = int(end) if end else float('inf')
        self.now = datetime.datetime.now()

        self.kept = collections.defaultdict(int)

        self.action = action

        if action == 'all':
            self.get_key = self._all_key
            self.count = float('inf')
        elif action == 'none':
            self.get_key = self._all_key
            self.count = 0
        else:
            count, period = action.split('/', 1)
            self.count = int(count)
            if period == 'day':
                self.get_key = self._day_key
            elif period == 'week':
                self.get_key = self._week_key
            elif period == 'month':
                self.get_key = self._month_key

    def _day_key(self, date):
        return (date.year, date.month, date.day)

    def _week_key(self, date):
        """Returns tuple (year, week), as reported by isocalendar"""
        return date.isocalendar()[:2]

    def _month_key(self, date):
        return (date.year, date.month)

    def _all_key(self, date):
        return (True, )

    def _check(self, target):
        return self.start <= (self.now - target).days < self.end

    def keep(self, target, table_name):
        """
        Returns False if target date matches this range and shoud be deleted
        otherwise returns True
        """
        if self._check(target):
            # prepend table name to the date_key, to make sure we would not keep just one table in directory
            key = (table_name, ) + self.get_key(target)
            if self.kept[key] < self.count:
                self.kept[key] += 1
                # keep
                log.debug("Matched {}, keeping".format(self))
                return True
            else:
                # Already have kept self.count dates, delete
                log.debug("Matched {}, marking for deletion".format(self))
                return False
        # doesn't match our range: keep
        return True

    def __repr__(self):
        return "DateRange({}-{}, action '{}')".format(
            self.start, self.end, self.action)

    def __str__(self):
        return self.__repr__()

    def reset(self):
        self.kept = collections.defaultdict(int)


class YtCleaner(object):
    def __init__(self, yt_dir, cleaner_config, dt_format=DEFAULT_DATE_FMT):
        self.yt_dir = yt_dir
        self.dt_format = dt_format
        self.node_ranges = {}
        subnode_type = cleaner_config['subnode_type']
        if subnode_type not in ['map_node', 'table']:
            raise ValueError('bad subnode_type = {}'.format(subnode_type))
        self.expected_subnode_type = subnode_type
        for node_name, rng in cleaner_config.get('node_ranges', {}).items():
            ranges = self._range_from_config(rng)
            self.node_ranges[node_name] = ranges

            if subnode_type == 'table' and any(r.action not in ['all', 'none'] for r in ranges):
                raise NotImplementedError('tables cleaner_mode not support 1/day config syntax')

        self.default_range = self._range_from_config(cleaner_config['default_range'])

        if subnode_type == 'table' and any(r.action not in ['all', 'none'] for r in self.default_range):
            raise NotImplementedError('tables cleaner_mode not support 1/day config syntax')

    @staticmethod
    def _range_from_config(rng_config):
        if isinstance(rng_config, int):
            # short config i.e. 'table_foo: 30', meaning keep 30 days, remove the rest
            return [_DateRange('-{}'.format(rng_config), 'all'),
                    _DateRange('{}-'.format(rng_config), 'none')]

        elif isinstance(rng_config, dict):
            # long config i.e. 'table_bar': {'-60': '1/week'}
            return [_DateRange(rng, action) for rng, action in rng_config.items()]

        raise TypeError("Bad range config {}".format(rng_config))

    def run(self, dry_run=False):
        log.info("Started cleaning {} (dry_run={})".format(self.yt_dir, dry_run))
        deleted_tables = deleted_dirs = 0

        date_nodes = yt.list(self.yt_dir, sort=True)
        for date_node in date_nodes:
            date_node_full_name = self.yt_dir + '/' + date_node
            log.info('Processing {}'.format(date_node_full_name))

            try:
                dt = get_dir_dt_from_dir_name(date_node, self.dt_format)
            except Exception:
                log.exception("Could not parse {} into date. Skipping".format(date_node))
                continue

            subnode_type = common_yt.get_attribute(date_node_full_name, 'type', yt)
            if self.expected_subnode_type != subnode_type:
                raise ValueError('expected {} got {} for yt dir {}'.format(self.expected_subnode_type, subnode_type, date_node_full_name))

            if subnode_type == "map_node":
                nodes = yt.list(date_node_full_name, absolute=True)
            elif subnode_type == "table":
                nodes = [date_node_full_name]
            else:
                raise ValueError("bad node type = {} in dir {}".format(subnode_type, date_node_full_name))

            for node_full_name in nodes:
                node_name = yt.ypath_split(node_full_name)[1]
                range_checks = self.node_ranges.get(node_name, self.default_range)
                log.debug("Using {} as range checks".format(range_checks))
                log.info('Processing {}'.format(node_full_name))

                keep = all([check.keep(dt, node_name) for check in range_checks])
                if not keep:
                    if dry_run:
                        log.info("Not deleting {}: dry run".format(node_full_name))
                    else:
                        log.info("Deleting {}".format(node_full_name))
                        yt.remove(node_full_name, recursive=True)
                    deleted_tables += 1

            # If the directory is empty, we can drop it too
            if not dry_run and not yt.list(date_node_full_name):
                log.info("Deleting empty directory {}".format(date_node_full_name))
                yt.remove(date_node_full_name)
                deleted_dirs += 1

        if dry_run:
            log.info("Done dry_run for {}. Processed {} dirs, Would have removed {} tables".format(
                self.yt_dir, len(date_nodes), deleted_tables))
        else:
            log.info("Done cleaning {}. Processed {} dirs, removed {} tables".format(
                self.yt_dir, len(date_nodes), deleted_tables))


class YtCollectDay:
    def __init__(self, yt_client=yt):
        self.yt_client = yt_client

    def run_from_config(self, config):
        self.run(**config)

    def run(self, from_dir, to_dir, collect_type, in_dt_format, day_limit, output_table_name_prefix=None, expiration_timeout=None):
        """Merge tables from dir in per-day tables.
        """

        yt_client = self.yt_client
        output_table_name_prefix = output_table_name_prefix or ''
        types_formats = {
            'day': '%Y%m%d',
            'month': '%Y-%m-00'}

        if collect_type not in types_formats:
            raise Exception('bad collect_type! select "day" or "month"')

        tables = yt_client.list(from_dir, sort=True)
        collect_config = collections.defaultdict(list)
        now_dt = datetime.datetime.now()
        for table in tables:
            table_dt = get_dir_dt_from_dir_name(table, in_dt_format)
            if now_dt.toordinal() - table_dt.toordinal() < day_limit:
                continue
            target_table_name = to_dir + '/' + output_table_name_prefix + table_dt.strftime(types_formats[collect_type])
            collect_config[target_table_name].append(from_dir + '/' + table)

        for target_table_name, input_tables in collect_config.items():
            logging.warning('will merge tables: %s => %s', input_tables, target_table_name)
            with yt_client.Transaction(),\
                    yt_client.TempTable() as tmp:
                tables_for_merge = input_tables[:]
                if yt_client.exists(target_table_name):
                    tables_for_merge.append(target_table_name)

                schemas = [
                    common_yt.get_schema(table_name, yt_client)
                    for table_name in tables_for_merge
                ]
                merged_schema = common_yt.merge_schemas(schemas)

                # выводим схему, подходящую подо все таблицы, т.к. схема таскоофферов может меняться
                yt_client.alter_table(tmp, schema=merged_schema)

                yt_client.run_merge(tables_for_merge, tmp, spec={"max_data_weight_per_job": 400 * yt.common.GB})
                yt_client.transform(tmp, target_table_name, optimize_for='scan')
                if expiration_timeout:
                    common_yt.set_expiration_time(target_table_name, expiration_timeout, yt_client)

                for table in input_tables:
                    yt_client.remove(table)
