import celery
import contextlib
import functools
import logging
import datetime

from sqlalchemy import create_engine
from sqlalchemy.orm import Session

from watcher.celery_app import app
from watcher.config import settings
from watcher.crud.schedule import save_schedule_shifts_boundary_recalculation_error
from watcher.logic import locks
from watcher.logic.exceptions import BaseWatcherException
from watcher.logic.logger.celery import CtxAwareMixin
from watcher.logic.timezone import now
from watcher.db import (
    Event,
    TaskMetric,
)
from watcher import enums


logger = logging.getLogger(__name__)


class WatcherCeleryTask(CtxAwareMixin, celery.Task):
    def apply_async(self, *args, **kwargs):
        if 'countdown' not in kwargs:
            kwargs['countdown'] = settings.TASKS_DEFAULT_COUNTDOWN
        task_name = self.name.split('.')[-1]
        force_delay = kwargs.pop(
            settings.FORCE_TASK_DELAY,
            args[1].pop(settings.FORCE_TASK_DELAY, None) if len(args) > 1 else None
        )

        if (
            settings.ENV_TYPE != 'development' and
            settings.ENABLE_TASK_QUEUE and
            task_name in settings.TASKS_FOR_QUEUE and
            not force_delay
        ):
            self.create_task_event(*args, **kwargs)
        else:
            return super().apply_async(*args, **kwargs)

    @staticmethod
    def _serialize(args, kwargs) -> None:
        for container in (args[1] if len(args) > 1 else {}, kwargs):
            for key, value in container.items():
                if isinstance(value, datetime.datetime):
                    container[key] = value.isoformat()

    def create_task_event(self, *args, **kwargs) -> None:
        kwargs.pop('countdown', None)
        session = kwargs.pop('session', None)
        if not session:
            engine = create_engine(settings.database_url)
            session = Session(bind=engine)

        self._serialize(args=args, kwargs=kwargs)
        event = Event(
            object_data={
                'args': args,
                'kwargs': kwargs,
                'name': self.name,
            },
            source=enums.EventSource.internal,
            type=enums.EventType.task,
        )
        session.add(event)
        session.commit()


@contextlib.contextmanager
def task_metric(metric_name, send_to_unistat=False, save_metrics=False):
    start = now()
    yield
    if save_metrics:
        engine = create_engine(settings.database_url)
        new_session = Session(bind=engine)

        task_metric_record = new_session.query(TaskMetric).filter(TaskMetric.task_name == metric_name).first()
        if task_metric_record:
            task_metric_record.last_success_start = start
            task_metric_record.last_success_end = now()
            task_metric_record.send_to_unistat = send_to_unistat
        else:
            task_metric_record = TaskMetric(
                task_name=metric_name,
                last_success_start=start,
                last_success_end=now(),
                send_to_unistat=send_to_unistat,
            )
            new_session.add(task_metric_record)
        new_session.commit()


def lock_task(f=None, min_lock_time=settings.TASKS_MIN_LOCK_TIME, **kwargs):
    """
    Обертка над декоратором celery-таска.

    Если явно не был задан параметр lock=False, задача будет заблокирована
    через yt с ключом lock_key (или полным именем таска по умолчанию).

    Также, если явно не задан параметр ignore_result=False, в оригинальный
    декоратор будет передан параметр ignore_result=True.

    min_lock_time - время в секундах, на которое гарантировано будет взят лок,
    использутся, если таска отрабатывает слишком быстро

    Остальные параметры передаются в оригинальный декоратор без изменений.
    """
    def get_wrapper(f, options):
        lock_key = options.pop('lock_key', None)
        save_metrics = options.pop('save_metrics', False)
        send_to_unistat = options.pop('send_to_unistat', False)

        @functools.wraps(f)
        def wrapper(*args, **kwargs):
            lock = kwargs.pop('_lock', True)
            func = f
            if lock:
                func = locks.lock(lock_key, min_lock_time)(func)
            with task_metric(func.__name__, send_to_unistat=send_to_unistat, save_metrics=save_metrics):
                task_result = func(*args, **kwargs)
            return task_result

        result = wrapper

        task_options = {
            'ignore_result': True,
        }
        task_options.update(options)

        result = app.task(**task_options)(result)
        result.original_func = f

        return result

    if f and callable(f):
        return get_wrapper(f, kwargs)
    else:
        def decorator(f):
            return get_wrapper(f, kwargs)
        return decorator


def save_schedule_error(func):
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        try:
            func(*args, **kwargs)
        except BaseWatcherException as exc:
            schedule_id = kwargs.get('schedule_id', None)
            if not schedule_id:
                schedule_id = args[0]

            engine = create_engine(settings.database_url)
            new_session = Session(bind=engine)
            with new_session.begin(nested=True):
                save_schedule_shifts_boundary_recalculation_error(
                    db=new_session,
                    schedule_id=schedule_id,
                    message=exc.to_json(),
                )
            new_session.commit()
            raise
    return wrapper
