import enum
from typing import Optional, Tuple

from maps_adv.common.helpers import Converter
from maps_adv.stat_controller.server.lib.db import DbTaskStatus

from . import base

__all__ = ["TaskStatus", "TaskManager"]


class TaskStatus(enum.Enum):
    accepted = "accepted"
    context_received = "context_received"
    calculation_completed = "calculation_completed"
    billing_notified = "billing_notified"
    charged_data_sent = "charged_data_sent"
    completed = "completed"

    @classmethod
    def in_progress(cls) -> Tuple["TaskStatus"]:
        return (
            cls.accepted,
            cls.context_received,
            cls.calculation_completed,
            cls.billing_notified,
            cls.charged_data_sent,
        )


converter = Converter(
    (
        (TaskStatus.accepted, DbTaskStatus.accepted_by_charger),
        (TaskStatus.context_received, DbTaskStatus.charger_received_context),
        (TaskStatus.calculation_completed, DbTaskStatus.charger_completed_calculation),
        (TaskStatus.billing_notified, DbTaskStatus.charger_notified_billing),
        (TaskStatus.charged_data_sent, DbTaskStatus.charger_sent_charged_data),
        (TaskStatus.completed, DbTaskStatus.charged),
    )
)


class TaskManager(base.BlockingManager):
    name: str = "charger"

    async def update(
        self, executor_id: str, task_id: int, status: TaskStatus, state=None, con=None
    ):
        db_status = converter.forward(status)

        sql_with_state = (
            "with tasks_log_rows as ("
            "INSERT INTO tasks_log (task_id, executor_id, status, execution_state) "
            "VALUES ($1, $2, $3, $4) RETURNING id, status"
            ")"
            "UPDATE tasks "
            "SET current_log_id = tasks_log_rows.id, status = tasks_log_rows.status "
            "FROM tasks_log_rows WHERE tasks.id = $1"
        )
        sql_without_state = (
            "with tasks_log_rows as ("
            "INSERT INTO tasks_log (task_id, executor_id, status) "
            "VALUES ($1, $2, $3) RETURNING id, status"
            ")"
            "UPDATE tasks "
            "SET current_log_id = tasks_log_rows.id, status = tasks_log_rows.status "
            "FROM tasks_log_rows WHERE tasks.id = $1"
        )

        async with self.connection(con) as con:
            if state:
                await con.execute(
                    sql_with_state, task_id, executor_id, db_status, state
                )
            else:
                await con.execute(sql_without_state, task_id, executor_id, db_status)

    async def find_oldest_available(self, con=None) -> Optional[dict]:
        sql = (
            "SELECT tasks.id, tasks.timing_from, tasks.timing_to, "
            "tasks_log.status, tasks_log.execution_state "
            "FROM tasks JOIN tasks_log ON tasks.current_log_id = tasks_log.id "
            "WHERE tasks.status = $1 "
            "OR (tasks.status IS NULL AND tasks_log.status IN ($2, $3, $4, $5, $6)) "
            "ORDER BY tasks.timing_to "
            "LIMIT 1"
        )

        async with self.connection(con) as con:
            row = await con.fetchrow(
                sql, DbTaskStatus.normalized, *DbTaskStatus.charger_in_progress()
            )
            if not row:
                return

            data = dict(row)
            if data["status"] == DbTaskStatus.normalized:
                del data["status"]
                del data["execution_state"]
            else:
                data["status"] = converter.reversed(data["status"])

        return data

    async def is_there_in_progress(self, con=None) -> bool:
        sql = "SELECT EXISTS(SELECT id FROM tasks WHERE status IN ($1, $2, $3, $4, $5))"

        async with self.connection(con) as con:
            return await con.fetchval(sql, *DbTaskStatus.charger_in_progress())

    async def retrieve_details(self, task_id: int) -> dict:
        async with self._db.acquire() as con:
            row = await base.retrieve_task_details(task_id, con)

        row["status"] = converter.reversed(row["status"])
        return row
