from abc import ABC, abstractmethod
from typing import List

from asyncpg import UniqueViolationError

from maps_adv.manul.lib.db.engine import DB

from .exceptions import ClientExists, ClientNotFound


class BaseClientsDataManager(ABC):
    @abstractmethod
    async def create_client(self, name: str, account_manager_id: int = None) -> dict:
        raise NotImplementedError()

    @abstractmethod
    async def update_client(self, client_id: int, name: str) -> dict:
        raise NotImplementedError()

    @abstractmethod
    async def set_account_manager_for_client(
        self, client_id: int, account_manager_id: int
    ) -> None:
        raise NotImplementedError()

    @abstractmethod
    async def retrieve_client(self, client_id: int) -> dict:
        raise NotImplementedError()

    @abstractmethod
    async def list_clients(self) -> List[dict]:
        raise NotImplementedError()


class ClientsDataManager(BaseClientsDataManager):
    _db: DB

    def __init__(self, db: DB):
        self._db = db

    async def create_client(self, name: str, account_manager_id: int = None) -> dict:
        sql = """
            INSERT INTO clients (name, account_manager_id)
            VALUES ($1, $2)
            RETURNING id, name, account_manager_id
        """

        async with self._db.acquire() as con:
            try:
                row = await con.fetchrow(sql, name, account_manager_id)
            except UniqueViolationError:
                raise ClientExists()
            else:
                return dict(row)

    async def update_client(self, client_id: int, name: str) -> dict:
        sql = """
        WITH clients_data AS (
            UPDATE clients
            SET name = $1
            WHERE id = $2
            RETURNING id, name
        )
        SELECT clients_data.id AS id,
        clients_data.name AS name,
        COUNT(orders.*) AS orders_count
        FROM clients_data LEFT OUTER JOIN orders ON clients_data.id = orders.client_id
        GROUP BY clients_data.id, clients_data.name
        """

        async with self._db.acquire() as con:
            row = await con.fetchrow(sql, name, client_id)

        if not row:
            raise ClientNotFound(client_id)

        return dict(row)

    async def set_account_manager_for_client(
        self, client_id: int, account_manager_id: int
    ) -> None:

        sql = """
        UPDATE clients
            SET account_manager_id = $1
            WHERE id = $2
            RETURNING id
        """

        async with self._db.acquire() as con:
            row = await con.fetchrow(sql, account_manager_id, client_id)

        if not row:
            raise ClientNotFound(client_id)

    async def retrieve_client(self, client_id: int) -> dict:
        sql = """
            SELECT
                clients.id AS id,
                clients.name AS name,
                Count(orders.id) AS orders_count,
                clients.account_manager_id AS account_manager_id
            FROM clients LEFT OUTER JOIN orders ON clients.id = orders.client_id
            WHERE clients.id = $1
            GROUP BY clients.id, clients.name
        """

        async with self._db.acquire() as con:
            row = await con.fetchrow(sql, client_id)

        if not row:
            raise ClientNotFound(client_id)

        return dict(row)

    async def list_clients(self) -> List[dict]:
        sql = """
            SELECT
                clients.id AS id,
                clients.name AS name,
                Count(orders.id) AS orders_count,
                clients.account_manager_id AS account_manager_id
            FROM clients LEFT OUTER JOIN orders ON clients.id = orders.client_id
            GROUP BY clients.id, clients.name
            ORDER BY clients.created_at DESC
        """
        async with self._db.acquire() as con:
            rows = await con.fetch(sql)

        return [dict(row) for row in rows]
