# -*- coding: utf-8 -*-

from __future__ import absolute_import

import traceback
import billiard
import hashlib
import logging
import resource
import sys

from billiard import current_process
from celery import Celery, Task
from celery.worker.job import Request
from celery.signals import (import_modules, task_prerun, before_task_publish, worker_ready, worker_process_init,
                            task_postrun)
from copy import copy, deepcopy
from pprint import pformat
from time import time

import mpfs.engine.process
import mpfs.engine.queue2.queue_log

from mpfs.common.errors import QueueNoAsyncTaskDataError, UserBlocked
from mpfs.common.static.tags.queue import QueueNames
from mpfs.common.util import logger
from mpfs.common.util.experiments.logic import experiment_manager
from mpfs.core.zookeeper.shortcuts import prepare_zookeeper_settings
from mpfs.engine.queue2.celery_custom import Consumer, TaskPool
from mpfs.engine.queue2.logging import setup_consumer_logging, setup_producer_logging
from mpfs.engine.queue2.request_patch import get_on_failure_patch
from mpfs.config import settings
from mpfs.engine.queue2.utils import get_task_func_name
from mpfs.metastorage.mongo.util import decompress_data, compress_data

LOGGER_MONITOR_ENABLED = settings.logger['monitor']['enabled']
QUEUE2_TASKS_MAX_RETRY = settings.queue2['tasks']['max_retry']
QUEUE2_TASKS_WHITELIST = settings.queue2['tasks']['whitelist']
QUEUE2_PUT_TASKS_IN_NEW_QUEUE = settings.queue2['put_tasks_in_new_queue']
QUEUE2_WORKER_LIMITS_MAX_TASK_PER_CHILD = settings.queue2['worker']['limits']['max_tasks_per_child']
QUEUE2_WORKER_LIMITS_MAX_RSS = settings.queue2['worker']['limits']['max_rss']
QUEUE2_WORKER_LIMITS_MAX_LIFETIME = settings.queue2['worker']['limits']['max_lifetime']
QUEUE2_WORKER_ENABLE_ACKS_LATE_FIX = settings.queue2['worker']['enable_acks_late_fix']
SERVICES_TVM_2_0_ENABLED = settings.services['tvm_2_0']['enabled']


KWARGS_DEDUPLICATION_ID_PARAMETER = '__deduplication_id'


# тут суть вот в чем - если стартуем как mpfs, или скрипт, то все тут уже инициализировано (уже заранее выполнен setup)
# и надо просто получить список урлом брокеров, а вот если стартуем как воркер celery, то получим тут None, а
# список брокеров позже получаем и устанавливаем позже - на этапе импорта после установки логгера
broker_urls = mpfs.engine.process.rabbitmq_hosts(shuffle=False)
if broker_urls:
    setup_producer_logging()


app = Celery('mpfs',
             broker=broker_urls,
             include=['mpfs.core.job_handlers.operation',
                      'mpfs.core.job_handlers.operation_failing',
                      'mpfs.core.job_handlers.indexer',
                      'mpfs.core.job_handlers.routine',
                      'mpfs.core.job_handlers.trash_clean',
                      'mpfs.core.job_handlers.user',
                      'mpfs.core.job_handlers.push',
                      'mpfs.core.job_handlers.billing',
                      'mpfs.core.job_handlers.aviary',
                      'mpfs.core.job_handlers.notifier',
                      'mpfs.core.job_handlers.yateam',
                      'mpfs.core.job_handlers.fotki',
                      'mpfs.core.albums.job_handlers',
                      'mpfs.core.lenta.job_handlers',
                      'mpfs.core.organizations.job_handlers',
                      'mpfs.core.last_files.job_handlers',
                      'mpfs.core.versioning.job_handlers',
                      'mpfs.core.versioning.logic.cleaner',
                      'mpfs.core.filesystem.cleaner.hidden',
                      'mpfs.core.job_handlers.discount',
                      'mpfs.core.job_handlers.global_gallery',
                      'mpfs.engine.queue2.control',
                      'mpfs.core.job_handlers.docs',
                      'mpfs.core.inactive_users_flow.logic',
                      'mpfs.core.public_links.tasks'])

app.conf.CELERYD_CONSUMER = Consumer
app.conf.CELERYD_POOL = TaskPool
app.conf.CELERYD_PREFETCH_MULTIPLIER = 1
app.conf.CELERY_ACKS_LATE = True
app.conf.CELERY_DEFAULT_QUEUE = 'submit'
app.conf.CELERY_TASK_SERIALIZER = 'json'
app.conf.CELERY_SEND_EVENTS = False  # http://docs.celeryproject.org/en/3.1/configuration.html#celery-send-events

