# -*- coding: utf-8 -*-

from typing import List, Optional, NamedTuple
from dataclasses import dataclass
import datetime
import logging
import re

import yt.wrapper as yt
from yt.wrapper import YtClient, ypath_split

from travel.hotels.lib.python3.yt.ytlib import schema_from_dict
from travel.hotels.devops.sandbox_planner.data import YtNodeInfo


class YtProxyData:
    def __init__(self):
        self.accessible = False
        self.changed = False
        self.processed_nodes = list()


class NodeKey(NamedTuple):
    plan_id: str
    node: str


class NodeData(NamedTuple):
    yt_node_info: YtNodeInfo
    ignore_modifications: bool


@dataclass
class YtTriggerPath:
    yt_proxy: str
    yt_path: str
    as_dir: bool
    min_age: Optional[datetime.timedelta]
    max_age: Optional[datetime.timedelta]
    name_from: Optional[str]
    creation_date_from: Optional[datetime.datetime]
    table_name_exclude_pattern: Optional[str]
    ignore_modifications: bool


class YtTrigger:
    def __init__(self, yt_clients, yt_root):
        self.yt_clients = yt_clients
        self.processed_nodes_yt_path = yt.ypath_join(yt_root, 'processed_nodes')
        self.proxy2data = dict()  # Proxy -> YtProxyInfo
        self._read_processed_nodes()

    def _read_processed_nodes(self):
        self.processed_nodes = dict()
        for yt_proxy, client in self.yt_clients.items():
            proxy_data = YtProxyData()
            self.proxy2data[yt_proxy] = proxy_data
            try:
                logging.info(f'Reading processed nodes from {yt_proxy}.{self.processed_nodes_yt_path}')
                if client.exists(self.processed_nodes_yt_path):
                    for row in client.read_table(self.processed_nodes_yt_path):
                        proxy_data.processed_nodes.append(YtNodeInfo(**row))
                proxy_data.accessible = True
            except Exception as e:
                logging.exception(f'Failed to read processed nodes from {yt_proxy}.{self.processed_nodes_yt_path}', e)

    @staticmethod
    def _parse_time_attr(attr) -> datetime.datetime:
        return datetime.datetime.strptime(attr[:-8], '%Y-%m-%dT%H:%M:%S')

    def actualize_and_find_new_nodes(self, plan_id: str,
                                     yt_trigger_paths: List[YtTriggerPath], limit: int) -> List[YtNodeInfo]:  # new nodes
        current_nodes = dict()
        for trigger_path in yt_trigger_paths:
            yt_path = trigger_path.yt_path
            yt_proxy = trigger_path.yt_proxy
            proxy_data = self.proxy2data[yt_proxy]
            if not proxy_data.accessible:
                logging.error(f"Skip reading at inaccessible proxy {yt_proxy}")
                return list()
            yt_client = self.yt_clients[yt_proxy]
            if not yt_client.exists(yt_path):
                logging.warning(f"Skip inexistent YT path {yt_path}")
                return list()

            now = datetime.datetime.utcnow()
            if trigger_path.as_dir:
                for node in yt_client.list(yt_path, attributes=['content_revision', 'creation_time'], absolute=True):
                    if trigger_path.min_age is not None and (now - YtTrigger._parse_time_attr(node.attributes['creation_time'])) < trigger_path.min_age:
                        continue
                    if trigger_path.max_age is not None and (now - YtTrigger._parse_time_attr(node.attributes['creation_time'])) > trigger_path.max_age:
                        continue
                    if trigger_path.creation_date_from is not None and YtTrigger._parse_time_attr(node.attributes['creation_time']) < trigger_path.creation_date_from:
                        continue
                    if trigger_path.name_from is not None and ypath_split(str(node))[1] < trigger_path.name_from:
                        continue
                    if trigger_path.table_name_exclude_pattern is not None and re.fullmatch(trigger_path.table_name_exclude_pattern, ypath_split(str(node))[1]) is not None:
                        continue

                    if yt_proxy not in current_nodes:
                        current_nodes[yt_proxy] = dict()

                    yt_node_info = YtNodeInfo(
                        plan_id=plan_id,
                        node=str(node),
                        version=node.attributes['content_revision'],
                    )

                    current_nodes[yt_proxy][NodeKey(yt_node_info.plan_id, yt_node_info.node)] = NodeData(yt_node_info, trigger_path.ignore_modifications)
            else:
                root_node = yt_client.get(yt_path, attributes=['content_revision'])
                if yt_proxy not in current_nodes:
                    current_nodes[yt_proxy] = dict()
                yt_node_info = YtNodeInfo(
                    plan_id=plan_id,
                    node=yt_path,
                    version=root_node.attributes['content_revision'],
                )
                current_nodes[yt_proxy][NodeKey(plan_id, yt_path)] = NodeData(yt_node_info, False)

        all_new_nodes = list()

        for yt_proxy in current_nodes:
            proxy_data = self.proxy2data[yt_proxy]
            old_nodes = proxy_data.processed_nodes
            proxy_data.processed_nodes = list()
            for node in old_nodes:
                if node.plan_id != plan_id:
                    # Чужих не трогаем
                    proxy_data.processed_nodes.append(node)
                else:
                    key = NodeKey(node.plan_id, node.node)
                    if key in current_nodes[yt_proxy] and (current_nodes[yt_proxy][key].ignore_modifications or current_nodes[yt_proxy][key].yt_node_info.version == node.version):
                        # Существующие ноды оставляем
                        proxy_data.processed_nodes.append(current_nodes[yt_proxy][key].yt_node_info)
                        current_nodes[yt_proxy].pop(key)
                    else:
                        # Нода была, но сейчас её нет
                        proxy_data.changed = True
            if current_nodes[yt_proxy]:
                # Есть какие-то ноды, которых раньше не было
                new_nodes = sorted([x.yt_node_info for x in current_nodes[yt_proxy].values()], key=lambda node: node.node)
                proxy_data.changed = True
                if len(new_nodes) > limit:
                    logging.warning(f"Too much new nodes ({len(new_nodes)}) for yt trigger for plan item {plan_id}, "
                                    f"will use only first {limit} ")
                    new_nodes = new_nodes[:limit]
                proxy_data.processed_nodes.extend(new_nodes)
                all_new_nodes.extend(new_nodes)

        return all_new_nodes

    def flush(self):
        self._write_processed_nodes()

    def _write_processed_nodes(self):
        schema = schema_from_dict({
            'plan_id': 'string',
            'node': 'string',
            'version': 'uint64',
        })
        for yt_proxy, proxy_data in self.proxy2data.items():
            if not proxy_data.changed:
                continue
            yt_client = self.yt_clients[yt_proxy]
            data = (ni._asdict() for ni in proxy_data.processed_nodes)
            logging.info(f'Writing changed processed nodes to {yt_proxy}.{self.processed_nodes_yt_path}')
            with yt_client.Transaction():
                yt_client.remove(self.processed_nodes_yt_path, force=True)
                yt_client.create('table', self.processed_nodes_yt_path, attributes={'schema': schema})
                yt_client.write_table(self.processed_nodes_yt_path, data)

    @staticmethod
    def get_node_version(yt_client: YtClient, path: str):
        return yt_client.get(yt.ypath_join(path, '@content_revision'))
