# -*- coding: utf-8 -*-
from __future__ import annotations

import logging
from abc import abstractmethod
from contextlib import contextmanager
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from pathlib import Path
from textwrap import dedent
from typing import Any, Callable, Dict, Generator, Iterable, List, NamedTuple, Optional, Set, Union
import inspect
import json
import os
import shutil

from yt.wrapper import JsonFormat
from yt.wrapper.client import YtClient
from yt.wrapper.errors import YtCypressTransactionLockConflict
from yt.wrapper.schema.table_schema import TableSchema
import dateutil.parser
import yt.wrapper as yt
import yt.wrapper.transaction as transaction

from travel.hotels.content_manager.lib.common import get_dc_key_fields, get_dc_yt_schema
from travel.hotels.content_manager.lib.yql_simple_client import YqlSimpleClient


TableRow = Dict[str, Any]


TableData = List[TableRow]


TableIndexedData = Dict[tuple, TableRow]


TableSchemaDict = Dict[str, str]


class Table(NamedTuple):
    data: TableIndexedData
    schema: TableSchemaDict


FieldType = Union[Optional[int], Optional[float], Optional[str], Optional[bool]]


ConditionValue = Union[FieldType, List[FieldType], Set[FieldType]]


class ConditionFields(NamedTuple):
    field: str
    operator: str
    value: ConditionValue


class Condition(object):

    def __init__(
            self,
            field: Optional[str] = None,
            operator: Optional[str] = None,
            value: Optional[ConditionValue] = None,
            condition_list: Optional[List[ConditionFields]] = None
    ):
        if condition_list is not None:
            self.condition_list: List[ConditionFields] = condition_list
            return

        self.condition_list: List[ConditionFields] = [

            ConditionFields(
                field=field,
                operator=operator,
                value=value,
            ),
        ]

    def __and__(self, other: 'Condition') -> 'Condition':
        return Condition(condition_list=(self.condition_list.copy() + other.condition_list.copy()))


CHECKERS = {
    'is': lambda a, b: a is b,
    'is not': lambda a, b: a is not b,
    'in': lambda a, b: a in b,
    'not in': lambda a, b: a not in b,
    '==': lambda a, b: a == b,
    '!=': lambda a, b: a != b,
    '>': lambda a, b: a > b,
    '>=': lambda a, b: a >= b,
    '<': lambda a, b: a < b,
    '<=': lambda a, b: a <= b,
}


class FakeTransaction(NamedTuple):
    transaction_id: str


Transaction = Union[FakeTransaction, transaction.Transaction]


class Node(NamedTuple):
    name: str
    path: str
    type: str
    created_at: datetime


class LockMode(Enum):
    SNAPSHOT = 'snapshot'
    SHARED = 'shared'
    EXCLUSIVE = 'exclusive'


class PersistenceManager(object):

    def __init__(self):
        self.upstream_changed = False

    @abstractmethod
    def copy(self, src: str, dst: str) -> None:
        pass

    @abstractmethod
    def create_dir(self, path: str) -> None:
        pass

    @abstractmethod
    def dirname(self, path: str) -> str:
        pass

    @abstractmethod
    def exists(self, path: str) -> bool:
        pass

    @abstractmethod
    def delete(self, path: str) -> None:
        pass

    @abstractmethod
    def is_dir(self, path: str) -> bool:
        pass

    @abstractmethod
    def flush(self) -> None:
        pass

    @abstractmethod
    def get(self, path) -> Any:
        pass

    @abstractmethod
    def join(self, *parts: str) -> str:
        pass

    @abstractmethod
    def link(self, link: str, target: str) -> None:
        pass

    @abstractmethod
    def list(self, path: str) -> List[Node]:
        pass

    @abstractmethod
    def lock(self, path: str, mode: LockMode = LockMode.EXCLUSIVE) -> bool:
        pass

    @abstractmethod
    def move(self, src: str, dst: str) -> None:
        pass

    @abstractmethod
    def read(self, path: str) -> Generator[TableRow]:
        pass

    @abstractmethod
    def realpath(self, path: str) -> str:
        pass

    @abstractmethod
    def reset_cache(self) -> None:
        pass

    @abstractmethod
    def row_count(self, path: str) -> int:
        pass

    @abstractmethod
    def set(self, path, value: Dict[str, Any]) -> None:
        pass

    @abstractmethod
    def split(self, path: str) -> List[str]:
        pass

    @abstractmethod
    def write(self, path: str, data: Iterable[TableRow], schema: Dict[str, str] = None) -> None:
        pass

    @abstractmethod
    def transaction(self) -> Transaction:
        pass

    @abstractmethod
    def request_select(
            self,
            src_path: str,
            dc: dataclass,
            match_conditions: List[Condition],
    ) -> TableData:
        pass

    @abstractmethod
    def request_upsert(
            self,
            src_data: TableData,
            dst_path: str,
            dc: dataclass,
            fields_to_update: Optional[List[str]] = None,
    ) -> None:
        pass

    def copy_dir(self, src_dir: str, dst_dir: str):
        if not self.exists(dst_dir):
            self.create_dir(dst_dir)
        for table in self.list(src_dir):
            src = self.join(src_dir, table.name)
            dst = self.join(dst_dir, table.name)
            self.copy(src, dst)

    @abstractmethod
    def request_remove(self, to_remove: TableData, remove_from: str, dc: dataclass) -> None:
        pass


