# -*- coding: utf-8 -*-

import pickle
import datetime
import os
import threading
import traceback

from intranet.yandex_directory.src.yandex_directory.directory_logging.logger import log

from intranet.yandex_directory.src.yandex_directory.common import json
from intranet.yandex_directory.src.yandex_directory.common.db import get_main_connection
from intranet.yandex_directory.src.yandex_directory.common.utils import (
    utcnow,
    app_context,
    log_exception,
)
from intranet.yandex_directory.src.yandex_directory.core.models import TaskModel
from intranet.yandex_directory.src.yandex_directory.core.utils import deferred
from intranet.yandex_directory.src.yandex_directory.core.task_queue.base import (
    TASK_STATES,
    TERMINATE_STATES,
    get_default_queue,
    TaskType,
    load_tasks,
)
from intranet.yandex_directory.src.yandex_directory.core.task_queue.exceptions import (
    Defer,
    Suspend,
)

load_tasks()

__all__ = ['Worker']


class TaskProcessor(object):
    def __init__(self, task_cls, task_data):
        self.task_cls = task_cls
        self.task_data = task_data

        self.task_id = task_data['id']
        if isinstance(task_data['params'], str):
            # данные из БД
            params = json.loads(self.task_data['params'])
        else:
            # для работы синхронного режима
            params = self.task_data['params']

        self.metadata = self.task_data['metadata']
        self.params = params
        self.ycrid = task_data['ycrid']
        # Значение будет установлено методом process
        self.task = None

    def process_do(self):
        """
        Выполняем задачу.
        """
        # если случится исключение в методе task.do
        # то этот случай обработается в process_do_error.
        try:
            result = self.task.do(**self.params)
        except TypeError as exc:
            message = str(exc)
            if 'got an unexpected keyword argument' in message:
                parameter = message.rsplit(' ', 1)[-1].strip("'")
                raise RuntimeError('Unexpected parameter "{0}" in task class {1}'.format(
                    parameter,
                    self.task.__class__
                ))
            else:
                raise
        return {
            'tries': self.task_data['tries'] + 1,
            'state': TASK_STATES.success,
            'finished_at': utcnow(),
            'result': pickle.dumps(result, protocol=0).decode('utf-8'),
        }

    def process_do_error(self, ex):
        """
        Обработчик ошибки возникшей при выполнении задачи.

        ВНИМАНИЕ: этот метод не должен ничего писать в базу.

        :param ex: исключение
        """
        log.warning('Processing task error')
        data = {}
        current_try = self.task_data['tries'] + 1

        try:
            data['exception'] = pickle.dumps(ex, protocol=0).decode('utf-8')
        except Exception:
            data['exception'] = None

        data['traceback'] = traceback.format_exc()

        data['tries'] = current_try
        if current_try < self.task_cls.tries:
            # есть еще попытки на выполнение задачи
            data['start_at'] = utcnow() + datetime.timedelta(seconds=self.task_cls.tries_delay)

        if current_try == self.task_cls.tries and not self.task_cls.need_rollback:
            # все попытки исчерпаны
            data['state'] = TASK_STATES.failed
            data['finished_at'] = utcnow()
            log.info('All attempts to execute task are exhausted')
        return data

    def process_rollback(self):
        """
        Откат задачи
        """
        data = {
            'rollback_tries': self.task_data['rollback_tries'] + 1,
            'state': TASK_STATES.rollback,
            'finished_at': utcnow(),
        }
        self.task.rollback(**self.params)
        return data

    def process_rollback_error(self, ex):
        """
        Обработчик ошибки возникшей при откате задачи/

        ВНИМАНИЕ: этот метод не должен ничего писать в базу.

        :param ex:
        """
        log.error('Processing task rollback error')
        data = {}
        current_rollback_tries = self.task_data['rollback_tries'] + 1
        data['rollback_tries'] = current_rollback_tries
        if current_rollback_tries == self.task_cls.rollback_tries:
            data['state'] = TASK_STATES.rollback_failed
            data['finished_at'] = utcnow()
        return data

    def get_process_func(self):
        """
        Выбираем функцию выполнения/отката задачи
        И функцию обработчик ошибок выполнения/отката задачи
        :rtype: tuple(func, func)
        """
        current_rollback_try = self.task_data['rollback_tries'] + 1
        current_try = self.task_data['tries'] + 1
        if current_try <= self.task_cls.tries:
            return self.process_do, self.process_do_error
        elif self.task_cls.need_rollback and current_rollback_try <= self.task_cls.rollback_tries:
            return self.process_rollback, self.process_rollback_error

    def process(self, main_connection):
        """
        Выполняем/откатываем выбранную задачу
        """
        with self.task_cls.log_task_params(self.task_id, self.params), \
             app_context():

            log.info('Trying to process task')

            self.task = self.task_cls(
                main_connection,
                queue=self.task_data['queue'],
                task_id=self.task_id,
                task_data=self.task_data,
            )

            task_info = {
                'worker': None,
                'locked_at': None,
                'free_lock_at': None,
            }
            process_error_func = lambda *args: {}
            try:
                # выберем функции в зависимости от состояния задачи
                # обработка или откат
                process_func, process_error_func = self.get_process_func()

                # Здесь мы стартуем подтранзакцию, чтобы в случае ошибки,
                # откатились только те изменения, которые таск сделал внутри
                # метода do. Раньше мы транзакцию не стартовали, и эти
                # изменения успешно коммитились в блоке finally, вместе
                # с обновлением данных по самому таску.
                #
                # Если интересно, то вот тут про это можно почитать:
                # https://st.yandex-team.ru/DIR-6906

                # После выполнения или отката транзакции мы применим
                # отложенные операции над метаданными так,
                # чтобы они сохранились даже в том случае, если при обработке
                # таска вылетело исключение
                with deferred.calls_at_the_end(), \
                     main_connection.begin_nested():
                    try:
                        result = process_func()
                        task_info.update(result)
                        log.debug('Task successfully executed')
                    except Defer as ex:
                        task_info.update({
                            'start_at': ex.when,
                            'state': TASK_STATES.suspended if ex.suspended else TASK_STATES.free,
                        })
                        with log.fields(retry_datetime=ex.when):
                            log.debug('Task deferred')
                    except Suspend:
                        task_info.update({
                            'state': TASK_STATES.suspended,
                        })
                        log.debug('Task suspended')
            except Exception as ex:
                log_exception(ex, 'Task error occurred')

                # Тут начинать подтранзакцию не надо, потому что функции,
                # обрабатывающие ошибки, ничего в базу не пишут
                result = process_error_func(ex)
                task_info.update(result)
            finally:
                TaskModel(main_connection).update_one(self.task_id, task_info)

                if task_info.get('state', None) in TERMINATE_STATES:
                    self.task.on_terminate(**self.params)

                    for dependent in self.task.get_dependents():
                        if task_info['state'] == TASK_STATES.success:
                            dependent.on_dependency_success()
                        else:
                            dependent.on_dependency_fail(task_info)