# Logging:
app.conf.CELERYD_LOG_LEVEL = logging.NOTSET
# Предотвращает перенаправление всего `stdout` в `celery.redirected`
app.conf.CELERY_REDIRECT_STDOUTS = False
app.conf.CELERY_REDIRECT_STDOUTS_LEVEL = logging.NOTSET
# Предотвращает удаление установленных хендлеров на `root` логгер
app.conf.CELERYD_HIJACK_ROOT_LOGGER = False

# http://docs.celeryproject.org/projects/kombu/en/latest/reference/kombu.connection.html#kombu.connection.Connection.failover_strategy
app.conf.BROKER_FAILOVER_STRATEGY = 'shuffle'

if QUEUE2_WORKER_LIMITS_MAX_TASK_PER_CHILD > 0:
    app.conf.CELERYD_MAX_TASKS_PER_CHILD = QUEUE2_WORKER_LIMITS_MAX_TASK_PER_CHILD


def get_current_worker_name():
    # type: () -> str
    """
    Функция возвращает имя celery worker'а, который выглядит как worker1@mpfs1g.disk.yandex.net
    Нужна при постановке задачи при старте в очередь started, завершении в очередь completed и при фейле
    Передается в контексте и queller (java) логирует это у себя, чтобы можно было понять, куда задача уехала
    выполняться - для этого надо будет просто сделать grep по ycrid задачи на queller (в данный момент это
    /var/log/yandex/disk/queller.log)
    """
    p = current_process()
    if p.name == 'MainProcess':
        return None
    return p.initargs[1]


class TaskData(object):
    def __init__(self, uid, data):
        if not isinstance(uid, basestring):
            raise TypeError('string type for uid expected, `%s` received' % type(uid))
        if not (data and isinstance(data, (list, dict))):
            raise ValueError('data field must be non-empty dict or list')

        self.uid = uid
        self.data = data