class LocalPersistenceManager(PersistenceManager):

    DEFAULT_VALUES = {
        'string': '',
        'uint32': 0,
        'uint64': 0,
        'int32': 0,
        'int64': 0,
        'double': .0,
        'boolean': False,
        'any': dict(),
    }

    def __init__(self, remote_root: str, local_root: str):
        super().__init__()

        self.remote_root: str = remote_root
        self.local_root: str = local_root

    def get_local_path(self, path):
        rel_path = os.path.relpath(path, self.remote_root)
        return os.path.normpath(os.path.join(self.local_root, rel_path))

    def localize_path(*args_to_replace: str) -> Callable:
        def adapter(func) -> Callable:
            arg_names = inspect.getfullargspec(func).args[1:]

            def wrapper(self, *args, **kwargs):
                actual_args = dict(zip(arg_names, args))
                actual_args.update(kwargs)
                for arg in args_to_replace:
                    actual_args[arg] = self.get_local_path(actual_args[arg])
                return func(self, **actual_args)

            return wrapper
        return adapter

    @localize_path('src', 'dst')
    def copy(self, src, dst):
        shutil.copy(src, dst)

    @localize_path('path')
    def create_dir(self, path):
        os.makedirs(path, exist_ok=True)

    def dirname(self, path):
        return os.path.dirname(path)

    @localize_path('path')
    def delete(self, path):
        if os.path.isdir(path):
            shutil.rmtree(path)
        else:
            os.remove(path)

    @localize_path('path')
    def is_dir(self, path):
        return not os.path.exists(self.join(path, '.table'))

    @localize_path('path')
    def exists(self, path):
        return os.path.exists(path)

    def flush(self) -> None:
        pass

    @localize_path('path')
    def get(self, path) -> Any:
        with open(path) as f:
            return json.load(f)

    def join(self, *parts):
        return os.path.join(*parts)

    @localize_path('link', 'target')
    def link(self, link: str, target: str) -> None:
        os.symlink(target, link)

    @localize_path('path')
    def list(self, path: str) -> List[Node]:
        nodes = list()
        for raw_node in Path(path).iterdir():
            node = Node(
                name=str(raw_node.name),
                path=str(raw_node),
                type='table' if (raw_node / '.table').exists() else 'map_node',
                created_at=datetime.fromtimestamp(raw_node.stat().st_mtime)
            )
            nodes.append(node)
        return nodes

    def lock(self, path: str, mode: LockMode = LockMode.EXCLUSIVE) -> bool:
        return False

    @localize_path('src', 'dst')
    def move(self, src: str, dst: str) -> None:
        shutil.move(src, dst)

    @localize_path('path')
    def read(self, path):
        path = self.join(path, '.table')
        with open(path) as f:
            for line in f.readlines():
                yield json.loads(line)

    def realpath(self, path: str) -> str:
        return os.path.realpath(path)

    def reset_cache(self) -> None:
        pass

    @localize_path('path')
    def row_count(self, path: str) -> int:
        path = self.join(path, '.table')
        with open(path) as f:
            return len(f.readlines())

    @localize_path('path')
    def set(self, path, value: Dict[str, Any]) -> None:
        with open(path, 'w') as f:
            return json.dump(value, f)

    def split(self, path):
        return os.path.split(path)

    @localize_path('path')
    def write(self, path, data, schema):
        defaults = {key: self.DEFAULT_VALUES[value] for key, value in schema.items()}
        os.makedirs(path, exist_ok=True)
        path = self.join(path, '.table')
        with open(path, 'w') as f:
            for row in data:
                complete_row = defaults.copy()
                complete_row.update(row)
                f.write(json.dumps(complete_row) + '\n')

    @contextmanager
    def transaction(self) -> Transaction:
        yield FakeTransaction(transaction_id='fake_id')

    def request_select(
            self,
            src_path: str,
            dc: dataclass,
            match_conditions: List[Condition],
    ) -> TableData:
        input_data = self.read(src_path)
        return [item for item in input_data if check_conditions(item, match_conditions)]

    def request_upsert(
            self,
            src_data: TableData,
            dst_path: str,
            dc: dataclass,
            fields_to_update: Optional[List[str]] = None,
    ) -> None:
        key_fields = get_dc_key_fields(dc)
        assert key_fields
        schema = get_dc_yt_schema(dc)

        if fields_to_update is None:
            fields_to_update = schema.keys()

        src_dict = get_dict(src_data, key_fields)
        dst_dict = get_dict(self.read(dst_path), key_fields)
        for key, row in src_dict.items():
            update_row(key, row, dst_dict, fields_to_update)
        self.write(dst_path, dst_dict.values(), schema)

        self.upstream_changed = True

    def request_remove(self, to_remove: TableData, remove_from: str, dc: dataclass) -> None:
        key_fields = get_dc_key_fields(dc)
        assert key_fields
        schema = get_dc_yt_schema(dc)
        to_remove_dict = get_dict(to_remove, key_fields)
        remove_from_dict = get_dict(self.read(remove_from), key_fields)

        new_dst = {}

        for key, row in remove_from_dict.items():
            if key in to_remove_dict:
                continue
            new_dst[key] = row

        self.write(remove_from, new_dst.values(), schema)

        self.upstream_changed = True


