import enum
from datetime import datetime
from typing import Optional

from asyncpg import Connection

from maps_adv.common.helpers import Converter
from maps_adv.stat_controller.server.lib.db import DbTaskStatus

from . import base

__all__ = ["TaskStatus", "TaskManager", "UnexpectedNaiveDateTime"]


class UnexpectedNaiveDateTime(Exception):
    pass


class TaskStatus(enum.Enum):
    accepted = "accepted"
    completed = "completed"


converter = Converter(
    (
        (TaskStatus.accepted, DbTaskStatus.accepted_by_normalizer),
        (TaskStatus.completed, DbTaskStatus.normalized),
    )
)


class TaskManager(base.BlockingManager):
    name: str = "normalizer"

    async def create(
        self, executor_id: str, timing_from: datetime, timing_to: datetime, con=None
    ) -> int:
        if not all([timing_from.tzinfo, timing_to.tzinfo]):
            raise UnexpectedNaiveDateTime()

        sql = "INSERT INTO tasks (timing_from, timing_to) VALUES ($1, $2) RETURNING id"

        async with self.connection(con) as con:
            async with con.transaction():

                task_id = await con.fetchval(sql, timing_from, timing_to)
                await self._update(executor_id, task_id, TaskStatus.accepted, con)

        return task_id

    async def update(self, executor_id: str, task_id: int, status: TaskStatus):
        async with self._db.acquire() as con:
            async with con.transaction():
                await self._update(executor_id, task_id, status, con)

    async def retrieve_details(self, task_id: int) -> dict:
        async with self._db.acquire() as con:
            row = await base.retrieve_task_details(task_id, con)

        row["status"] = converter.reversed(row["status"])
        return row

    async def find_last(self, con=None) -> Optional[dict]:
        sql = (
            "SELECT id, timing_from, timing_to "
            "FROM tasks "
            "ORDER BY timing_to DESC "
            "LIMIT 1"
        )

        async with self.connection(con) as con:
            row = await con.fetchrow(sql)

        if row:
            return dict(row)

    async def find_failed_task(self, con=None) -> Optional[dict]:
        sql = (
            "SELECT tasks.id, tasks.timing_from, tasks.timing_to "
            "FROM tasks JOIN tasks_log ON tasks.current_log_id = tasks_log.id "
            "WHERE tasks.status IS NULL AND tasks_log.status = $1 "
        )

        async with self.connection(con) as con:
            row = await con.fetchrow(sql, DbTaskStatus.accepted_by_normalizer)

        if row:
            return dict(row)

    async def is_there_in_progress(self, con=None) -> bool:
        sql = "SELECT EXISTS(SELECT id FROM tasks WHERE status = $1)"

        async with self.connection(con) as con:
            return await con.fetchval(sql, DbTaskStatus.accepted_by_normalizer)

    async def _update(
        self, executor_id: str, task_id: int, status: TaskStatus, con: Connection
    ):
        db_status = converter.forward(status)

        sql = (
            "with tasks_log_rows as ("
            "INSERT INTO tasks_log (task_id, executor_id, status) "
            "VALUES ($1, $2, $3) RETURNING id, status"
            ")"
            "UPDATE tasks "
            "SET current_log_id = tasks_log_rows.id, status = tasks_log_rows.status "
            "FROM tasks_log_rows WHERE tasks.id = $1"
        )
        await con.execute(sql, task_id, executor_id, db_status)
