from abc import ABC
from datetime import datetime, timezone
from decimal import Decimal

from asyncpg import Connection

__all__ = [
    "ClientsTableSynchronizer",
    "ClientsProgramsTableSynchronizer",
    "GainedClientBonusesTableSynchronizer",
    "SpentClientBonusesTableSynchronizer",
    "ClientBonusesToActivateTableSynchronizer",
]


class TableSynchronizer(ABC):
    DB_TABLE: str
    COLUMNS: list[str]
    WRITE_CHUNK_SIZE: int = 100_000

    def __init__(self, con: Connection):
        self._con = con

    async def process_data(self, rows: list[tuple]) -> None:
        await self.clean_existing_data()
        await self.preload_exiting_data()
        await self.create_tmp_table()

        rows_to_add_chunk = []
        for row in rows:
            row_to_add = self.process_row(row)

            if await self.is_duplicate(row_to_add):
                continue

            rows_to_add_chunk.append(row_to_add)
            if len(rows_to_add_chunk) == self.WRITE_CHUNK_SIZE:
                await self._write_rows(rows_to_add_chunk)
                rows_to_add_chunk.clear()

        if rows_to_add_chunk:
            await self._write_rows(rows_to_add_chunk)

        await self.write_to_table()

    async def _write_rows(self, records: list[tuple]) -> None:
        await self._con.copy_records_to_table(
            self._tmp_db_table,
            records=records,
            columns=self.COLUMNS,
        )

    async def is_duplicate(self, row: tuple) -> bool:
        return False

    async def clean_existing_data(self) -> None:
        pass

    async def preload_exiting_data(self) -> None:
        pass

    async def create_tmp_table(self) -> None:
        await self._con.execute(
            f"""
            CREATE TEMPORARY TABLE {self._tmp_db_table} ON COMMIT DROP AS (
                SELECT {self._columns_list}
                FROM {self.DB_TABLE}
                WHERE False
            );
            """
        )

    async def write_to_table(self):
        await self._con.execute(
            f"""
            INSERT INTO {self.DB_TABLE} ({self._columns_list})
            SELECT {self._columns_list}
            FROM {self._tmp_db_table}
            """
        )

    def process_row(self, row: tuple) -> tuple:
        return row

    @property
    def _tmp_db_table(self) -> str:
        return f"{self.DB_TABLE}_tmp"

    @property
    def _columns_list(self) -> str:
        return ", ".join(self.COLUMNS)


class ClientsTableSynchronizer(TableSynchronizer):
    DB_TABLE = "clients"
    COLUMNS = ["id", "login", "agency_id", "is_active", "create_date"]

    def __init__(self, con: Connection):
        super().__init__(con)
        self._existing_ids: set[int] = set()

    def process_row(self, row: tuple) -> tuple:
        create_date = datetime.strptime(row[4], '%Y-%m-%d %H:%M:%S')
        create_date = datetime(
            create_date.year,
            create_date.month,
            create_date.day,
            create_date.hour,
            create_date.minute,
            create_date.second,
            tzinfo=timezone.utc,
        )

        return int(row[0]), row[1], int(row[2]), bool(row[3]), create_date

    async def is_duplicate(self, row: tuple) -> bool:
        return row[0] in self._existing_ids

    async def preload_exiting_data(self) -> None:
        self._existing_ids = set(
            await self._con.fetchval(
                """
                SELECT coalesce(array_agg(id), ARRAY[]::bigint[])
                FROM clients
                """
            )
        )


class ClientsProgramsTableSynchronizer(TableSynchronizer):
    DB_TABLE = "clients_programs"
    COLUMNS = ["client_id", "program_id"]

    def process_row(self, row: tuple) -> tuple:
        return int(row[0]), row[1]

    async def clean_existing_data(self) -> None:
        await self._con.execute(
            f"""
            DELETE FROM {self.DB_TABLE}
            """
        )


class GainedClientBonusesTableSynchronizer(TableSynchronizer):
    DB_TABLE = "gained_client_bonuses"
    COLUMNS = ["client_id", "program_id", "amount", "currency", "gained_at"]

    def __init__(self, con: Connection, gained_at: datetime, program_ids: list[int] = None):
        super().__init__(con)
        self._gained_at = gained_at
        self._program_ids = ','.join(str(i) for i in program_ids) if program_ids else None

    async def clean_existing_data(self):
        sql_delete_all = f"""DELETE FROM {self.DB_TABLE} WHERE gained_at = $1"""
        sql_delete_certain_programs = f"""DELETE FROM {self.DB_TABLE} WHERE gained_at = $1 AND program_id in ({self._program_ids})"""

        if self._program_ids:
            await self._con.execute(
                sql_delete_certain_programs,
                self._gained_at,
            )
        else:
            await self._con.execute(
                sql_delete_all,
                self._gained_at,
            )

    def process_row(self, row: tuple) -> tuple:
        return int(row[0]), row[1], Decimal(row[2]), row[3], self._gained_at


class SpentClientBonusesTableSynchronizer(TableSynchronizer):
    DB_TABLE = "spent_client_bonuses"
    COLUMNS = ["client_id", "amount", "currency", "spent_at"]

    def __init__(self, con: Connection, spent_at: datetime):
        super().__init__(con)
        self._spent_at = spent_at

    async def clean_existing_data(self):
        await self._con.execute(
            f"""
                DELETE FROM {self.DB_TABLE}
                WHERE spent_at = $1
                """,
            self._spent_at,
        )

    def process_row(self, row: tuple) -> tuple:
        return int(row[0]), Decimal(row[1]), row[2], self._spent_at


class ClientBonusesToActivateTableSynchronizer(TableSynchronizer):
    DB_TABLE = "client_bonuses_to_activate"
    COLUMNS = ["client_id", "amount"]

    async def clean_existing_data(self) -> None:
        await self._con.execute(
            f"""
            DELETE FROM {self.DB_TABLE}
            """
        )

    def process_row(self, row: tuple) -> tuple:
        return int(row[0]), Decimal(row[1])