class YtPersistenceManager(PersistenceManager):

    def __init__(self, yt_client: YtClient, yql_client: Optional[YqlSimpleClient] = None):
        super().__init__()

        self.yt_client = yt_client
        self.yql_client = yql_client

    def copy(self, src, dst):
        # noinspection PyUnresolvedReferences
        node_type = self.yt_client.get_type(src + '&')
        if node_type == 'link':
            # noinspection PyUnresolvedReferences
            link_path = self.yt_client.get(src + '&/@target_path')
            self.link(dst, link_path)
        else:
            # noinspection PyUnresolvedReferences
            self.yt_client.copy(src, dst)

    def create_dir(self, path):
        # noinspection PyUnresolvedReferences
        self.yt_client.create('map_node', path, recursive=True)

    def dirname(self, path):
        return yt.ypath_dirname(path)

    def delete(self, path):
        # noinspection PyUnresolvedReferences
        self.yt_client.remove(path, recursive=True)

    def is_dir(self, path):
        # noinspection PyUnresolvedReferences
        attribute_path = self.join(path, '@type')
        return self.get(attribute_path) == 'map_node'

    def exists(self, path):
        # noinspection PyUnresolvedReferences
        return self.yt_client.exists(path)

    def flush(self) -> None:
        pass

    def get(self, path) -> Any:
        # noinspection PyUnresolvedReferences
        return self.yt_client.get(path)

    def join(self, *parts):
        return yt.ypath_join(*parts)

    def link(self, link: str, target: str) -> None:
        # noinspection PyUnresolvedReferences
        self.yt_client.link(target, link, force=True)

    def list(self, path: str) -> List[Node]:
        nodes = list()
        for raw_node in self.yt_client.list(path, attributes=['type', 'creation_time']):
            node = Node(
                name=str(raw_node),
                path=self.join(path, raw_node),
                type=raw_node.attributes['type'],
                created_at=dateutil.parser.parse(raw_node.attributes['creation_time']),
            )
            nodes.append(node)
        return nodes

    def lock(self, path: str, mode: LockMode = LockMode.EXCLUSIVE) -> bool:
        try:
            self.yt_client.lock(path, mode=mode.value)
            return True
        except YtCypressTransactionLockConflict as e:
            logging.info(e)
            return False

    def move(self, src: str, dst: str) -> None:
        self.yt_client.move(src, dst, recursive=True, preserve_creation_time=True, preserve_modification_time=True)

    def read(self, path):
        # noinspection PyUnresolvedReferences
        return self.yt_client.read_table(path)

    def realpath(self, path: str) -> str:
        # noinspection PyUnresolvedReferences
        node_type = self.yt_client.get_type(path + '&')
        if node_type == 'link':
            # noinspection PyUnresolvedReferences
            path = self.yt_client.get(path + '&/@target_path')
        return path

    def reset_cache(self) -> None:
        pass

    def row_count(self, path: str) -> int:
        # noinspection PyUnresolvedReferences
        return self.yt_client.row_count(path)

    def set(self, path, value: Dict[str, Any]) -> None:
        # noinspection PyUnresolvedReferences
        self.yt_client.set(path, value)

    @staticmethod
    def get_table_schema(common_schema, sort_by=None, hide_fields=None):
        if isinstance(common_schema, TableSchema):
            return common_schema
        if sort_by is None:
            sort_by = []
        if hide_fields is None:
            hide_fields = []
        for field in hide_fields:
            common_schema.pop(field, None)
        unsorted_fields = [{'name': n, 'type': t} for n, t in common_schema.items() if n not in sort_by]
        sorted_fields = [{'name': n, 'type': common_schema[n], 'order': 'ascending'} for n in sort_by]
        return sorted_fields + unsorted_fields

    def split(self, path):
        return yt.ypath_split(path)

    def write(self, path, data, schema=None):
        # noinspection PyUnresolvedReferences
        if self.yt_client.exists(path):
            self.delete(path)
        attributes = {}
        if schema is not None:
            schema = self.get_table_schema(schema)
            attributes = {'schema': schema}
        # noinspection PyUnresolvedReferences
        self.yt_client.create('table', path, attributes=attributes, recursive=True)
        # noinspection PyUnresolvedReferences
        self.yt_client.write_table(path, data, format=JsonFormat(attributes={'encode_utf8': False}))

    def transaction(self) -> Transaction:
        # noinspection PyUnresolvedReferences
        return self.yt_client.Transaction()

    @staticmethod
    def convert_type(v: ConditionValue) -> str:
        if isinstance(v, list):
            items = (YtPersistenceManager.convert_type(item) for item in v)
            return f'({", ".join(items)})'
        elif isinstance(v, str):
            return f'"{v}"'
        else:
            return str(v)

    @staticmethod
    def condition_fields_to_str(condition_fields: ConditionFields) -> str:
        checker = CHECKERS.get(condition_fields.operator)
        if checker is None:
            raise ValueError(f'Unsupported condition operator: "{condition_fields.operator}"')

        value = condition_fields.value

        if value is None:
            value = 'NULL'
        else:
            value = YtPersistenceManager.convert_type(value)

        return f'{condition_fields.field} {condition_fields.operator.upper()} {value}'

    @staticmethod
    def condition_to_str(condition: Condition) -> str:
        parts = (YtPersistenceManager.condition_fields_to_str(c) for c in condition.condition_list)
        result = ' AND\n                '.join(parts)
        if len(condition.condition_list) > 1:
            result = f'({result})'
        return result

    @staticmethod
    def get_select_query(src: str, dst: str, schema: Dict[str, str], filters: List[Condition]) -> str:
        fields = ',\n                '.join(schema.keys())
        conditions = ''
        if filters:
            conditions = (YtPersistenceManager.condition_to_str(c) for c in filters)
            conditions = ' OR\n                '.join(conditions)
            conditions = f'WHERE\n                {conditions}'
        query = f'''
            INSERT INTO `{dst}`
            SELECT
                {fields}
            FROM `{src}`
            {conditions}
        '''
        return dedent(query)

    def request_select(
            self,
            src_path: str,
            dc: dataclass,
            match_conditions: List[Condition],
    ) -> TableData:
        schema = get_dc_yt_schema(dc)
        with self.yt_client.TempTable() as temp:
            query = self.get_select_query(src_path, temp, schema, match_conditions)
            self.yql_client.run_query(query)
            return list(self.read(temp))

    @staticmethod
    def get_join_query(
            src: str,
            dst: str,
            schema: Dict[str, str],
            key_fields: List[str],
            fields_to_update: List[str],
            join_type: str,
    ) -> str:
        fields_to_update = set(fields_to_update)
        fields = list()
        for field in schema.keys():
            first, second = ('src', 'dst') if field in fields_to_update else ('dst', 'src')
            field = f'{first}.{field} ?? {second}.{field} AS {field}'
            fields.append(field)
        fields = ',\n                '.join(fields)
        key_fields = ', '.join(key_fields)
        query = f'''
            INSERT INTO `{dst}` WITH TRUNCATE
            SELECT
                {fields}
            FROM `{dst}` AS dst
            {join_type} JOIN `{src}` AS src
            USING({key_fields});
        '''
        return dedent(query)

    def request_upsert(
            self,
            src_data: TableData,
            dst_path: str,
            dc: dataclass,
            fields_to_update: Optional[List[str]] = None,
    ) -> None:
        key_fields = get_dc_key_fields(dc)
        assert key_fields
        schema = get_dc_yt_schema(dc)

        if fields_to_update is None:
            fields_to_update = schema.keys()

        with self.yt_client.TempTable() as temp:
            self.write(temp, src_data, schema)
            query = self.get_join_query(temp, dst_path, schema, key_fields, fields_to_update, 'FULL')
            self.yql_client.run_query(query)

        self.upstream_changed = True

    def request_remove(self, to_remove: TableData, remove_from: str, dc: dataclass) -> None:
        key_fields = get_dc_key_fields(dc)
        assert key_fields
        schema = get_dc_yt_schema(dc)
        with self.yt_client.TempTable() as temp:
            self.write(temp, to_remove, schema)
            query = self.get_join_query(temp, remove_from, schema, key_fields, [], 'LEFT ONLY')
            self.yql_client.run_query(query)

        self.upstream_changed = True


