from contextlib import asynccontextmanager, nullcontext
from datetime import datetime
from typing import Optional

from asyncpg import Connection
from asyncpg.exceptions import LockNotAvailableError

from maps_adv.warden.server.lib.db import DB
from smb.common.pgswim.lib.engine import PoolType

__all__ = [
    "UnknownTaskType",
    "TaskNotFound",
    "DataManager",
    "AbstractDataManager",
    "ConflictOperation",
]


class UnknownTaskType(Exception):
    pass


class TaskNotFound(Exception):
    pass


class ConflictOperation(Exception):
    pass


class AbstractDataManager:
    async def create_task(
        self,
        executor_id: str,
        type_name: str,
        scheduled_time: datetime,
        metadata: Optional[dict] = None,
        con: Optional[Connection] = None,
    ) -> dict:
        raise NotImplementedError()

    async def update_task(
        self,
        executor_id: str,
        task_id: int,
        status: str,
        metadata: Optional[dict] = None,
        con: Optional[Connection] = None,
    ) -> None:
        raise NotImplementedError()

    async def restore_task(
        self,
        executor_id: str,
        task_id: int,
        status: str,
        metadata: Optional[dict] = None,
        con: Optional[Connection] = None,
    ) -> None:
        raise NotImplementedError()

    async def retrieve_task_details(
        self, task_id: int, con: Optional[Connection] = None
    ) -> dict:
        raise NotImplementedError()

    async def retrieve_task_type_details(
        self, type_name: str, con: Optional[Connection] = None
    ) -> dict:
        raise NotImplementedError()

    async def find_last_task_of_type(
        self, type_id: int, con: Optional[Connection] = None
    ) -> Optional[dict]:
        raise NotImplementedError()

    async def retrieve_active_task_details(
        self, type_id: int, con: Optional[Connection] = None
    ) -> Optional[dict]:
        raise NotImplementedError()

    async def mark_task_as_failed(self, task_id: int, con: Optional[Connection] = None):
        raise NotImplementedError()

    async def mark_tasks_as_failed(self, extra_time: int):
        raise NotImplementedError()

    async def retrieve_last_failed_task_details(
        self, type_id: int, con: Optional[Connection] = None
    ) -> Optional[dict]:
        raise NotImplementedError()

    async def is_executor_id_exists(
        self, executor_id: str, con: Optional[Connection] = None
    ) -> bool:
        raise NotImplementedError()

    def lock(self, type_name: str):
        raise NotImplementedError()


