from datetime import datetime

from maps_adv.stat_controller.server.lib.db import DB, DbTaskStatus

__all__ = ["UnexpectedTransactionMode", "SystemOpManager"]


class UnexpectedTransactionMode(Exception):
    pass


class SystemOpManager:
    __slots__ = ("_db",)

    _db: DB

    def __init__(self, db: DB):
        self._db = db

    async def validate_db_connection(self):
        async with self._db.acquire() as con:
            rw_transaction_read_only = await con.fetchval(
                query="SHOW transaction_read_only", column="transaction_read_only"
            )

        if rw_transaction_read_only == "on":
            raise UnexpectedTransactionMode("Unexpected read-only mode")

        async with self._db.acquire("ro") as con:
            await con.execute("SELECT 1")

    async def mark_expired_tasks_as_failed(self, expired_at: datetime):
        sql = (
            "UPDATE tasks SET status = NULL "
            "FROM tasks_log "
            "WHERE tasks.current_log_id = tasks_log.id"
            "   AND tasks.status IN ($1, $2, $3, $4, $5, $6, $7)"
            "   AND tasks_log.created <= $8"
        )

        async with self._db.acquire() as con:
            await con.execute(sql, *DbTaskStatus.in_progress(), expired_at)