class YtCachePersistenceManager(YtPersistenceManager):

    def __init__(self, yt_client: YtClient, yql_client: YqlSimpleClient) -> None:
        super(YtCachePersistenceManager, self).__init__(yt_client, yql_client)
        self.cache: Dict[str: Table] = dict()

        self.upstream_changed = False

    def ensure_cached(self, path: str, schema: TableSchemaDict, key_fields: List[str]) -> None:
        if path in self.cache:
            return

        data = {get_key(row, key_fields): row for row in self.read(path)}
        self.cache[path] = Table(data, schema)

    def flush(self) -> None:
        if not self.upstream_changed:
            return
        for path, table in self.cache.items():
            self.write(path, table.data.values(), table.schema)

    def row_count(self, path: str) -> int:
        table = self.cache.get(path)
        if table is None:
            return super(YtCachePersistenceManager, self).row_count(path)
        return len(table.data)

    def request_select(
            self,
            src_path: str,
            dc: dataclass,
            match_conditions: List[Condition],
    ) -> TableData:
        schema = get_dc_yt_schema(dc)
        key_fields = get_dc_key_fields(dc)

        self.ensure_cached(src_path, schema, key_fields)
        input_data: TableIndexedData = self.cache[src_path].data
        return [item for item in input_data.values() if check_conditions(item, match_conditions)]

    def request_upsert(
            self,
            src_data: TableData,
            dst_path: str,
            dc: dataclass,
            fields_to_update: Optional[List[str]] = None,
    ) -> None:
        schema = get_dc_yt_schema(dc)
        key_fields = get_dc_key_fields(dc)
        assert key_fields

        if fields_to_update is None:
            fields_to_update = schema.keys()

        self.ensure_cached(dst_path, schema, key_fields)
        dst_dict: TableIndexedData = self.cache[dst_path].data
        src_dict = get_dict(src_data, key_fields)

        for key, row in src_dict.items():
            update_row(key, row, dst_dict, fields_to_update)

        self.upstream_changed = True

    def request_remove(self, to_remove: TableData, remove_from: str, dc: dataclass) -> None:
        schema = get_dc_yt_schema(dc)
        key_fields = get_dc_key_fields(dc)
        assert key_fields

        self.ensure_cached(remove_from, schema, key_fields)

        dst_dict: TableIndexedData = self.cache[remove_from].data
        src_dict = get_dict(to_remove, key_fields)
        updated_cache_data: TableIndexedData = {}
        for key, row in dst_dict.items():
            if key in src_dict.keys():
                continue
            updated_cache_data[key] = row
        self.cache[remove_from] = Table(updated_cache_data, schema)
        self.flush()

        self.upstream_changed = True

    def reset_cache(self) -> None:
        self.cache.clear()


def check_condition_item(row: TableRow, condition: ConditionFields) -> bool:
    checker = CHECKERS.get(condition.operator)
    if checker is None:
        raise ValueError(f'Unsupported condition operator: "{condition.operator}"')
    return checker(row[condition.field], condition.value)


def check_condition(row: TableRow, condition: Condition) -> bool:
    return all(check_condition_item(row, c) for c in condition.condition_list)


def check_conditions(row: TableRow, conditions: List[Condition]):
    return any(check_condition(row, condition) for condition in conditions)


def get_key(row: TableRow, key_fields: List[str]) -> tuple:
    return tuple(row[k] for k in key_fields)


def get_dict(data: Iterable[TableRow], key_fields: List[str]) -> TableIndexedData:
    result = dict()
    for row in data:
        key = get_key(row, key_fields)
        if key in result:
            raise RuntimeError(f'Key duplicates {key}')
        result[key] = row
    return result


def update_row(
        key: tuple,
        row: TableRow,
        data: TableIndexedData,
        fields_to_update: List[str]
) -> None:
    patch = {f: row[f] for f in fields_to_update}
    row_to_update = data.setdefault(key, row.copy())
    row_to_update.update(patch)