class Worker(object):
    def __init__(self, main_connection, queue=None):
        self.queue = queue or get_default_queue()

    @staticmethod
    def get_worker_name():
        """
        Генерируем имя для обработчика задач
        :rtype: str
        """
        return '{}:{}:{}'.format(
            os.environ.get('DEPLOY_NODE_FQDN', os.environ.get('QLOUD_HOSTNAME', 'unknown')),
            os.getpid(),
            threading.currentThread().ident,
        )

    def lock_task(self, shard):
        """
        Захватываем задачу для выполнения воркером
        :rtype: TaskProcessor
        """
        task_processor = None

        # Лочить таск надо с помощью отдельной транзакции, чтобы она закоммитилась
        # и остальные воркеры увидели, что таск занят
        with get_main_connection(shard=shard, for_write=True) as separate_main_connection:
            task_model = TaskModel(separate_main_connection)
            task_data = task_model.lock_for_worker(self.get_worker_name(), self.queue)
            if task_data:
                task_name = task_data['task_name']
                task_id = task_data['id']
                with log.fields(task_id=task_id, task_name=task_name):
                    task_cls = TaskType.task_types.get(task_name)
                    if not task_cls:
                        log.warning('No registered task class for task')
                        task_model.release_task(task_id)
                    else:
                        log.debug('Task locked')
                        task_processor = TaskProcessor(task_cls, task_data)
        return task_processor
