# -*- coding: utf-8 -*-
import pickle
import importlib
import contextlib
import uuid

# Тут нужен именно стандартный json, чтобы проверять, что psycopg сможет
# сериализовать параметры таска.
import json as standard_json

from datetime import timedelta
from intranet.yandex_directory.src.yandex_directory.directory_logging.logger import log
from intranet.yandex_directory.src.yandex_directory.common.datatools import flatten_lists

from intranet.yandex_directory.src.yandex_directory import app
from intranet.yandex_directory.src.yandex_directory.core.utils import deferred
from intranet.yandex_directory.src.yandex_directory.common.utils import (
    utcnow,
    remove_sensitive_data,
    hide_sensitive_params,
    get_exponential_step,
)
from intranet.yandex_directory.src.yandex_directory.core.models import TaskModel, TaskRelationsModel
from intranet.yandex_directory.src.yandex_directory.core.models.task import DEPENDENCIES_STATE
from intranet.yandex_directory.src.yandex_directory.core.task_queue.exceptions import (
    UnknownTask,
    TaskFailed,
    TaskInProgress,
    DuplicatedTask,
    DependencyCreationError,
    Defer,
    Suspend,
)
from intranet.yandex_directory.src.yandex_directory.core.utils import only_attrs
from intranet.yandex_directory.src.yandex_directory.core.utils.ycrid_manager import ycrid_manager


def queue_name_for_current_env(name):
    """
    Генерируем имя очереди вида 'ENVIRONMENT-name'.
    :param name: имя очереди
    :rtype: str
    """
    return '{}-{}'.format(app.config['ENVIRONMENT'], name)

_DEFAULT_QUEUE = None


def get_default_queue():
    if _DEFAULT_QUEUE is None:
        raise RuntimeError('Please, call setup_task_queue')
    return _DEFAULT_QUEUE


def setup_task_queue(app):
    global _DEFAULT_QUEUE
    name = app.config['DEFAULT_QUEUE_NAME']
    _DEFAULT_QUEUE = queue_name_for_current_env(name)


class TASK_STATES:
    free = 'free'
    in_progress = 'in-progress'
    suspended = 'suspended'

    failed = 'failed'
    rollback = 'rollback'
    rollback_failed = 'rollback-failed'
    success = 'success'
    canceled = 'canceled'


TERMINATE_STATES = [
    TASK_STATES.failed,
    TASK_STATES.rollback,
    TASK_STATES.success,
    TASK_STATES.rollback_failed,
    TASK_STATES.canceled,
]

ACTIVE_STATES = [
    TASK_STATES.free,
    TASK_STATES.in_progress,
]


def load_tasks():
    """
    Регистрируем доступные задачи
    """
    for module in app.TASK_MODULES:
        with log.fields(module=module):
            try:
                importlib.import_module(module)
                log.debug('Task from module registered')
            except ImportError:
                log.error('Cant import module')


class TaskType(type):
    """
    Мета класс для автоматической регистрации задач
    в реестре задач
    """
    task_types = {}

    def __new__(cls, name, bases, *args, **kwargs):
        task_class = super(TaskType, cls).__new__(cls, name, bases, *args, **kwargs)
        task_name = task_class.get_task_name()
        # Базовый класс таска не надо добавлять в список
        if not task_name.endswith('.Task'):
            cls.task_types[task_name] = task_class
        return task_class


class AsyncResult(object):
    def __init__(self, main_connection, task_id):
        self.main_connection = main_connection
        self._task_id = task_id

        self.task_model = TaskModel(self.main_connection)
        if not self.task_model.count({'id': task_id}):
            raise UnknownTask(task_id)

    def _get_task_data(self, force=False):
        if force or not hasattr(self, '__task_data'):
            self.__task_data = self.task_model.get(self.task_id)
        return self.__task_data

    @property
    def task_id(self):
        return self._task_id

    @property
    def state(self):
        return self._get_task_data().get('state')

    @property
    def exception(self):
        exception = self._get_task_data().get('exception', '')
        if exception:
            exception = exception.encode('utf-8')
            return pickle.loads(exception)

    def get_metadata(self):
        return self._get_task_data().get('metadata')

    def get_result(self):
        data = self._get_task_data()
        state = data['state']
        if state == TASK_STATES.success:
            return pickle.loads(data['result'].encode('utf-8'))
        if state in (TASK_STATES.failed, TASK_STATES.rollback_failed):
            exception = data.get('exception')
            ex = None
            if exception:
                exception = exception.encode('utf-8')
                ex = pickle.loads(exception)
            if ex:
                raise ex
            raise TaskFailed(self.task_id)
        raise TaskInProgress(self.task_id)

    def __repr__(self):
        return 'AsyncResult ({})'.format(self.task_id)


