import asyncio
import copy
import logging
from typing import Callable, Optional

from smb.common.sensors import MetricGroup

from maps_adv.warden.client.lib.client import ClientFactory
from maps_adv.warden.client.lib.exceptions import (
    Conflict,
    TaskTypeAlreadyAssigned,
    TooEarlyForNewTask,
)

from .context import TaskContext


class PeriodicalTask:
    __slots__ = "_task_type", "_func", "_params", "_logger", "_sensors"

    _relaunch_interval_after_exception: float = 60

    _task_type: str
    _params: dict
    _func: Callable
    _sensors: Optional[MetricGroup]

    def __init__(
        self,
        task_type: str,
        func: Callable,
        *,
        sensors: Optional[MetricGroup] = None,
        params: Optional[dict] = None,
        logger: Optional = None,
    ):
        self._task_type = task_type
        self._func = func
        self._params = params if params is not None else {}
        self._logger = logger if logger else logging.getLogger(__name__)
        self._sensors = sensors

    async def __call__(self, client_factory: ClientFactory):
        try:
            if self._sensors:
                self._sensors.take(
                    task_status="requested", task_name=self._task_type
                ).inc()
            async with client_factory.context(self._task_type) as client:
                if self._sensors:
                    self._sensors.take(
                        task_status="accepted", task_name=self._task_type
                    ).inc()
                await asyncio.wait_for(
                    self._func(
                        TaskContext(
                            client=client,
                            status=client.status,
                            metadata=client.metadata,
                            params=copy.deepcopy(self._params),
                        )
                    ),
                    timeout=client.time_limit,
                )
                if self._sensors:
                    self._sensors.take(
                        task_status="completed", task_name=self._task_type
                    ).inc()

        except asyncio.TimeoutError:
            self._logger.warning(f'Task "{self._task_type}" failed by timeout')
            if self._sensors:
                self._sensors.take(
                    task_status="failed (timeout)", task_name=self._task_type
                ).inc()

        except (Conflict, TaskTypeAlreadyAssigned, TooEarlyForNewTask):
            self._logger.info(f'Conflict creation new task type of "{self._task_type}"')
            if self._sensors:
                self._sensors.take(
                    task_status="failed (conflict)", task_name=self._task_type
                ).inc()
            await asyncio.sleep(self._relaunch_interval_after_exception)

        except Exception as e:
            self._logger.exception(str(e))
            if self._sensors:
                self._sensors.take(
                    task_status="failed", task_name=self._task_type
                ).inc()
            await asyncio.sleep(self._relaunch_interval_after_exception)

    async def run(self, client_factory: ClientFactory):
        try:
            while True:
                await self(client_factory)

        except asyncio.CancelledError:
            self._logger.info(f'Canceled main loop for "{self._task_type}"')

        except Exception:
            self._logger.exception(f'Crash main loop for "{self._task_type}"')
