from contextlib import contextmanager
from copy import deepcopy
from dataclasses import dataclass
from typing import Any, ClassVar, Dict, Mapping, MutableMapping, Optional, Tuple, Type
from weakref import WeakValueDictionary

from sendr_aiopg import StorageBase
from sendr_aiopg.action import BaseDBAction
from sendr_core import BaseAction
from sendr_taskqueue.worker.base.entites import BaseTaskParams, BaseTaskType
from sendr_taskqueue.worker.storage import BaseStorageWorker
from sendr_taskqueue.worker.storage.db.entities import Task, TaskState
from sendr_utils import enum_value, json_value


@dataclass
class ActionTaskParams(BaseTaskParams):
    action_kwargs: Dict[str, Any]
    max_retries: int


class BaseActionStorageWorker(BaseStorageWorker):
    actions: ClassVar[Tuple[Type[BaseAction], ...]]
    task_type: ClassVar[BaseTaskType]
    action_mapping: Mapping[str, Type[BaseAction]]

    def _create_actions_mapping(self) -> Mapping[str, Type[BaseAction]]:
        """Check that each action has a name"""
        mapping: Dict[str, Type[BaseAction]] = {}
        for action in self.actions:
            if not isinstance(action.action_name, str):
                raise RuntimeError('Action class variable `action_name` is not of string type')
            mapping[action.action_name] = action
        return mapping

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.action_mapping = self._create_actions_mapping()
        assert getattr(self, 'task_type', None), 'task_type must be defined'

    async def fetch_task_for_work(self, storage: StorageBase) -> Task:
        task_mapper = storage[self.mapper_name_task]

        task = await task_mapper.get_for_work(
            task_types=[self.task_type],
            task_states=[TaskState.PENDING],
            action_names=self.action_mapping.keys(),
        )
        if task.details is None:
            task.details = {}
        return task

    def get_params(self, task: Optional[Task] = None) -> dict:
        """Create context for action"""
        assert task is not None
        if isinstance(task.params, dict):
            task.params = ActionTaskParams(**task.params)
        cls: Type[BaseAction] = self.get_action_class(task)
        kwargs = cls.deserialize_kwargs(deepcopy(task.params.action_kwargs))
        return kwargs

    def get_action_class(self, task: Task) -> Type[BaseAction]:
        """Get action class by given action name in Task attributes"""
        assert isinstance(task.action_name, str)
        return self.action_mapping[task.action_name]

    async def process_action(self, action_cls: Any, params: Any) -> None:
        with self.logger:
            self.logger.context_push(action_name=getattr(action_cls, 'action_name', None))
            await super().process_action(action_cls, params)

    def get_process_task_logger_context(self, task: Task) -> Dict[str, Any]:
        return {
            'task_id': task.task_id,
            'task_type': enum_value(task.task_type),
            'task_try': task.retries,
            'action_name': task.action_name
        }


class AsyncDBActionMeta(type):
    """
    Checks action names - if action defines some name it has to be unique
    Unique action names are required to create Tasks.
    If action is not intended to be scheduled in Task, then action can omit name.
    """
    _actions: ClassVar[MutableMapping[str, Type[BaseAction]]] = WeakValueDictionary()

    @classmethod
    def _register_action(mcs, action_cls: Type[BaseAction]) -> bool:
        name = getattr(action_cls, 'action_name', None)
        if not isinstance(name, str):
            return False
        if name in mcs._actions:
            raise RuntimeError('Action name collision')
        mcs._actions[name] = action_cls
        return True

    def __new__(mcs, *args, **kwargs):
        action_cls = super().__new__(mcs, *args, **kwargs)
        mcs._register_action(action_cls)
        return action_cls


class BaseAsyncDBAction(BaseDBAction, metaclass=AsyncDBActionMeta):
    max_retries: int = 10
    action_name: ClassVar[str]
    mapper_name_task: ClassVar[str] = 'task'
    task_type: ClassVar[BaseTaskType]

    @classmethod
    def deserialize_kwargs(cls: Type[BaseAction], init_kwargs: dict) -> dict:
        return init_kwargs

    @classmethod
    def serialize_kwargs(
        cls: Type[BaseAction], init_kwargs: Dict[str, Any]
    ) -> Dict[str, Any]:
        return json_value(init_kwargs)

    @contextmanager
    def _turn_replica_read_off(self):
        prev_value = self.allow_replica_read
        self.allow_replica_read = False
        try:
            yield
        finally:
            self.allow_replica_read = prev_value

    async def run_async(self, **kwargs: Any) -> BaseTaskType:
        """
        Creates task instead of running action.
        Action must define cls.action_name and be in one of ActionWorkers, to be executed asynchronously
        """
        assert not self._init_args, 'Only kwargs are allowed'
        assert getattr(self, 'action_name', None), 'action_name must be defined'
        assert self.mapper_name_task, 'mapper_name_task must be defined'
        assert getattr(self, 'task_type', None), 'task_type must be defined'

        # the very action can be read-only, but in order to write it to the tasks
        # table we need a writable DB connection
        with self._turn_replica_read_off():
            async with self.storage_setter(transact=self.transact, reuse_connection=True):
                # in case self.serialize_kwargs() uses storage for some reason
                action_kwargs = dict(
                    action_name=self.action_name,
                    task_type=self.task_type,
                    params=ActionTaskParams(
                        action_kwargs=self.serialize_kwargs(self._init_kwargs),
                        max_retries=self.max_retries,
                    ),
                )
                action_kwargs.update(kwargs)

                task_mapper = self.storage[self.mapper_name_task]
                task = await task_mapper.create(**action_kwargs)

                with self.logger:
                    self.logger.context_push(task_id=task.task_id)
                    self.logger.info('Task created')

                return task
