#!/usr/bin/env python3
from typing import Any, ClassVar, Optional, Type

from aiohttp import web

from sendr_aiopg.types import EngineUnion
from sendr_qlog.http.aiohttp import get_middleware_logging_adapter
from sendr_taskqueue import (
    BaseActionStorageWorker, BaseStorageArbiterWorker, BaseStorageWorker, BaseStorageWorkerApplication
)
from sendr_utils import copy_context

from mail.beagle.beagle.api.app import BeagleApplication
from mail.beagle.beagle.api.routes.utility import UTILITY_ROUTES
from mail.beagle.beagle.conf import settings
from mail.beagle.beagle.core.actions.base import BaseAction
from mail.beagle.beagle.core.actions.smtp_cache import GenerateSMTPCacheAction
from mail.beagle.beagle.core.actions.sync.sync_organization import SyncCurrentOrganizationAction
from mail.beagle.beagle.core.actions.transact_email import TransactEmailAction
from mail.beagle.beagle.core.entities.enums import TaskType, WorkerType
from mail.beagle.beagle.core.entities.task import Task
from mail.beagle.beagle.interactions.base import BaseInteractionClient, create_connector
from mail.beagle.beagle.storage import Storage, StorageContext
from mail.beagle.beagle.utils.db import create_configured_engine
from mail.beagle.beagle.utils.stats import queue_size_gauge, queue_tasks_counter, queue_tasks_gauge, queue_tasks_time


class BaseWorker(BaseStorageWorker):
    storage_context_cls = StorageContext
    app: BeagleApplication

    async def process_action(self, action_cls: Any, params: Any) -> None:
        with queue_tasks_time.labels(self.worker_type.value).time:
            await super().process_action(action_cls, params)

    async def task_fail(self, reason: Optional[str], task: Task, storage: Storage) -> bool:
        queue_tasks_counter.labels('fail').inc()
        return await super().task_fail(reason, task, storage)

    async def task_done(self, task: Task, storage: Storage) -> bool:
        queue_tasks_counter.labels('done').inc()
        return await super().task_done(task, storage)

    @copy_context
    async def _run(self):
        BaseAction.context.logger = self.logger
        BaseAction.context.request_id = self.request_id
        BaseAction.context.db_engine = self.app.db_engine
        BaseAction.context.storage = None

        return await super()._run()


class ArbiterWorker(BaseStorageArbiterWorker):
    CHECK_WORKERS_ACTIVE = True
    KILL_ON_CLEANUP = True

    storage_context_cls = StorageContext
    worker_heartbeat_period = settings.WORKER_HEARTBEAT_PERIOD

    async def count_tasks(self, storage: Storage) -> None:
        for type_ in TaskType:
            queue_tasks_gauge.labels(type_.value).observe(0)
        async for type_, count in storage.task.count_pending_by_type():
            queue_tasks_gauge.labels(type_.value).observe(count)
        queue_size = await storage.task.get_size()
        queue_size_gauge.observe(queue_size)


class ActionWorker(BaseWorker, BaseActionStorageWorker):
    task_type = TaskType.RUN_ACTION
    worker_type = WorkerType.RUN_ACTION
    actions = (GenerateSMTPCacheAction, TransactEmailAction,)
    retry_exceptions = ()


class SyncOrganizationWorker(BaseWorker):
    worker_type = WorkerType.SYNC_ORGANIZATION
    task_action_mapping = {
        TaskType.SYNC_ORGANIZATION: SyncCurrentOrganizationAction,
    }
    max_retries: ClassVar[int] = settings.SYNC_ORG_MAX_RETRIES

    def should_retry_exception(self, action_cls: Type[BaseAction], action_exception: Exception) -> bool:
        return super().should_retry_exception(action_cls, action_exception)

    async def fetch_task_for_work(self, storage: Storage) -> Task:
        task_mapper = storage[self.mapper_name_task]

        task = await task_mapper.get_for_work_by_org(task_types=self.task_types)
        await task_mapper.delete_duplicates_by_org(task)

        return task


class BeagleWorkerApplication(BaseStorageWorkerApplication):
    debug = settings.DEBUG
    arbiter_cls = ArbiterWorker
    sentry_dsn = settings.SENTRY_DSN
    middlewares = (
        get_middleware_logging_adapter(),
    )
    routes = UTILITY_ROUTES
    workers = [
        (ActionWorker, settings.ACTION_WORKERS),
        (SyncOrganizationWorker, settings.SYNC_ORGANIZATION_WORKERS),
    ]

    async def create_connector(self, _: Any) -> None:
        BaseInteractionClient.CONNECTOR = create_connector()

    async def close_connector(self, _: Any) -> None:
        await BaseInteractionClient.close_connector()

    async def setup(self, app: web.Application) -> None:
        await super().setup(app)
        await self.create_connector(app)
        self.on_cleanup.append(self.close_connector)

    async def open_engine(self) -> EngineUnion:
        return create_configured_engine()
