from typing import Union
import datetime
import dateutil.parser
from dataclasses import dataclass
import logging

from yt.wrapper import YPath, YtClient, create, link, list, Transaction, remove

from travel.hotels.lib.python3.lang.auto_str import auto_str
from travel.hotels.lib.python3.yt.ytlib import join, transfer_results


@dataclass
class Node:
    path: str
    type: str
    ts: datetime.datetime


@auto_str(ignore=['logger'])
class CleanupStrategy:
    def __init__(self):
        self.logger = logging.getLogger(__name__).getChild(self.__class__.__name__)

    def clean(self, path: Union[str, YPath], yt_client: YtClient):
        raise NotImplementedError

    def delete_old_nodes(self, path: Union[str, YPath], keep_last_nodes_count: int, keep_weekly_milestones: bool, yt_client: YtClient, node_types=frozenset(['map_node'])):
        dirs = list(path, attributes=['type', 'creation_time'], absolute=True, client=yt_client)
        dirs_with_ts = []
        for d in dirs:
            node_type = d.attributes['type']
            if node_type in node_types:
                dirs_with_ts.append(Node(str(d), node_type, dateutil.parser.parse(d.attributes['creation_time'])))
        dirs_with_ts = sorted(dirs_with_ts, key=lambda x: x.ts)
        dirs_with_ts = dirs_with_ts[:-keep_last_nodes_count]
        prev_week_nr = None
        with Transaction(client=yt_client):
            for d in dirs_with_ts:
                c = d.ts.isocalendar()
                week_nr = (c[0], c[1])
                if not keep_weekly_milestones or (prev_week_nr is not None and prev_week_nr == week_nr):
                    self.logger.info("Remove %s %s", d.type, d.path)
                    remove(d.path, recursive=True, client=yt_client)
                prev_week_nr = week_nr


class DummyCleanupStrategy(CleanupStrategy):
    def __init__(self):
        super(DummyCleanupStrategy, self).__init__()

    def clean(self, path: Union[str, YPath], yt_client: YtClient):
        pass


class StandardCleanupStrategy(CleanupStrategy):
    def __init__(self, keep_last_nodes_count: int, keep_weekly_milestones: bool):
        super(StandardCleanupStrategy, self).__init__()
        self.keep_last_nodes_count = keep_last_nodes_count
        self.keep_weekly_milestones = keep_weekly_milestones

    def clean(self, path: Union[str, YPath], yt_client: YtClient):
        self.delete_old_nodes(path, self.keep_last_nodes_count, self.keep_weekly_milestones, yt_client=yt_client)


DEFAULT_CLEANUP_STRATEGY = StandardCleanupStrategy(7, True)


class KeepLastNCleanupStrategy(CleanupStrategy):
    def __init__(self, keep_last_nodes_count: int):
        super(KeepLastNCleanupStrategy, self).__init__()
        self.keep_last_nodes_count = keep_last_nodes_count

    def clean(self, path: Union[str, YPath], yt_client: YtClient):
        self.delete_old_nodes(path, self.keep_last_nodes_count, False, yt_client=yt_client)


def parse_cleanup_strategy(value: str):
    '''
    Format {StrategyName}[:{param1}[:{param2}]]
    '''
    strategy_name, *strategy_params = value.split(':')
    if strategy_name == 'Standard':
        keep_last_nodes_count = int(strategy_params[0])
        keep_weekly_milestones = bool(strategy_params[1])
        return StandardCleanupStrategy(keep_last_nodes_count, keep_weekly_milestones)
    elif strategy_name == 'Default':
        return DEFAULT_CLEANUP_STRATEGY
    elif strategy_name == 'KeepLastN':
        keep_last_nodes_count = int(strategy_params[0])
        return KeepLastNCleanupStrategy(keep_last_nodes_count)
    elif strategy_name == 'Dummy':
        return DummyCleanupStrategy()
    else:
        raise Exception(f'Unknown strategy name {strategy_name}')


class VersionedPath:
    def __init__(self, path: Union[str, YPath], yt_client: YtClient, cleanup_strategy: CleanupStrategy = DEFAULT_CLEANUP_STRATEGY):
        self.base_path = YPath(path)
        self.run_path = None
        self.yt_client = yt_client
        self.cleanup_strategy = cleanup_strategy

        self.timestamp = datetime.datetime.utcnow()
        self.time_label = self.timestamp.isoformat().split(".")[0] + "Z"

        self.logger = logging.getLogger(__name__)

    def __enter__(self):
        create('map_node', self.get_run_dir(), recursive=True, ignore_existing=True, client=self.yt_client)
        return self.get_run_dir()

    def __exit__(self, exc_type, exc_val, exc_tb):
        if exc_type is None:
            self.update_latest()
            self.cleanup_strategy.clean(self.base_path, self.yt_client)

    def get_run_dir(self):
        if self.run_path is None:
            short_path = join(self.base_path, self.time_label[:10])
            nodes = list(self.base_path, attributes=['type', 'creation_time'], absolute=True, client=self.yt_client)

            is_today_node_exist = False
            for node in nodes[-5:]:
                node_type = node.attributes['type']
                if (node_type == 'map_node') and (str(short_path) in str(node)):
                    is_today_node_exist = True

            if not is_today_node_exist:
                self.run_path = short_path
            else:
                self.run_path = join(self.base_path, self.time_label)
        return self.run_path

    def get_latest_path(self):
        return join(self.base_path, "latest")

    def update_latest(self):
        link(self.get_run_dir(), self.get_latest_path(), client=self.yt_client, force=True)

    def transfer_results(self, destination: str, yt_token: str, yt_proxy: str):
        transfer_results(self.get_run_dir(), yt_proxy, destination, yt_token, self.get_latest_path())
