from abc import ABC, abstractmethod
from typing import Dict, List, Optional

from smb.common.pgswim import PoolType, SwimEngine

from . import sqls
from .enums import SegmentType

__all__ = ["BaseDataManager", "DataManager"]


class BaseDataManager(ABC):
    __slots__ = ()

    @abstractmethod
    async def add_business(self, biz_id: int, permalink: int, counter_id: int) -> None:
        raise NotImplementedError

    @abstractmethod
    async def list_business_segments_data(self, biz_id: int) -> Optional[dict]:
        raise NotImplementedError

    @abstractmethod
    async def add_business_segment(
        self, biz_id: int, name: str, cdp_id: int, cdp_size: int, type_: SegmentType
    ) -> None:
        raise NotImplementedError

    @abstractmethod
    async def list_biz_ids(self) -> List[int]:
        raise NotImplementedError

    @abstractmethod
    async def update_segments_sizes(self, sizes: Dict[int, int]) -> None:
        raise NotImplementedError


class DataManager(BaseDataManager):
    __slots__ = ["_db"]

    _db: SwimEngine

    def __init__(self, db: SwimEngine):
        self._db = db

    async def add_business(self, biz_id: int, permalink: int, counter_id: int) -> None:
        async with self._db.acquire(PoolType.master) as con:
            await con.execute(
                sqls.add_business,
                biz_id,
                permalink,
                counter_id,
            )

    async def list_business_segments_data(self, biz_id: int) -> Optional[dict]:
        async with self._db.acquire(PoolType.replica) as con:
            row = await con.fetchrow(sqls.fetch_business_segments_data, biz_id)

        return dict(row) if row else None

    async def add_business_segment(
        self, biz_id: int, name: str, cdp_id: int, cdp_size: int, type_: SegmentType
    ) -> None:
        async with self._db.acquire(PoolType.master) as con:
            await con.execute(
                sqls.add_business_segment, biz_id, name, cdp_id, cdp_size, type_.value
            )

    async def list_biz_ids(self) -> List[int]:
        async with self._db.acquire(PoolType.replica) as con:
            return await con.fetchval(sqls.list_biz_ids)

    async def update_segments_sizes(self, sizes: Dict[int, int]) -> None:
        async with self._db.acquire(PoolType.master) as con:
            async with con.transaction():
                await con.execute(sqls.create_temp_sizes_table)
                await con.copy_records_to_table(
                    "cdp_sizes",
                    records=sizes.items(),
                    columns=["cdp_id", "cdp_size"],
                )
                await con.execute(sqls.update_segments_sizes_from_temp_table)