class SyncResult(AsyncResult):
    """
    Результат выполнения таска при синхронном режиме работы
    """
    def __init__(self, main_connection, task_id):
        super(SyncResult, self).__init__(main_connection, task_id)
        self._process_task()

    def _process_task(self):
        # задачи выполняются без задержки между попытками
        # синхронный режим нужен для тестов
        # в них можно не ждать

        from intranet.yandex_directory.src.yandex_directory.core.task_queue.worker import TaskProcessor

        with log.fields(task_id=self.task_id):
            log.debug('Task run is sync mode')

            attempt = 1
            while True:

                # На всякий случай встроим защиту от зацикливания
                if attempt == 100:
                    raise RuntimeError('Too many attempts')

                attempt += 1

                # Важно на каждой итерации получать новое состояние таска,
                # потому что на предыдущей итерации процессор мог поменять его
                # состояние и количество оставшихся попыток
                task = self.task_model.get(self.task_id)
                task_cls = TaskType.task_types.get(task['task_name'])

                # Сначала надо проверить, не выполнен ли уже этот таск
                if task['state'] in TERMINATE_STATES:
                    break

                # Далее, надо выполнить все зависимости таска
                dependencies = TaskModel(self.main_connection).get_dependencies(self.task_id)
                for dependency in dependencies:
                    if dependency['state'] not in TERMINATE_STATES:
                        SyncResult(self.main_connection, dependency['id'])

                # Теперь важно снова получить состояние таска из базы, так как
                # на него могло повлиять выполнение / или невыполнение зависимостей
                task = self.task_model.get(self.task_id)

                processor = TaskProcessor(task_cls, task)
                processor.process(
                    self.main_connection,
                )

                # Если таск был отложен во времени с помощью defer, то время следующей попытки
                # скорее всего будет больше чем в том случае, если бы тас отложили из-за ошибки.
                # В этом случае даже после SyncResult таск останется в статусе free.
                # Для для того, чтобы его выполнить до конца, надо будет снова вызвать SyncResult.
                next_try_at = utcnow() + timedelta(seconds=task_cls.tries_delay)
                task = self.task_model.get(self.task_id)
                if task['start_at'] > next_try_at:
                    break

    def __repr__(self):
        return 'SyncResult ({})'.format(self.task_id)


class OrgIDIsRequired(RuntimeError):
    def __init__(self, task):
        super(OrgIDIsRequired, self).__init__(
            'Org_id is required for tasks of class {}'.format(
                task.__class__.__name__
            )
        )