# этот класс должен быть строго после задания настроек app выше, т.к. он их уже использует при импорте
class BaseTask(Task):
    max_retries = QUEUE2_TASKS_MAX_RETRY  # должно быть None всегда, если мы работаем с queller
    # в development окружении переопределяется, чтобы таска не ретраилась вечно

    def apply_async(self, args=None, kwargs=None, **rest_kwargs):
        if kwargs and 'task_data' in kwargs:
            task_data = kwargs.pop('task_data')
            if not isinstance(task_data, TaskData):
                raise TypeError('`TaskData` object expected in task_data but `%s` received' % type(task_data))

            compressed_data = compress_data(task_data.data)

            from mpfs.engine.queue2.async_tasks.models import AsyncTasksData
            async_task_data = AsyncTasksData(uid=task_data.uid, data=compressed_data)
            async_task_data.save()

            kwargs['__task_data'] = {'uid': task_data.uid, 'db_id': async_task_data.id}

        return super(BaseTask, self).apply_async(args=args, kwargs=kwargs, **rest_kwargs)

    def __call__(self, *args, **kwargs):
        worker = get_current_worker_name()
        if worker:
            self._process_task_started(worker, kwargs)
        self._log_start()
        mpfs.engine.queue2.queue_log.OperationJobBinding.set_job(self)

        # Нужно на случай, если таск решит поменять агрументы, с которыми он вызывается
        # https://st.yandex-team.ru/CHEMODAN-30156
        kwargs = deepcopy(kwargs)
        # Эта копия для ретраев
        original_kwargs = deepcopy(kwargs)

        experiment_manager.update_context(uid=kwargs.get('uid'))

        try:
            from mpfs.core.user.common import CommonUser
            uid = kwargs.get('uid')
            if uid and self.name.rsplit('.', 1)[-1] not in QUEUE2_TASKS_WHITELIST:
                user = CommonUser(uid=uid)
                user.check_blocked()
            async_task_data = None
            if '__task_data' in kwargs:
                task_data = kwargs.pop('__task_data')
                uid, db_id = task_data['uid'], task_data['db_id']

                from mpfs.engine.queue2.async_tasks.controllers import AsyncTasksDataController
                async_task_data = AsyncTasksDataController().get(uid=uid, id=db_id)

                # Из RabbitMQ сообщение может прийти дважды, метод запустится дважды. Один может успеть обработать
                # сообщение и удалить данные, тогда во втором будет такая ошибка.
                if async_task_data is None:
                    raise QueueNoAsyncTaskDataError()

                kwargs['task_data'] = TaskData(uid=uid, data=decompress_data(async_task_data.data))

            res = super(BaseTask, self).__call__(*args, **kwargs)
        except UserBlocked as exc:
            stacktrace = traceback.format_exc()
            kwargs['context'].update({
                'error': unicode(exc.message).encode('utf-8'),
                'traceback': stacktrace,
                'worker': worker,
            })
            self.request.retries -= 1
            self.retry(kwargs=kwargs, throw=False, countdown=0, queue='completed')
            self._log_error(stacktrace, kwargs)
            res = None
            # Нельзя зафейлить таск без ретрая. Надо писать в лог инфу про блокировку и фейл, но заканчивать успешно
        except Exception as exc:
            if worker:
                self._process_task_failure(worker, exc, original_kwargs)
            raise
        else:
            if worker:
                self._process_task_success(worker, kwargs)

            if async_task_data:
                try:
                    async_task_data.delete()
                except Exception:
                    pass

        return res

    def _process_task_started(self, worker, kwargs):
        if not worker:
            return

        # выполняется не в главном потоке, а в дочернем (воркере)
        kwargs['context'].update({'started': int(time() * 1000),
                                  'worker': worker})

        self.request.retries -= 1
        self.retry(kwargs=kwargs, throw=False, countdown=0, queue='started')
        self.request.retries += 1

    def _process_task_success(self, worker, kwargs):
        if not worker:
            return

        kwargs['context'].update({'finished': int(time() * 1000),
                                  'worker': worker})

        # Task.retry делает retries = request.retries + 1
        # Мы не хотим увеличивать счётчик, поэтому делаем - 1
        self.request.retries -= 1
        self.retry(kwargs=kwargs, throw=False, countdown=0, queue='completed')
        self._log_success(kwargs)

    def _process_task_failure(self, worker, exc, kwargs):
        if not worker:
            return

        submit_queue_name = QueueNames.SUBMIT
        if get_task_func_name(self.name) in settings.queue2['secondary_submit']['task_func_names']:
            submit_queue_name = QueueNames.SECONDARY_SUBMIT

        if isinstance(exc, SilentTaskRetryException):
            # Task.retry делает retries = request.retries + 1
            # Мы не хотим увеличивать счётчик, поэтому делаем - 1
            self.request.retries -= 1
            self.retry(kwargs=kwargs, throw=False, countdown=exc.retry_delay, queue=submit_queue_name)
            self._log_success(kwargs)
        else:
            stacktrace = traceback.format_exc()
            kwargs['context'].update({
                'error': unicode(exc.message).encode('utf-8'),
                'traceback': stacktrace,
                'worker': worker,
            })
            self.retry(kwargs=kwargs, exc=exc, throw=False, countdown=0, queue=submit_queue_name)
            self._log_error(stacktrace, kwargs)

    def _log_start(self):
        log = mpfs.engine.process.get_default_log()
        log.info('Task %s started (try %d), name: %s' % (self.request.id, self.request.retries, self.name))

    def _log_state(self, context, status):
        process_time = None
        lifetime = None
        if context:
            finished = int(time() * 1000)
            process_time = (finished - context['started']) / 1000.0
            lifetime = (finished - context['created']) / 1000.0
        log_data = {
            'task_id': self.request.id,
            'task_type': self.name,
            'task_status': status,
            'task_retries': self.request.retries,
            'process_time': process_time,
            'lifetime': lifetime,
        }
        mpfs.engine.queue2.queue_log.log_metrics(**log_data)

    def _log_success(self, kwargs):
        context = kwargs.get('context')
        self._log_state(context, 'OK')

    def _log_error(self, stacktrace, kwargs):
        error_log = mpfs.engine.process.get_error_log()
        error_log.info('Task %s failed (try %d), name: %s' % (self.request.id, self.request.retries, self.name))
        error_log.error(stacktrace)
        task_data = copy(kwargs)
        context = task_data.pop('context', None)
        error_log.info(pformat(task_data))

        self._log_state(context, 'FAIL')


@import_modules.connect
def prepare_celery_app(sender, **kwargs):
    # Этот сигнал вызывает функцию при инициализации приложения celery - до импорта модулей с тасками - один раз для
    # всех воркеров.
    # Делаем это здесь, т.к. тут вызывается setup, который инициализирует базу, а это надо сделать до того, как
    # произойдет импорт модулей с тасками (создание объекта Celery, опция include), т.к. там коннект к базе уже
    # должен быть проинициализирован
    prepare_zookeeper_settings()

    from mpfs.engine.process import setup, pre_fork, set_register_after_fork_impl, get_default_log

    set_register_after_fork_impl(billiard.util.register_after_fork)
    setup()

    sender.conf.BROKER_URL = mpfs.engine.process.rabbitmq_hosts(shuffle=True)

    setup_producer_logging()
    setup_consumer_logging()

    pre_fork()
    if LOGGER_MONITOR_ENABLED:
        logger.enable_monitor()

    if QUEUE2_WORKER_ENABLE_ACKS_LATE_FIX:
        Request.on_failure = get_on_failure_patch(get_default_log())

    if SERVICES_TVM_2_0_ENABLED:
        from mpfs.core.services.tvm_2_0_service import tvm2
        tvm2.update_public_keys(silent_mode_on_errors=False)
        tvm2.update_service_tickets()


