import json
import functools
import contextlib
import logging
import traceback
import ylock

from django.conf import settings

from celery.app.task import TaskType
from django_celery_beat.schedulers import DatabaseScheduler
from django_celery_beat.models import PeriodicTask

from infra.cauth.server.common.middleware import SqlAlchemySessionMiddleware

from infra.cauth.server.master.celery_app import app
from infra.cauth.server.master.utils.mongo import get_mongo_database
from infra.cauth.server.master.utils.subtasks import SubtaskPool

from .config import service_is_readonly


lock_manager = ylock.backends.create_manager(**settings.YLOCK)
locked_context = functools.partial(
    lock_manager.lock,
    block=False,
)


@contextlib.contextmanager
def fake_locked_context(*args, **kwargs):
    yield True


class CauthScheduler(DatabaseScheduler):
    def __init__(self, *args, **kwargs):
        # Костыль, без которого beat перестает ставить таски при редеплое
        PeriodicTask.objects.update(last_run_at=None)
        super(CauthScheduler, self).__init__(*args, **kwargs)

    def maybe_due(self, entry, publisher=None):
        is_due, next_time_to_run = entry.is_due()

        if service_is_readonly():
            return next_time_to_run
        return super(CauthScheduler, self).maybe_due(entry, publisher=publisher)


def remove_lock_by_name(lock_name):
    return  # TODO: CAUTH-1689: вернуть возможность удаления лока


class TaskDecorator(object):
    def __init__(self, use_lock=True, lock_name=None, bind=False,
                 is_subtask=False, dedicated_logger=True, **kwargs):
        if settings.DISABLE_TASK_LOCKING:
            self._use_lock = False
        else:
            self._use_lock = use_lock
        self._lock_name = lock_name
        self._bind = bind
        self._is_subtask = is_subtask
        self._pool = None
        self._task_kwargs = kwargs
        self._dedicated_logger = dedicated_logger

        self._error_logger = None

    @contextlib.contextmanager
    def lock_context(self, instance):
        if self._use_lock:
            context_func = locked_context
        else:
            context_func = fake_locked_context

        bound_info = instance.info.bind(instance)
        with context_func(bound_info.lock_name) as acquired:
            yield acquired

    @contextlib.contextmanager
    def error_context(self, instance):
        if not self._error_logger:
            self._error_logger = logging.getLogger(instance.name)
        try:
            yield
        except Exception:
            if self._pool:
                error = traceback.format_exc()
                self._pool.report_error(instance.request.id, error)

            self._error_logger.exception(
                "Exception in task {}".format(instance.name))
            raise

    @contextlib.contextmanager
    def log_context(self, instance):
        yield logging.getLogger()

    def attach_info(self, task_):
        task_.info = TaskInfo(self._use_lock, self._lock_name)
        return task_

    @staticmethod
    @contextlib.contextmanager
    def alchemy_context():
        try:
            yield
            SqlAlchemySessionMiddleware.process_response(None, None)
        except Exception:
            SqlAlchemySessionMiddleware.process_exception(None, None)
            raise

    def __call__(self, f):
        @functools.wraps(f)
        def wrapper(instance, *args, **kwargs):
            if not self._is_subtask and instance.request.called_directly:
                if self._bind:
                    return f(instance, *args, **kwargs)
                else:
                    return f(*args, **kwargs)

            with self.log_context(instance) as logger,\
                    self.error_context(instance),\
                    self.alchemy_context(),\
                    self.lock_context(instance) as acquired:

                if not acquired:
                    logger.info("Task is locked: {}".format(instance.name))
                    return None

                if self._is_subtask:
                    self._pool = SubtaskPool(
                        id=kwargs.pop('pool_id'),
                        suite_run_id=kwargs.get('suite_run_id'),
                    )
                    pool_lock = self._pool.get_lock()
                    if not (settings.DISABLE_TASK_LOCKING
                            or pool_lock.check_acquired()):
                        logger.warning("Subtask pool lock for %s has already "
                                       "been released. Skipping.", self._pool.id)
                        return

                    self._pool.report_started(instance.request.id)
                    logger.info('marked task %s of pool %s as started',
                                instance.request.id, self._pool.id)

                if self._bind:
                    result = f(instance, *args, **kwargs)
                else:
                    result = f(*args, **kwargs)

                logger.info('Task finished: {}'.format(repr(result)))

                if self._pool:
                    self._pool.report_finished(instance.request.id)
                    logger.info('marked task %s of pool %s as finished',
                                instance.request.id, self._pool.id)

                return result

        task_ = app.task(wrapper, bind=True, **self._task_kwargs)
        return self.attach_info(task_)


class LockInfo(object):
    def __init__(self, name):
        self._name = name

    def get(self):
        return None  # TODO: CAUTH-1689: передеделать, чтобы работало с Locke


class TaskInfo(object):
    def __init__(self, use_lock, lock_name):
        self.lock_name = lock_name
        self.use_lock = use_lock

    def bind(self, task):
        if isinstance(task, PeriodicTask):
            return TaskInfoBound(self, task.task, json.loads(task.args), json.loads(task.kwargs))
        if isinstance(type(task), TaskType):
            return TaskInfoBound(self, task.name, task.request.args or [], task.request.kwargs or {})
        raise ValueError


class TaskInfoBound(object):
    def __init__(self, info, task_name, args, kwargs):
        self.info = info
        self.task_name = task_name
        self.args = args
        self.kwargs = kwargs
        self.lock_info = LockInfo(self.lock_name)

    @property
    def lock_name(self):
        if not self.info.use_lock:
            return None

        if self.info.lock_name is None:
            return self.task_name

        if isinstance(self.info.lock_name, str):
            return self.info.lock_name

        if callable(self.info.lock_name):
            return self.info.lock_name(self.task_name, *self.args,
                                       **self.kwargs)

        raise ValueError

    def get_lock_info(self, retrying=False):
        return self.lock_info.get()

    def get_recent_logs(self, limit=None):
        mongo_db = get_mongo_database()

        params = {'task_name': self.task_name}
        if self.info.use_lock:
            params['lock_name'] = self.lock_name

        query = mongo_db['task_logs'].find(params).sort([('started', -1)])

        if limit:
            query = query.limit(limit)

        return list(query)

    def get_latest_log(self):
        recent = self.get_recent_logs(limit=1)
        if recent:
            return recent[0]

    @staticmethod
    def get_log_by_task_id(id_):
        mongo_db = get_mongo_database()
        return mongo_db['task_logs'].find_one({'task_id': id_})


task = TaskDecorator
