import uuid
from asyncio import CancelledError
from datetime import timedelta
from typing import Any, ClassVar, Dict, Iterable, Mapping, Optional, Tuple, Type, Union

from aiohttp import web

from sendr_aiopg import StorageBase
from sendr_aiopg.storage.exceptions import StorageNotFound
from sendr_aiopg.storage.lock import LockAlreadyAcquired
from sendr_core import BaseAction
from sendr_taskqueue.worker.base import BaseWorker
from sendr_taskqueue.worker.base.entites import BaseTaskType, BaseWorkerType
from sendr_taskqueue.worker.storage.db.entities import Task, TaskState, Worker, WorkerState
from sendr_taskqueue.worker.storage.db.mappers import get_task_mapper, get_worker_mapper  # NOQA
from sendr_taskqueue.worker.storage.exceptions import BaseWorkerError, WorkerShutdownError
from sendr_taskqueue.worker.storage.mixins import StorageMixin
from sendr_utils import utcnow


class BaseStorageWorker(StorageMixin, BaseWorker):
    retry_initial_delay: ClassVar[float] = 1
    retry_delay_multiplier: ClassVar[float] = 2
    retry_max_delay: ClassVar[float] = -1
    worker_type: ClassVar[Type[BaseWorkerType]]
    retry_exceptions: ClassVar[Union[Tuple[Union[Type[Exception], Exception], ...], bool]] = True
    task_action_mapping: ClassVar[Mapping[BaseTaskType, Type[BaseAction]]] = dict()
    max_retries: ClassVar[int] = 10
    suppress_exceptions: ClassVar[Tuple[Type[Exception], ...]] = ()

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.worker: Optional[Worker] = None
        self._request_id: str = f'{self.worker_id}_{uuid.uuid4().hex}'

    @property
    def request_id(self) -> str:
        return self._request_id

    async def commit_task(self, task: Task, storage: StorageBase) -> None:
        """At the end of task processing save resulting state"""
        async with storage.conn.begin():
            task_mapper = storage[self.mapper_name_task]
            worker_mapper = storage[self.mapper_name_worker]

            worker = await worker_mapper.get(self.worker_id)
            assert worker.task_id == task.task_id, (worker.task_id, task.task_id)
            worker.task_id = None

            await task_mapper.save(task)
            await worker_mapper.save(worker)

        with self.logger:
            self.logger.context_push(task_state=task.state.value)
            self.logger.info('Task processed')

    @property
    def task_types(self) -> Iterable[str]:
        return self.task_action_mapping.keys()

    async def fetch_task_for_work(self, storage: StorageBase) -> Task:
        task_mapper = storage[self.mapper_name_task]
        return await task_mapper.get_for_work(task_types=self.task_types, task_states=[TaskState.PENDING])

    async def get_task(self, storage: StorageBase) -> Task:
        async with storage.conn.begin():
            worker_mapper = storage[self.mapper_name_worker]
            task_mapper = storage[self.mapper_name_task]

            worker = await worker_mapper.get(self.worker_id)
            if worker.state != WorkerState.RUNNING:
                raise WorkerShutdownError()
            if worker.task_id:
                prev_task = await task_mapper.get(worker.task_id)
                if prev_task.state == TaskState.PROCESSING:
                    prev_task.state = TaskState.FAILED
                    if prev_task.details is None:
                        prev_task.details = {}
                    prev_task.details['reason'] = 'Bad state after processing'
                    await self.commit_task(prev_task, storage=storage)

        async with storage.conn.begin():
            task = await self.fetch_task_for_work(storage=storage)

            worker.task_id = task.task_id
            await worker_mapper.save(worker)

            task.state = TaskState.PROCESSING
            await task_mapper.save(task)

        return task

    def get_action_class(self, task: Task) -> Type[BaseAction]:
        return self.task_action_mapping[task.task_type]

    def get_params(self, task):
        return task.params

    def should_fail_task(self,
                         task: Task,
                         action_cls: Type[BaseAction],
                         action_exception: Exception) -> Tuple[bool, Optional[str]]:
        if self.task_retries_exhausted(task):
            self.logger.info('Task retries exhausted')
            return True, 'Max retries exceeded'
        if self.should_retry_exception(action_cls, action_exception):
            return False, None
        return True, f'Action failed because of exception {action_exception.__class__.__name__}'

    async def task_done(self, task: Task, storage: StorageBase) -> bool:
        task.state = TaskState.FINISHED
        await self.commit_task(task, storage=storage)
        return self.PROCESS_TASK_WITH_NO_PAUSE

    def get_process_task_logger_context(self, task: Task) -> Dict[str, Any]:
        return {
            'task_id': task.task_id,
            'task_try': task.retries
        }

    def task_retries_exhausted(self, task: Task) -> bool:
        if isinstance(task.params, dict):
            max_retries = task.params.get('max_retries', self.max_retries)
        else:
            max_retries = getattr(task.params, 'max_retries', self.max_retries)
        return task.retries >= max_retries if max_retries >= 0 else False

    async def task_fail(self, reason: Optional[str], task: Task, storage: StorageBase) -> bool:
        task.state = TaskState.FAILED
        task.details = {'reason': reason}
        await self.commit_task(task, storage=storage)
        return self.PROCESS_TASK_WITH_NO_PAUSE

    def should_retry_exception(self, action_cls: Type[BaseAction], action_exception: Exception) -> bool:
        if isinstance(self.retry_exceptions, bool):
            return self.retry_exceptions
        elif isinstance(self.retry_exceptions, tuple):
            return isinstance(action_exception, self.retry_exceptions)  # type: ignore
        else:
            raise Exception(f'Invalid type of retry_exceptions: {type(self.retry_exceptions)}')

    def get_task_retry_delay(self, task: Task) -> timedelta:
        total_seconds = self.retry_initial_delay * (self.retry_delay_multiplier ** task.retries)
        if self.retry_max_delay >= 0:
            total_seconds = min(total_seconds, self.retry_max_delay)
        return timedelta(seconds=int(total_seconds))

    async def task_retry(self, task: Task, exception: Exception, storage: StorageBase) -> bool:
        """Mark task pending again, increment retry counter and save state"""
        task.state = TaskState.PENDING
        task.run_at = utcnow() + self.get_task_retry_delay(task)
        task.retries += 1
        assert task.details is not None
        retry_reason = exception.__class__.__name__
        if str(exception):
            retry_reason += f': {str(exception)}'

        cause_count = 0
        cause = exception
        while cause := getattr(cause, '__cause__', None):
            cause_cls = cause.__class__.__name__
            retry_reason += f' caused by: {cause_cls}: {str(cause)}'
            cause_count += 1
            if cause_count >= 5:
                break

        task.details['retry_reason'] = retry_reason
        await self.commit_task(task, storage=storage)

        return self.PROCESS_TASK_WITH_NO_PAUSE

    # Abstract method implementation begin

    async def register_worker(self, app: web.Application) -> None:
        async with self.storage_context() as storage:
            assert getattr(self, 'worker_type', None), 'worker_type must be defined'
            worker_mapper = storage[self.mapper_name_worker]

            self.worker = await worker_mapper.create(
                worker_mapper.model(
                    worker_id=self.worker_id,
                    worker_type=self.worker_type,
                    host=self.host,
                    state=WorkerState.RUNNING,
                )
            )
            self.logger.context_push(worker_id=self.worker_id, worker_type=self.worker_type.value)
            self.logger.info('Worker created')

    async def unregister_worker(self, app: web.Application) -> None:
        async with self.storage_context() as storage:
            worker_mapper = storage[self.mapper_name_worker]

            worker = await worker_mapper.get(self.worker_id)
            worker.state = WorkerState.SHUTDOWN
            await worker_mapper.save(worker)

            self.logger.info('Worker shutdown')

    async def heartbeat(self) -> None:
        async with self.storage_context() as storage:
            worker_mapper = storage[self.mapper_name_worker]
            await worker_mapper.heartbeat(self.worker_id)

    async def process_action(self, action_cls: Any, params: Any) -> None:
        await action_cls(**params).run()

    async def process_task(self) -> bool:

        with self.logger:
            async with self.storage_context() as storage:
                try:
                    task = await self.get_task(storage=storage)
                except (StorageNotFound, LockAlreadyAcquired):
                    return self.PROCESS_TASK_WITH_PAUSE
                except (CancelledError, BaseWorkerError):
                    raise
                except Exception:
                    self.logger.exception('Failed to get task')
                    return self.PROCESS_TASK_WITH_PAUSE

                self.logger.context_push(**self.get_process_task_logger_context(task))

            try:
                action_cls = self.get_action_class(task)
            except KeyError:
                self.logger.exception('Failed to get action for task')
                return await self.task_fail(f'Failed to get action {task.action_name} for task', task, storage)

            params = self.get_params(task)

            try:
                await self.process_action(action_cls, params)
            except CancelledError:
                raise
            except Exception as action_exception:
                async with self.storage_context() as storage:
                    should_fail, reason = self.should_fail_task(task, action_cls, action_exception)

                    if isinstance(action_exception, self.suppress_exceptions):
                        self.logger.warning('Failed to run action for task', exc_info=action_exception)
                    else:
                        self.logger.exception('Failed to run action for task')

                    if should_fail:
                        return await self.task_fail(reason, task, storage)
                    else:
                        self.logger.info('Will retry task')
                        return await self.task_retry(task, action_exception, storage)
            else:
                async with self.storage_context() as storage:
                    self.logger.info('Task done')
                    return await self.task_done(task, storage=storage)

    async def _on_exception(self, exc: Exception) -> bool:
        # propagating worker exceptions
        if isinstance(exc, BaseWorkerError):
            return True

        return await super()._on_exception(exc)

    # Abstract method implementation end
