from typing import Any, Dict, Optional

from sendr_aiopg import StorageBase
from sendr_taskqueue import BaseStorageWorker, BaseStorageWorkerApplication
from sendr_utils import copy_context

from mail.ipa.ipa.conf import settings
from mail.ipa.ipa.core.actions.base import BaseAction
from mail.ipa.ipa.core.entities.task import Task
from mail.ipa.ipa.storage import Storage, StorageContext
from mail.ipa.ipa.utils.stats import queue_tasks_counter, queue_tasks_time


class BaseWorker(BaseStorageWorker):
    storage_context_cls = StorageContext
    app: BaseStorageWorkerApplication

    heartbeat_period = settings.TASKQ_WORKER_HEARTBEAT_PERIOD
    max_retries = settings.TASKQ_MAX_RETRIES
    retry_initial_delay = settings.TASKQ_RETRY_INITIAL_DELAY
    retry_delay_multiplier = settings.TASKQ_RETRY_DELAY_MULTIPLIER

    @copy_context
    async def _run(self):
        BaseAction.context.logger = self.logger
        BaseAction.context.request_id = self.request_id
        BaseAction.context.db_engine = self.app.db_engine
        BaseAction.context.storage = None

        return await super()._run()

    async def task_retry(self, task: Task, exception: Exception, storage: StorageBase) -> bool:
        return await super().task_retry(task, exception, storage)

    @copy_context
    async def process_task(self) -> bool:
        return await super().process_task()

    async def get_task(self, storage: StorageBase) -> Task:
        task = await super().get_task(storage=storage)

        BaseAction.context.meta_task_id = task.root_task_id
        return task

    def get_process_task_logger_context(self, task: Task) -> Dict[str, Any]:
        log_context = super().get_process_task_logger_context(task)
        log_context['entity_id'] = task.entity_id
        if task.meta_task_id:
            log_context['meta_task_id'] = task.meta_task_id
        return log_context

    def get_params(self, task: Task) -> Dict[str, Any]:
        action_cls = self.get_action_class(task)

        return action_cls.deserialize_kwargs(task.params)

    async def process_action(self, action_cls: Any, params: Any) -> None:
        with queue_tasks_time.labels(self.worker_type.value).time:
            self.logger.info('Starting task')
            await super().process_action(action_cls, params)

    async def task_fail(self, reason: Optional[str], task: Task, storage: Storage) -> bool:
        queue_tasks_counter.labels('fail').inc()
        return await super().task_fail(reason, task, storage)

    async def task_done(self, task: Task, storage: Storage) -> bool:
        queue_tasks_counter.labels('done').inc()
        return await super().task_done(task, storage)
