from typing import ClassVar, Iterable

from sqlalchemy import func
from sqlalchemy.sql.expression import text

from sendr_aiopg import StorageBase
from sendr_taskqueue import BaseWorkerApplication
from sendr_taskqueue.worker.base.arbiter import BaseArbiterWorker
from sendr_taskqueue.worker.storage.db.entities import TaskState, Worker, WorkerState
from sendr_taskqueue.worker.storage.mixins import StorageMixin

TASK_STATE_MAP: dict = {}


class BaseStorageArbiterWorker(StorageMixin, BaseArbiterWorker):
    worker_heartbeat_period: ClassVar[int] = 60
    worker_heartbeat_period_multiplicator: ClassVar[int] = 3
    fail_workers_limit: ClassVar[int] = 100

    async def fail_workers(self, storage: StorageBase) -> None:
        worker_storage = storage[self.mapper_name_worker]
        interval = self.worker_heartbeat_period * self.worker_heartbeat_period_multiplicator
        beat_before = func.now() - text(f"INTERVAL '{interval} SECONDS'")
        for_fail = [
            x async for x in worker_storage.find(state=WorkerState.RUNNING,
                                                 beat_before=beat_before,
                                                 limit=self.fail_workers_limit,
                                                 for_update=True)
        ]

        for worker in for_fail:
            with self.logger:
                self.logger.context_push(
                    worker_id=worker.worker_id,
                    worker_type=worker.worker_type.value,
                    host=worker.host,
                    heartbeat=worker.heartbeat,
                    startup=worker.startup,
                )
                self.logger.info('Failing worker by timeout')
                worker.state = WorkerState.FAILED
                await worker_storage.save(worker)

    async def cleanup_worker_task(self, storage: StorageBase, worker: Worker) -> None:
        task = await self.get_worker_processing_task(storage, worker)

        if task is None:
            return None

        task_mapper = storage[self.mapper_name_task]
        with self.logger:
            if task.state != TaskState.PROCESSING:
                return

            self.logger.context_push(task_type=task.task_type.value, task_state=task.state.value)
            self.logger.info('Cleanup task')

            task.state = TASK_STATE_MAP.get(task.task_type, TaskState.PENDING)

            await task_mapper.save(task)

        return None

    async def clean_one_worker(self, storage: StorageBase, states: Iterable[WorkerState]) -> bool:
        iterate = False
        worker_mapper = storage[self.mapper_name_worker]

        async for worker in worker_mapper.find(states=states, limit=1, for_update=True):
            self.logger.context_push(
                worker_id=worker.worker_id,
                worker_type=worker.worker_type.value,
                host=worker.host,
            )
            self.logger.info('Cleaning worker')

            await self.cleanup_worker_task(storage, worker)

            worker.task_id = None
            worker.state = WorkerState.CLEANEDUP
            await worker_mapper.save(worker)

            self.logger.info('Worker cleaned')
            iterate = True
        return iterate

    async def get_worker_processing_task(self, storage, worker):
        task_id = worker.task_id

        if task_id is None:
            return None

        task_mapper = storage[self.mapper_name_task]
        with self.logger:
            self.logger.context_push(task_id=task_id)
            try:
                return await task_mapper.get(task_id)
            except task_mapper.model.DoesNotExist:
                self.logger.info('Worker task for cleanup not found')
                return None

    async def cleanup_workers(self, storage: StorageBase, states: Iterable[WorkerState]) -> None:
        iterate = True
        while iterate:
            with self.logger:
                iterate = await self.clean_one_worker(storage, states)

    async def count_tasks(self, storage: StorageBase) -> None:
        pass

    async def clean_tasks(self, app: BaseWorkerApplication) -> None:
        async with self.storage_context(transact=True) as storage:
            await self.fail_workers(storage)
            await self.cleanup_workers(storage, (WorkerState.SHUTDOWN, WorkerState.FAILED))
            await self.count_tasks(storage)
