from abc import ABC, abstractmethod
from asyncpg import Connection
from typing import Union

__all__ = [
    "CashbackProgramsSynchronizer"
]


class TableSynchronizer(ABC):
    DB_TABLE: str
    COLUMNS: list[str]
    WRITE_CHUNK_SIZE: int = 100_000

    def __init__(self, con: Connection):
        self._con = con

    async def process_data(self, rows: list[Union[tuple, dict]]) -> None:
        await self.clean_existing_data()
        await self.create_tmp_table()
        rows_to_add_chunk = []
        for row in rows:
            rows_to_add = self.process_row(row)

            if await self.is_duplicate(rows_to_add[0]):
                await self.update_row(rows_to_add[0])
                continue

            rows_to_add_chunk.extend(rows_to_add)
            if len(rows_to_add_chunk) == self.WRITE_CHUNK_SIZE:
                await self._write_rows(rows_to_add_chunk)
                rows_to_add_chunk.clear()

        if rows_to_add_chunk:
            await self._write_rows(rows_to_add_chunk)

        await self.write_to_table()

    async def _write_rows(self, records: list[tuple]) -> None:
        await self._con.copy_records_to_table(
            self._tmp_db_table,
            records=records,
            columns=self.COLUMNS,
        )

    async def clean_existing_data(self) -> None:
        pass

    async def update_row(self, row: tuple) -> None:
        pass

    async def is_duplicate(self, row: tuple) -> bool:
        return False

    async def create_tmp_table(self) -> None:
        await self._con.execute(
            f"""
            CREATE TEMPORARY TABLE {self._tmp_db_table} ON COMMIT DROP AS (
                SELECT {self._columns_list}
                FROM {self.DB_TABLE}
                WHERE False
            );
            """
        )

    async def write_to_table(self):
        await self._con.execute(
            f"""
            INSERT INTO {self.DB_TABLE} ({self._columns_list})
            SELECT {self._columns_list}
            FROM {self._tmp_db_table}
            """
        )

    @abstractmethod
    def process_row(self, row: Union[tuple, dict]) -> list[tuple]:
        pass

    @property
    def _tmp_db_table(self) -> str:
        return f"{self.DB_TABLE}_tmp"

    @property
    def _columns_list(self) -> str:
        return ", ".join(self.COLUMNS)


class CashbackProgramsSynchronizer(TableSynchronizer):
    DB_TABLE = "cashback_programs"
    COLUMNS = [
        "id",
        "category_id",
        "is_general",
        "is_enabled",
        "name_ru",
        "name_en",
        "description_ru",
        "description_en",
    ]

    async def is_duplicate(self, row: tuple) -> bool:
        existing_program = await self._con.fetchrow(
            f"""
            SELECT * FROM {self.DB_TABLE}
            WHERE id = $1
            """,
            int(row[0]),
        )
        return bool(existing_program)

    async def update_row(self, row: tuple) -> None:
        await self._con.execute(
            f"""
            UPDATE {self.DB_TABLE}
            SET category_id = $2, is_general = $3, is_enabled= $4, name_ru = $5, name_en = $6,
            description_ru = $7, description_en = $8
            WHERE id = $1
            """,
            int(row[0]),
            int(row[1]),
            bool(row[2]),
            bool(row[3]),
            str(row[4]),
            str(row[5]),
            str(row[6]),
            str(row[7]),
        )

    def process_row(self, row: Union[tuple, dict]) -> list[tuple]:
        return [
            (
                int(row[0]),
                int(row[1]),
                bool(row[2]),
                bool(row[3]),
                str(row[4]),
                str(row[5]),
                str(row[6]),
                str(row[7]),
            )
        ]
