from abc import ABC, abstractmethod
from typing import List

import asyncpg

from maps_adv.manul.lib.db.engine import DB
from maps_adv.manul.lib.db.enums import CurrencyType, RateType

from .exceptions import ClientNotFound, OrderNotFound


class BaseOrdersDataManager(ABC):
    @abstractmethod
    async def create_order(
        self,
        title: str,
        client_id: int,
        product_id: int,
        currency: CurrencyType,
        comment: str,
        rate: RateType,
    ) -> dict:
        raise NotImplementedError()

    @abstractmethod
    async def retrieve_order(self, order_id: int) -> dict:
        raise NotImplementedError()

    @abstractmethod
    async def list_orders(self, ids: List[int]) -> List[dict]:
        raise NotImplementedError()

    @abstractmethod
    async def retrieve_order_ids_for_account(
        self, account_manager_id: int
    ) -> List[int]:
        raise NotImplementedError()


class OrdersDataManager(BaseOrdersDataManager):

    __slots__ = ("_db",)

    _db: DB

    def __init__(self, db: DB):
        self._db = db

    async def create_order(
        self,
        title: str,
        client_id: int,
        product_id: int,
        currency: CurrencyType,
        comment: str,
        rate: RateType,
    ) -> dict:
        sql = """
            INSERT INTO orders (title, client_id, product_id, currency, comment, rate)
            VALUES ($1, $2, $3, $4, $5, $6)
            RETURNING id, title, client_id, product_id, currency,
                comment, rate, created_at
        """

        async with self._db.acquire() as con:
            try:
                row = await con.fetchrow(
                    sql, title, client_id, product_id, currency, comment, rate
                )
            except asyncpg.ForeignKeyViolationError:
                raise ClientNotFound()

            return dict(row)

    async def retrieve_order(self, order_id: int) -> dict:
        sql = """
            SELECT id, title, client_id, product_id, currency, comment, rate, created_at
            FROM orders
            WHERE id=$1
        """

        async with self._db.acquire() as con:
            row = await con.fetchrow(sql, order_id)

        if not row:
            raise OrderNotFound()

        return dict(row)

    async def list_orders(self, ids: List[int]) -> List[dict]:
        sql_by_ids = """
            SELECT id, title, client_id, product_id, currency, comment, rate, created_at
            FROM orders
            WHERE id = ANY($1::int[])
            ORDER BY created_at DESC
        """
        sql_all = """
            SELECT id, title, client_id, product_id, currency, comment, rate, created_at
            FROM orders
            ORDER BY created_at DESC
        """

        async with self._db.acquire() as con:
            if ids:
                rows = await con.fetch(sql_by_ids, ids)
            else:
                rows = await con.fetch(sql_all)

        return [dict(row) for row in rows]

    async def retrieve_order_ids_for_account(
        self, account_manager_id: int
    ) -> List[int]:
        sql = """
            SELECT orders.id as order_id
            FROM orders
            INNER JOIN clients ON orders.client_id = clients.id
            WHERE clients.account_manager_id=$1
        """

        async with self._db.acquire() as con:
            rows = await con.fetch(sql, account_manager_id)

        return [row["order_id"] for row in rows]