def set_command_context(command):
    # Тут мы патчим агрумент context у каждой таски перед тем, как отправить ее в брокер. Добавляем ей ycrid, если у нее
    # его еще нет. Надо отметить, что этот код попадает в java, которая использует этот параметр для логирования и
    # плюс ко всему добавляет в context еще кучу всего перед тем, как таска попадает обратно на обработку.
    # ВАЖНО! Я тут еще заметил, что если подавать параметр произвольный, не ycrid, а с другим именем, то java его
    # фильтрует, поэтому если надо добавить еще один параметр, который должен сохраниться, надо договариваться с
    # джавистами
    if 'context' not in command:
        command['context'] = {}
    if 'ycrid' not in command['context']:
        command['context']['ycrid'] = mpfs.engine.process.get_cloud_req_id()
    if 'created' not in command['context']:
        command['context']['created'] = int(time() * 1000)

    if mpfs.engine.process.is_uwsgi_process():
        command['context']['host'] = mpfs.engine.process.hostname()
    else:
        command['context']['host'] = mpfs.engine.process.get_async_task_uwsgi_submitter()

    if KWARGS_DEDUPLICATION_ID_PARAMETER in command:
        deduplication_id = command.pop(KWARGS_DEDUPLICATION_ID_PARAMETER)
        command['context']['activeUid'] = hashlib.md5(deduplication_id).hexdigest()


@before_task_publish.connect
def setup_task_context_params(body, *args, **kwargs):
    set_command_context(body['kwargs'])


@task_prerun.connect
def setup_task(task_id, task, *args, **kwargs):
    # А тут мы перед тем, как отдать таску на выполнение, вытаскиваем из kwargs аргументов таски переменную context
    # и cloud_request_id aka ycrid, чтобы в логе записи писались уже с ним, ну и формируем id таски также для лога
    ycrid = kwargs.get('kwargs', {}).get('context', {}).get('ycrid', None)
    task_short_name = task.name.rsplit('.', 1)[-1]

    mpfs.engine.process.set_cloud_req_id(ycrid)
    mpfs.engine.process.set_req_id('%s-%s' % (task_id, task_short_name))
    if get_current_worker_name():
        mpfs.engine.process.reset_cached()  # только в воркерах
        mpfs.engine.process.reset_connections()

    uwsgi_host = kwargs.get('kwargs', {}).get('context', {}).get('host', None)
    if uwsgi_host:
        mpfs.engine.process.set_async_task_uwsgi_submitter(uwsgi_host)


@task_postrun.connect
def shutdown_check(*args, **kwargs):
    worker = get_current_worker_name()
    from mpfs.core.zookeeper.hooks import update_settings_in_current_worker
    update_settings_in_current_worker(worker_id=worker)
    if worker:
        rss_memory = float(resource.getrusage(resource.RUSAGE_SELF).ru_maxrss) / 1024  # переводим килобайты в мегабайты
        if rss_memory > QUEUE2_WORKER_LIMITS_MAX_RSS:
            mpfs.engine.process.get_default_log().info('Process consumed too much %d MB (> %d), exit.',
                                                       rss_memory, QUEUE2_WORKER_LIMITS_MAX_RSS)
            sys.exit()
        elapsed = time() - mpfs.engine.process.get_process_ctime()
        if elapsed > QUEUE2_WORKER_LIMITS_MAX_LIFETIME:
            mpfs.engine.process.get_default_log().info('Process lived too long %d sec (> %d), exit.',
                                                       elapsed, QUEUE2_WORKER_LIMITS_MAX_LIFETIME)
            sys.exit()


@worker_ready.connect
def setup_worker_master_process(*args, **kwargs):
    mpfs.engine.process.set_app_name('queue2.%s' % current_process().name)


@worker_process_init.connect
def setup_worker_child_process(*args, **kwargs):
    mpfs.engine.process.set_app_name('queue2')

    if LOGGER_MONITOR_ENABLED:
        logger.enable_monitor()

    mpfs.engine.process.set_process_ctime(time())

    worker_id = get_current_worker_name()

    from mpfs.core.zookeeper.thread import preload_settings_and_start_zk_init_function
    preload_settings_and_start_zk_init_function(worker_id=worker_id)

    from mpfs.core.zookeeper.shortcuts import push_new_settings_to_queue
    push_new_settings_to_queue(worker_id=worker_id)


class SilentTaskRetryException(Exception):
    """The task is to be retried later without incrementing retry counter. Task is logged as success."""
    def __init__(self, retry_delay=0):
        super(SilentTaskRetryException, self).__init__()
        self.retry_delay = retry_delay
