import copy
import datetime
import logging
import typing
from sqlalchemy import and_
from asgiref.sync import async_to_sync
from gino import api
from yt.wrapper import response_stream, yamr_record, YtClient

from crm.agency_cabinet.common.db.models import BaseModel, YtSync
from crm.agency_cabinet.common.utils.nullcontext import NullContext

LOGGER = logging.getLogger('yt.loader')


class BaseRowLoadException(Exception):

    default_message = 'UNKNOWN'

    def __init__(self, message: str = None):
        self.message = message or self.default_message

    def __str__(self):
        return self.message


class CriticalRowLoadException(BaseRowLoadException):
    pass


class ExpectedColumnNotFound(BaseRowLoadException):
    pass


class ColumnExtractor:
    def __call__(self, yt_row, loader: 'YtModelLoader'):
        raise NotImplementedError


class ConstantExtractor(ColumnExtractor):
    def __init__(self, value: typing.Any):
        self.value = value

    def __call__(self, yt_row, loader: 'YtModelLoader'):
        return self.value


class DictExtractor(ColumnExtractor):
    def __init__(self, key: str):
        self.key = key

    def __call__(self, yt_row, loader: 'YtModelLoader'):
        if self.key not in yt_row:
            raise BaseRowLoadException(f'No {self.key} in row: {yt_row}')
        return yt_row[self.key]


class ObfuscateDictExtractor(DictExtractor):

    def __init__(self, key: str, obfuscator: typing.Callable):
        super().__init__(key)
        self.obfuscator = obfuscator

    def __call__(self, yt_row, loader: 'YtModelLoader'):
        value = super().__call__(yt_row, loader)
        return self.obfuscator(value)


class MethodExtractor(ColumnExtractor):
    def __init__(self, method_name):
        self.method_name = method_name

    def __call__(self, yt_row, loader: 'YtModelLoader'):
        try:
            return getattr(loader, self.method_name)(yt_row)
        except Exception as ex:
            raise BaseRowLoadException(f'Can\'t use loader method to extract value: {ex}')


class YqlColumnExtractor(ColumnExtractor):
    def __init__(self, columns: list[str]):
        self.columns = columns

    def __call__(self, yt_row, loader: 'YtModelLoader', column: str = None):
        try:
            return yt_row[self.columns.index(column)]
        except KeyError as ex:
            raise BaseRowLoadException(f'Something went wrong: Unrecognized column name.'
                                       f' \'{column}\' is not in {self.columns} - {ex}')
        except AttributeError as ex:
            raise BaseRowLoadException(f'Something went wrong: property \'columns\' has value of \'None\' - {ex}')