class DataManager(AbstractDataManager):
    __slots__ = ("_db",)

    _db: DB

    def __init__(self, db: DB):
        self._db = db

    async def create_task(
        self,
        executor_id: str,
        type_id: int,
        scheduled_time: datetime,
        metadata: Optional[dict] = None,
        con: Optional[Connection] = None,
    ) -> dict:
        create_task_sql = """
            INSERT INTO tasks (type_id, status, scheduled_time)
            VALUES ($1, 'accepted', $2)
            RETURNING id
        """
        create_log_sql = """
            WITH new_log AS (
                INSERT INTO tasks_log (task_id, status, executor_id, metadata)
                SELECT $1, 'accepted', $2, $3
                RETURNING id
            )
            UPDATE tasks
            SET current_log_id = new_log.id
            FROM new_log
            WHERE tasks.id = $1
        """

        async with self._con(con) as con:
            async with con.transaction():
                task_id = await con.fetchval(create_task_sql, type_id, scheduled_time)
                await con.execute(create_log_sql, task_id, executor_id, metadata)

        return dict(task_id=task_id, status="accepted")

    async def update_task(
        self,
        executor_id: str,
        task_id: int,
        status: str,
        metadata: Optional[dict] = None,
        con: Optional[Connection] = None,
    ) -> None:
        sql = """
            WITH new_log AS (
                INSERT INTO tasks_log (task_id, status, executor_id, metadata)
                SELECT $1, $2, $3, $4
                RETURNING id
            )
            UPDATE tasks
            SET status = $2, current_log_id = new_log.id
            FROM new_log
            WHERE tasks.id = $1
        """

        async with self._con(con) as con:
            await con.execute(sql, task_id, status, executor_id, metadata)

    async def restore_task(
        self,
        executor_id: str,
        task_id: int,
        status: str,
        metadata: Optional[dict] = None,
        con: Optional[Connection] = None,
    ) -> None:
        sql = """
            WITH new_log AS (
                INSERT INTO tasks_log (task_id, status, executor_id, metadata)
                SELECT $1, $2, $3, $4
                RETURNING id
            )
            UPDATE tasks
            SET status = $2, current_log_id = new_log.id, intake_time = now()
            FROM new_log
            WHERE tasks.id = $1
        """

        async with self._con(con) as con:
            await con.execute(sql, task_id, status, executor_id, metadata)

    async def retrieve_task_details(
        self, task_id: int, con: Optional[Connection] = None
    ) -> dict:
        sql = """
            SELECT
                tasks.status,
                tasks_log.executor_id,
                tasks_log.metadata,
                tasks.scheduled_time
            FROM tasks
            JOIN tasks_log ON tasks.current_log_id = tasks_log.id
            WHERE tasks.id = $1
        """

        async with self._con(con) as con:
            row = await con.fetchrow(sql, task_id)

        if not row:
            raise TaskNotFound()

        return dict(row)

    async def retrieve_task_type_details(self, type_name: str, con=None) -> dict:
        sql = """
            SELECT id, time_limit, schedule, restorable
            FROM task_types WHERE name = $1
        """

        async with self._con(con) as con:
            got = await con.fetchrow(sql, type_name)

        if not got:
            raise UnknownTaskType(type_name)

        return dict(got)

    async def find_last_task_of_type(
        self, type_id: int, con: Optional[Connection] = None
    ) -> Optional[dict]:
        sql = """
            SELECT id, created, scheduled_time
            FROM tasks
            WHERE type_id = $1
            ORDER BY created DESC
            LIMIT 1
        """

        async with self._con(con) as con:
            got = await con.fetchrow(sql, type_id)

        if got:
            return dict(got)

    async def retrieve_active_task_details(
        self, type_id: int, con: Optional[Connection] = None
    ) -> Optional[dict]:
        sql = """
            WITH active_task AS (
                SELECT DISTINCT ON (type_id)
                    id, status, intake_time, current_log_id
                FROM tasks
                WHERE status NOT IN ('completed', 'failed') and type_id = $1
                ORDER BY type_id, created DESC
            )
            SELECT
                active_task.id,
                active_task.status,
                active_task.intake_time,
                tasks_log.metadata
            FROM tasks_log
            INNER JOIN active_task ON active_task.current_log_id = tasks_log.id
        """

        async with self._con(con) as con:
            got = await con.fetchrow(sql, type_id)

        if got:
            return dict(got)

    async def mark_task_as_failed(self, task_id: int, con: Optional[Connection] = None):
        sql = """
            WITH failed_log AS (
                INSERT INTO tasks_log (task_id, status)
                SELECT id, 'failed'
                FROM tasks
                WHERE id = $1 AND status != 'failed'
                RETURNING id
            )
            UPDATE tasks
            SET status = 'failed', current_log_id = failed_log.id
            FROM failed_log
            WHERE tasks.id = $1
        """

        async with self._con(con) as con:
            await con.execute(sql, task_id)

    async def mark_tasks_as_failed(self, extra_time: int) -> None:
        sql = """
            WITH failed_logs AS (
                INSERT INTO tasks_log (task_id, status)
                SELECT tasks.id, 'failed'
                FROM tasks
                JOIN task_types ON tasks.type_id = task_types.id
                WHERE tasks.status NOT IN ('completed', 'failed')
                    AND now() > (tasks.intake_time
                        + concat(task_types.time_limit, ' seconds')::interval
                        + concat($1::int, ' seconds')::interval)
                RETURNING id, task_id
            )
            UPDATE tasks
            SET status = 'failed', current_log_id = failed_logs.id
            FROM failed_logs
            WHERE tasks.id = failed_logs.task_id
        """

        async with self._db.acquire() as con:
            await con.execute(sql, extra_time)

    async def retrieve_last_failed_task_details(
        self, type_id: int, con: Optional[Connection] = None
    ) -> Optional[dict]:
        sql = """
            SELECT task_id as id, status, metadata
            FROM tasks_log
            WHERE status != 'failed' AND task_id = (
                SELECT id
                FROM tasks
                WHERE status = 'failed' and type_id = $1
                ORDER BY created DESC
                LIMIT 1
            )
            ORDER BY created DESC
            LIMIT 1
        """

        async with self._con(con) as con:
            got = await con.fetchrow(sql, type_id)

        if got:
            return dict(got)

    async def is_executor_id_exists(
        self, executor_id: str, con: Optional[Connection] = None
    ) -> bool:
        sql = """
            SELECT EXISTS(
                SELECT 1
                FROM tasks_log
                WHERE executor_id = $1
            )
        """

        async with self._con(con) as con:
            return await con.fetchval(sql, executor_id)

    @asynccontextmanager
    async def lock(self, type_name: str) -> Connection:
        async with self._db.acquire() as con, con.transaction():
            try:
                sql = """
                    SELECT id
                    FROM task_types
                    WHERE name = $1
                    FOR UPDATE NOWAIT
                """
                await con.execute(sql, type_name)
                yield con

            except LockNotAvailableError:
                raise ConflictOperation()

    @asynccontextmanager
    async def _con(
        self, con: Optional[Connection] = None, pool_type: PoolType = PoolType.master
    ) -> Connection:
        async with self._db.acquire(pool_type) if not con else nullcontext(con) as con:
            yield con
