import typing
import asyncio
import logging
import celery
import boto3
import ylock
from ylock.backends.yt import Manager as YTManager
from functools import cache, wraps
from celery import Celery
from asgiref.sync import async_to_sync, sync_to_async
from ylog.context import log_context
from crm.agency_cabinet.common.request_id_utils import REQUEST_ID_VAR
from .scheduler import LockedPersistentScheduler


class BotoSessionBoundedMixin(object):
    _MDS_ACCESS_KEY_ID = None
    _MDS_SECRET_ACCESS_KEY = None
    _MDS_ENDPOINT_URL = None

    def __init__(self):
        super().__init__()
        if self._MDS_ACCESS_KEY_ID and self._MDS_SECRET_ACCESS_KEY and self._MDS_SECRET_ACCESS_KEY:
            self.boto3_session = boto3.Session(aws_access_key_id=self._MDS_ACCESS_KEY_ID,
                                               aws_secret_access_key=self._MDS_SECRET_ACCESS_KEY)
            self.s3_client = self.boto3_session.client('s3', endpoint_url=self._MDS_ENDPOINT_URL)
            self.s3_resource = self.boto3_session.resource('s3', endpoint_url=self._MDS_ENDPOINT_URL)
        else:
            self.boto3_session = None
            self.s3_client = None
            self.s3_resource = None

    @classmethod
    def init_mds_params(cls, access_key, secret_key, endpoint_url):
        cls._MDS_ACCESS_KEY_ID = access_key
        cls._MDS_SECRET_ACCESS_KEY = secret_key
        cls._MDS_ENDPOINT_URL = endpoint_url


class BaseTask(celery.Task):
    def on_failure(self, exc, task_id, args, kwargs, einfo):
        logger = logging.getLogger('celery.app.trace')
        logger.exception('Exception occurred during task execution: %s', exc)
        REQUEST_ID_VAR.set('UNSET')

    LOCK_MANAGER: YTManager = None
    # def on_success(self, retval, task_id, args, kwargs):


class ContextTask(BaseTask):

    DB = None
    DB_CONFIG = None

    def __init__(self):
        super().__init__()

    @classmethod
    def build_engine_context(cls):
        return cls.DB.set_bind(
            bind=cls.DB_CONFIG['dsn'],
            pool_class=cls.DB_CONFIG['pool_class'],
            ssl=cls.DB_CONFIG['ssl']
        )

    async def _run(self, *args, **kwargs):
        try:
            if self.DB._bind is None:
                await self.build_engine_context()
            res = await sync_to_async(self.run)(*args, **kwargs)
            return res
        except Exception:
            raise
        finally:
            if self.DB._bind is not None:
                await asyncio.sleep(0)
                await self.DB.pop_bind().close()
            await asyncio.sleep(0.5)  # ждем закрытия ssl-соединения

    def __call__(self, *args, **kwargs):
        with log_context(request_id=self.request.id):
            return async_to_sync(self._run)(*args, **kwargs)


class CeleryAppFactory:

    def __init__(
            self,
            name: str,
            scheduler_lock_settings: dict,
            celery_settings: dict,
            task_class: typing.Type[BaseTask] = BaseTask):
        self.scheduler_lock_settings = scheduler_lock_settings
        self.celery_settings = celery_settings
        self.task_class = task_class
        self.name = name

    def post_create(self, app, manager, *args, **kwargs):
        if self.task_class is not None:
            self.task_class.LOCK_MANAGER = manager

    @cache
    def create(self, *args, **kwargs) -> Celery:
        manager = ylock.create_manager(
            backend="yt",
            proxy=self.scheduler_lock_settings["proxy"],
            token=self.scheduler_lock_settings["token"],
            prefix=self.scheduler_lock_settings.get("prefix"),
        )
        LockedPersistentScheduler.configure(
            manager, **self.scheduler_lock_settings.get("lock_params", {})
        )

        app = Celery(self.name)
        app.config_from_object(self.celery_settings)

        if self.task_class is not None:
            app.Task = self.task_class

        self.post_create(app, manager, *args, **kwargs)

        return app


class CeleryAppFactoryContextTask(CeleryAppFactory):
    def __init__(
            self,
            name: str,
            scheduler_lock_settings: dict,
            celery_settings: dict,
            db_config: dict,
            db,
            task_class: typing.Type[ContextTask] = ContextTask,
    ):
        super(CeleryAppFactoryContextTask, self).__init__(name, scheduler_lock_settings, celery_settings, task_class)
        self.db_config = db_config
        self.db = db

    def post_create(self, app, manager, *args, **kwargs):
        super(CeleryAppFactoryContextTask, self).post_create(app, manager, *args, **kwargs)
        self.task_class: typing.Type[ContextTask]
        if self.task_class is not None:
            self.task_class.DB = self.db
            self.task_class.DB_CONFIG = self.db_config


def locked_task(function=None,
                lock_path: str = "",
                key: typing.Union[str, int] = None,
                block: bool = False,
                block_timeout: int = None,
                timeout: int = None):

    def decorator(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            self_or_cls: ContextTask = args[0] if len(args) > 0 and isinstance(args[0], ContextTask) else ContextTask
            result = None
            have_lock = False
            if key is not None:
                if isinstance(key, str) and key in kwargs:
                    suffix = f'-{kwargs[key]}' if kwargs[key] is not None else ''
                elif isinstance(key, int) and len(args) > key:
                    suffix = f'-{args[key]}' if args[key] is not None else ''
                else:
                    suffix = ''
                p = f'{lock_path}/{key}{suffix}'
            else:
                p = lock_path
            lock = self_or_cls.LOCK_MANAGER.lock(p, block=block, block_timeout=block_timeout, timeout=timeout)
            try:
                have_lock = lock.acquire(timeout)
                if have_lock:
                    result = func(*args, **kwargs)
            finally:
                if have_lock:
                    lock.release()

            return result

        return wrapper

    return decorator(function) if function is not None else decorator