class Task(object, metaclass=TaskType):
    tries = 3  # количество попыток выполнения задачи
    tries_delay = 60  # задержка между выполнением (секунды)
    lock_ttl = tries_delay * (tries + 1)  # время, на которое воркер может залочить таск
    need_rollback = False  # нужно ли откатывать задачу и вызывать метод rollback
    rollback_tries = 3  # количество попыток отката задачи
    singleton = True  # задача не в статусе TERMINATE_STATES должна существовать в единственном экземпляре в очереди
    # Большинство наших тасков должно быть связано с организацией.
    # В случаях, когда это не так, надо явно указать в классе таска.
    # Если этот атрибут True, а в delay метод не передан org_id,
    # то вместо запуска таска будет брошено исключение OrgIDIsRequired.
    # Это позволит избежать ошибок, подобной:
    # https://st.yandex-team.ru/DIR-6674
    org_id_is_required = True
    # Для задач по миграции ящиков, мы назначаем приоритеты
    # в дипазоне от 0 до 100. Чтобы они при этом не тормозили
    # остальные таски, сделаем так, чтобы по умолчанию таски
    # получали больший приоритет.
    # https://st.yandex-team.ru/DIR-4224
    # Приоритет по-умолчанию можно переопределить на уровне класса задчачи.
    default_priority = 100
    sensitive_params = []
    singleton_params = []

    def __init__(self, main_connection, queue=None, task_id=None, parent_task_id=None,
                 depends_on=None, priority=None, task_data=None):
        """
        :param depends_on: список uuid задач, от которых зависит создаваемая задача.
        :type depends_on: list of uuid
        """
        self.queue = queue or get_default_queue()
        self.main_connection = main_connection
        self.task_id = task_id
        self.parent_task_id = parent_task_id
        self.priority = priority or self.default_priority
        self.depends_on = depends_on
        # Здесь будут такие параметры, как started_at и тд.
        # Не знаю почему мы не сделали так сразу, но пока для
        # обратной совместимости сохраняем эти параметры
        # отдельным атрибутом.
        self.task_data = task_data or {}

    def _get_base_filter(self):
        return {
            'state__notequal': TERMINATE_STATES,
            'task_name': self.get_task_name(),
            'queue': self.queue,
        }

    def _get_duplicate_task_id(self, kwargs):
        task_model = TaskModel(self.main_connection)
        filter_data = self._get_base_filter()

        filter_params = {}

        search_keys = self.singleton_params if self.singleton_params else list(kwargs.keys())
        # если определены singleton_params или sensitive_params, то ищем неполное сравнение params__contains
        # иначе сравнивание все параметры
        search_field = 'params__contains' if self.singleton_params or self.sensitive_params else 'params'

        # убираем приватные ключи из фильтра
        # как правило это пароли, фильтровать по ним не будем, там могут быть спец символы
        for key in set(search_keys) - set(self.sensitive_params):
            filter_params[key] = kwargs[key]

        filter_data[search_field] = filter_params

        try:
            task = task_model.filter(**filter_data).one()
        except TypeError:
            with log.fields(filter_data=filter_data):
                log.trace().error('Unable to check if duplicate task exists')
            raise
        if task:
            return task['id']
        return None

    def delay(self, **kwargs):
        """
        Запускаем отложенную задачу
        :rtype: AsyncResult
        """
        if self.org_id_is_required and 'org_id' not in kwargs:
            raise OrgIDIsRequired(self)

        with log.fields(queue=self.queue, task_name=self.get_task_name()):
            log.debug('Trying to create task')
            start_in = timedelta()
            if 'start_in' in kwargs:
                start_in = kwargs['start_in']
                del kwargs['start_in']      # потому что json не обрабатывает timedelta

            metadata = None
            if 'metadata' in kwargs:
                metadata = kwargs['metadata']
                del kwargs['metadata']      # мета данные записываем отдельным параметром

            try:
                standard_json.dumps(kwargs)
            except TypeError:
                raise RuntimeError('Some task parameters are not serializable to JSON')

            if self.singleton:
                duplicate_task_id = self._get_duplicate_task_id(kwargs)
                if duplicate_task_id:
                    raise DuplicatedTask(duplicate_task_id)

            return self.place_into_the_queue(start_in, metadata, **kwargs)

    def get_params(self):
        return TaskModel(self.main_connection).get_params(self.task_id)

    def place_into_the_queue(self, start_in, metadata, **params):
        """Этот метод раньше был частью delay, но был выделен, чтобы в тестах можно было
           замокать само добавление таска в очередь, оставив при этом функционал по проверке
           аргументов.
        """
        task_model = TaskModel(self.main_connection)
        task_name = self.get_task_name()
        task = task_model.create(
            task_name=task_name,
            params=params,
            queue=self.queue,
            ttl=self.lock_ttl,
            start_in=start_in,
            priority=self.priority,
            parent_task_id=self.parent_task_id,
            depends_on=self.depends_on,
            metadata=metadata,
            ycrid=ycrid_manager.get()[0]
        )
        self.task_id = task['id']

        if self.depends_on:
            self.suspend()
            TaskModel(self.main_connection).increment_dependencies_count(
                self.task_id,
                DEPENDENCIES_STATE.new,
                len(self.depends_on),
            )

        with self.log_task_params(self.task_id, params):
            log.debug('Task was created')

        return self._get_task_result(self.task_id)

    @classmethod
    @contextlib.contextmanager
    def log_task_params(cls, task_id, task_params):
        task_params = cls.clean_log_data(**task_params)
        task_name = cls.get_task_name()
        log_params = dict(
            task_id=task_id,
            task_name=task_name,
            task_params=task_params,
        )
        for param_name in ('org_id', 'domain', 'user_id'):
            if param_name in task_params:
                log_params[param_name] = task_params[param_name]

        with log.fields(**log_params):
            yield

    @classmethod
    def clean_log_data(cls, **kwargs):
        # Скрываем приватные данные при логировании и из параметров в БД
        if cls.sensitive_params:
            return remove_sensitive_data(kwargs, cls.sensitive_params)
        else:
            return kwargs

    @classmethod
    def get_task_name(cls):
        return '{}.{}'.format(cls.__module__, cls.__name__)

    def get_metadata(self):
        task = TaskModel(self.main_connection).get(self.task_id, fields=['metadata'])
        if task:
            return task['metadata']

    def set_metadata(self, metadata):
        deferred.add_call(self.real_set_metadata, metadata)

    def update_metadata(self, **kwargs):
        deferred.add_call(self.real_update_metadata, **kwargs)

    def real_set_metadata(self, metadata):
        TaskModel(self.main_connection).save_metadata(self.task_id, metadata)

    def real_update_metadata(self, **kwargs):
        metadata = self.get_metadata()
        if not metadata:
            metadata = kwargs
        elif isinstance(metadata, dict):
            metadata.update(kwargs)
        else:
            raise TypeError('metadata should be a dict')
        self.set_metadata(metadata)

    def suspend(self):
        """
        Приостановим задачу
        """
        if TaskModel(self.main_connection).filter(id=self.task_id, state__notequal=TERMINATE_STATES).count():
            TaskModel(self.main_connection).update_one(
                task_id=self.task_id,
                update_data={'state': TASK_STATES.suspended}
            )

    def resume(self):
        """
        Возобновим приостановленую задачу и перведём её в состояние free
        """
        TaskModel(self.main_connection).update(
            update_data={
                'state': TASK_STATES.free,
                'worker': None,
                'locked_at': None,
                'free_lock_at': None,
            },
            filter_data={
                'id': self.task_id,
                'state': TASK_STATES.suspended,
            }
        )

    def do(self, **kwargs):
        raise NotImplementedError

    def rollback(self, **kwargs):
        pass

    def get_child_tasks_results(self):
        child_task_ids = TaskModel(self.main_connection).filter(parent_task_id=self.task_id).fields('id').scalar()
        return list(map(self._get_task_result, child_task_ids))

    def _get_task_result(self, task_id):
        return AsyncResult(self.main_connection, task_id)

    def wait_for(self, *tasks):
        """
        Добавляем зависимости в таблицу tasks_relations
        Инкрементируем dependencies_count.

        На вход можно передавать как таски, так и их id,
        при чём, как отдельно, так и в виде списка или списков.

        ВНИМАНИЕ: после вызова этого метода таск который его вызвал, переходит в спящее состояние.
        """
        def ensure_id(task):
            if isinstance(task, (int, uuid.UUID)):
                return task
            else:
                return task.task_id

        try:
            task_ids = list(map(ensure_id, flatten_lists(tasks)))
            TaskRelationsModel(self.main_connection).bulk_create(
                [{'task_id': self.task_id, 'dependency_task_id': dep_id} for dep_id in task_ids])
            TaskModel(self.main_connection).increment_dependencies_count(
                self.task_id,
                DEPENDENCIES_STATE.new,
                len(task_ids),
            )
        except Exception as e:
            raise DependencyCreationError(e)

        if task_ids:
            log.info('Waiting for dependencies')
            raise Suspend()
        else:
            # Это странная ситуация которая может возникнуть только в случае нарушения логики.
            # Потому что зачем вызывать wait_for, если тасков которые нужно подождать - нет?
            raise RuntimeError('No tasks to wait')

    def get_dependencies(self, task_id):
        """
        Получаем все зависимости задачи
        """
        dependencies = only_attrs(
            TaskRelationsModel(self.main_connection) \
                .filter(task_id=task_id) \
                .fields('dependency_task_id'),
            'dependency_task_id'
        )
        return list(map(self._get_task_result, dependencies))

    def on_dependency_success(self):
        """
        Инкрементируем количество тасок, завершенных со статусом success
        Если нужно другое поведение - переопределить в подклассе конкретно таски.

        ВНИМАНИЕ: В любом случае, не пытайтесь делать какую-то полезную работу в этом каллбэке.
        Всё, что он должен делать, это изменять состояние зависимого таска с помощью одного из методов:

        * resume - чтобы продолжить выполнение зависимого таска.
        * cancel - чтобы отменить его.

        Таким образом можно управлять тем, как таск реагирует на ошибки.

        По умолчанию, таск будет просыпаться и продолжать свою работу
        только после того, как все зависимости завершились успешно или с ошибкой.
        """
        TaskModel(self.main_connection).increment_dependencies_count(
            self.task_id,
            DEPENDENCIES_STATE.successful,
        )
        if self.is_all_dependencies_completed():
            self.resume()

    def on_dependency_fail(self, task_info):
        """
        Инкрементируем количество тасок, завершенных не со статусом success.

        Более подробное описание того, когда этот метод можно переопределять,
        написано в докстринге on_dependency_success.
        """
        TaskModel(self.main_connection).increment_dependencies_count(
            self.task_id,
            DEPENDENCIES_STATE.failed,
        )
        if self.is_all_dependencies_completed():
            self.resume()

    def on_terminate(self, **kwargs):
        self.clean_private_data_in_tasks()

    def cancel(self):
        """
        Проставляет задаче статус canceled
        и finished_at, чтобы воркер больше не брал эту задачу
        """
        TaskModel(self.main_connection).update(
            update_data={
                'state': TASK_STATES.canceled,
                'finished_at': utcnow(),
            },
            filter_data={
                'id': self.task_id,
            }
        )
        for dependent in self.get_dependents():
            dependent.cancel()

        self.clean_private_data_in_tasks(with_dependent=False)

    def get_age(self):
        """Возвращает количество секунд, прошедших с момента старта."""
        started_at = self.task_data['start_at']
        return (utcnow() - started_at).total_seconds()

    def exponential_defer(self, min_interval, max_interval, const_time):
        """Откладывает таск так, что время засыпания каждый раз увеличивается по экспоненте
           от min_interval до max_interval, но по прошествии const_time интервал
           всегда будет равным max_interval.
        """
        time_since_start = self.get_age()
        countdown = get_exponential_step(
            time_since_start,
            min_interval,
            max_interval,
            const_time,
        )
        self.defer(countdown=countdown)

    def defer(self, retry_at=None, countdown=None, suspended=False):
        """Откладывает таск так, либо до определенного момента времени, либо на определённое количество секунд.
        """
        raise Defer(retry_at=retry_at, countdown=countdown, suspended=False)

    def get_dependents(self):
        result = []
        for dependent in TaskModel(self.main_connection)._get_dependents(self.task_id, fields=('task_name', 'id')):
            task_name = dependent['task_name']
            dependent_cls = TaskType.task_types.get(task_name)
            if not dependent_cls:
                with log.fields(dependent_task_name=task_name, dependent_task_id=dependent['id']):
                    raise RuntimeError('No registered task class for task')
            result.append(dependent_cls(self.main_connection, task_id=dependent['id']))
        return result

    def is_all_dependencies_completed(self):
        """
        Возвращает True, когда все задачи от которых зависит текущая, завершились
        """
        task = TaskModel(self.main_connection).get(self.task_id)
        return task['dependencies_count'] == \
               task['successful_dependencies_count'] + task['failed_dependencies_count']

    def has_failed_dependencies(self):
        return self.task_data['failed_dependencies_count'] > 0

    def is_all_dependents_completed(self):
        """
        Возвращает True, когда все задачи которые зависят от текущей, завершились
        :param task_id:
        :return: bool
        """
        dependents = TaskModel(self.main_connection)._get_dependents(self.task_id, fields=('state',))
        for i in dependents:
            if i.get('state', None) not in TERMINATE_STATES:
                return False
        return True

    def clean_private_data_in_tasks(self, task=None, with_dependent=True):
        """
        Отчищаем пароли в данных таска и всех тасков от которых он зависит, если
        все зависимые выполнены
        :param task:
        """
        task_model = TaskModel(self.main_connection)

        if not task:
            task = task_model.filter(id=self.task_id).one()
        if task.get('state', None) in TERMINATE_STATES and self.is_all_dependents_completed():
            meta = self.get_metadata()
            if hide_sensitive_params(meta):
                with log.name_and_fields('security', task_id=self.task_id):
                    log.info('Password was removed from task metadata')
                self.set_metadata(meta)

            params = task_model.get_params(self.task_id)
            if hide_sensitive_params(params):
                with log.name_and_fields('security', task_id=self.task_id):
                    log.info('Password was removed from task parameters')
                    task_model.save_params(self.task_id, params)

            if with_dependent:
                dependents = TaskModel(self.main_connection).get_dependencies(self.task_id)
                for i in dependents:
                    task = Task(self.main_connection, task_id=i['id'])
                    task.clean_private_data_in_tasks(task=i)


def get_short_name(task_name):
    """Убирает из длинного имени таска модуль и взвращает только имя класса.
    """
    return task_name.rsplit('.', 1)[-1]


def get_task_names_map():
    """Возвращает map короткого имени таска в полное."""
    return {
        get_short_name(full_name): full_name
        for full_name in list(TaskType.task_types.keys())
    }