class YtModelLoader:

    expected_columns = ()

    def __init__(self, table_path: str,
                 model: typing.Type[BaseModel],
                 columns_mapper: typing.Dict[str, typing.Union[str, ColumnExtractor]],
                 default_columns: typing.Dict[str, typing.Any],
                 client_config: dict,
                 db_bind: typing.Optional[api._BindContext] = None,
                 update_attribute: str = 'modification_time',
                 limit=None,
                 process_duplicates_after: bool = False,
                 force_load: bool = False,
                 expected_columns: typing.List = None,
                 use_bulk_insert: bool = False,
                 bulk_insert_step: int = 3000,
                 **kwargs):
        self.client = YtClient(proxy=client_config['cluster'],
                               token=client_config['token'],
                               config=client_config.get('config'))
        if not table_path.startswith('//'):
            table_path = f'//{table_path}'
        if not table_path.endswith('/'):
            table_path = f'{table_path}/'
        self.table_path = table_path
        self.updated_at = self._retrieve_table_updated_at(update_attribute)
        self.model = model
        self.db_bind = db_bind if db_bind is not None else NullContext()  # expected set_bind on upper level
        self.columns_mapper = columns_mapper
        self.default_columns = default_columns
        self.limit = limit
        self._table: \
            typing.Optional[typing.Union[response_stream.ResponseStream, typing.Iterator[yamr_record.Record]]] = None
        self.process_duplicates_after = process_duplicates_after
        self.force_load = force_load
        self.is_loaded_last_time = False
        if expected_columns is not None:
            self.expected_columns = expected_columns
        self.use_bulk_insert = use_bulk_insert
        self.bulk_insert_step = bulk_insert_step

        # used by YqlColumnExtractor for extracting column values from list by index using column names
        self._columns = None

        self._init(**kwargs)

    def _init(self, **kwargs):
        pass

    def _before_start(self):
        pass

    def _after_finish(self):
        pass

    def _retrieve_table_updated_at(self, update_attribute: str) -> typing.Optional[datetime.datetime]:
        # TODO: use try except/util.parse
        return datetime.datetime.strptime(
            self.client.get(f'{self.table_path}@{update_attribute}'), '%Y-%m-%dT%H:%M:%S.%fZ'
        ).astimezone(tz=datetime.timezone.utc)

    def _read_table(self, **kwargs):
        # kwargs.setdefault('unordered', True) -- can be faster if turn off retries
        self._table: \
            typing.Optional[typing.Union[response_stream.ResponseStream, typing.Iterator[yamr_record.Record]]] \
            = self.client.read_table(self.table_path[:-1], **kwargs)

    def _find_duplicate(self, yt_row) -> typing.Optional[BaseModel]:
        return None

    def _process_duplicate(self, yt_row, db_row: BaseModel):
        pass

    def _process_duplicates(self, duplicates: typing.List[typing.Tuple[dict, BaseModel]]):
        pass

    def _create_db_rows(self, db_rows) -> typing.List[BaseModel]:
        @async_to_sync
        async def create_db_row_bulk(_db_rows) -> typing.List[BaseModel]:
            instances = []
            LOGGER.info('Try to create %s rows', len(_db_rows))
            async with self.db_bind:
                try:
                    instances = await self.model.bulk_insert(_db_rows, self.bulk_insert_step)
                except Exception as ex:  # noqa
                    LOGGER.exception('Something went wrong: %s', ex)
            return instances

        @async_to_sync
        async def create_db_row(_db_rows) -> typing.List[BaseModel]:
            instances = []
            async with self.db_bind:
                for columns in _db_rows:
                    try:
                        m = await self.model.create(**columns)
                        instances.append(m)
                    except Exception as ex:  # noqa
                        LOGGER.exception('Something went wrong: %s', ex)
            return instances

        return create_db_row_bulk(db_rows) if self.use_bulk_insert else create_db_row(db_rows)

    def _process_new(self, yt_row) -> typing.List[dict]:
        # TODO: check relationship/get foreign keys?
        columns = copy.deepcopy(self.default_columns)
        for key, value in self.columns_mapper.items():
            if isinstance(value, str):
                # this one is for rows in a form of tuples or lists,
                # for example, for YQL-queries' result
                if isinstance(yt_row, tuple) or isinstance(yt_row, list):
                    columns[key] = YqlColumnExtractor(self._columns)(yt_row=yt_row, loader=self, column=value)
                else:
                    columns[key] = DictExtractor(value)(yt_row, self)
            elif isinstance(value, ColumnExtractor):
                columns[key] = value(yt_row, self)
            else:
                raise CriticalRowLoadException(f'Can\'t extract with: {value}')
        return [columns]

    def _preprocess_yt_row(self, yt_row):
        for value in self.expected_columns:
            if value not in yt_row:
                raise ExpectedColumnNotFound(f'No {value} in row: {yt_row}')

    def _get_loader_tag(self) -> str:
        return self.__class__.__name__

    def _sync_after_execution(self):
        @async_to_sync
        async def _update_table_path():
            if self.table_path.startswith('//'):
                table_path = self.table_path[2:]
            else:
                table_path = self.table_path
            async with self.db_bind:
                row = await YtSync.query.where(
                    and_(
                        YtSync.path == table_path,
                        YtSync.tag == self._get_loader_tag()
                    )
                ).gino.first()
                if row is None:
                    await YtSync.create(path=table_path, tag=self._get_loader_tag(), last_used_update=self.updated_at)
                else:
                    await row.update(last_used_update=self.updated_at,
                                     updated_at=datetime.datetime.now(tz=datetime.timezone.utc)).apply()

        try:
            _update_table_path()
        except Exception as ex:
            LOGGER.warning('Couldn\'t update record about sync: %s', ex)

    def check_if_should_execute(self) -> bool:
        @async_to_sync
        async def _check_if_should_sync():
            async with self.db_bind:
                if self.table_path.startswith('//'):
                    table_path = self.table_path[2:]
                else:
                    table_path = self.table_path
                row: YtSync = await YtSync.query.where(
                    and_(YtSync.path == table_path, YtSync.tag == self._get_loader_tag())
                ).gino.first()

                # there is no record with provided table path, so we should add it
                if row is None:
                    return True
                elif not row.enabled:
                    return False

                LOGGER.debug(f'Make check for {self.table_path} '
                             f'with last_updated_at: {self.updated_at} vs {row.last_used_update}')
                return row.last_used_update != self.updated_at

        return _check_if_should_sync()

    def _on_exception(self):
        pass

    def _after_creation(self):
        pass

    def _check_if_should_skip_row(self, yt_row) -> bool:
        return False

    def load(self, **kwargs) -> typing.List[BaseModel]:
        if not self.force_load and not self.check_if_should_execute():
            self.is_loaded_last_time = False

            LOGGER.debug('Skip loading')
            return []

        if self._table is None:
            self._read_table(**kwargs)
        success = True
        new_db_rows = []
        duplicates = []
        self._before_start()
        cnt = 0
        for row in self._table:
            if self._check_if_should_skip_row(row):
                continue

            cnt += 1
            if cnt % 1000 == 0:
                LOGGER.info('Processed %s rows', cnt - 1)
            try:
                # TODO: add possibility to bulk insert by chunks during processing?
                self._preprocess_yt_row(row)
                db_duplicate = self._find_duplicate(row)
                if db_duplicate is None:
                    new_db_rows.extend(self._process_new(row))
                else:
                    duplicates.append((row, db_duplicate))
                    self._process_duplicate(row, db_duplicate)
                if self.limit is not None and len(new_db_rows) >= self.limit:
                    break
            except BaseRowLoadException as ex:
                LOGGER.warning('Skip row because: %s', ex.message, )
            except (Exception, CriticalRowLoadException) as ex:
                LOGGER.exception('Something went wrong, so sync was stopped and can be broken: %s', ex)
                new_db_rows = []  # don't create anything if critical exception occurred
                success = False
                self._on_exception()
                break
        res = []

        if success:
            self._after_finish()
            if self.process_duplicates_after:
                self._process_duplicates(duplicates)
            res = list(self._create_db_rows(new_db_rows))
            self.is_loaded_last_time = True
            LOGGER.info(
                'Finished creation for loader %s',
                self._get_loader_tag()
            )
            self._after_creation()
            self._sync_after_execution()
        return res
