from abc import ABC, abstractmethod
from collections import defaultdict
from datetime import datetime
from operator import itemgetter
from typing import AsyncGenerator, Dict, List, Optional

from asyncpg import UniqueViolationError
from smb.common.pgswim import PoolType, SwimEngine

from ..enums import ScenarioName, SubscriptionStatus
from ..exceptions import SubscriptionAlreadyExists, UnknownSubscription
from . import sqls
from .scenarios import (
    DiscountForDisloyalScenario,
    DiscountForLostScenario,
    EngageProspectiveScenario,
    ThankTheLoyalScenario,
    find_scenario_by_name,
)


class BaseDataManager(ABC):
    @abstractmethod
    async def list_scenarios(self, *, biz_id: int) -> List[dict]:
        raise NotImplementedError()

    @abstractmethod
    async def create_subscription(
        self, *, biz_id: int, scenario_name: ScenarioName, coupon_id: Optional[int]
    ) -> dict:
        raise NotImplementedError()

    @abstractmethod
    async def retrieve_subscription(self, *, subscription_id: int, biz_id: int) -> dict:
        raise NotImplementedError()

    @abstractmethod
    async def retrieve_subscription_current_state(
        self, *, subscription_id: int, biz_id: int
    ) -> dict:
        raise NotImplementedError()

    @abstractmethod
    async def update_subscription_status(
        self, *, subscription_id: int, biz_id: int, status: SubscriptionStatus
    ):
        raise NotImplementedError()

    @abstractmethod
    async def update_subscriptions_statuses(
        self, *, subscription_ids: List[int], status: SubscriptionStatus
    ):
        raise NotImplementedError()

    @abstractmethod
    async def replace_subscription_coupon(
        self,
        *,
        subscription_id: int,
        biz_id: int,
        coupon_id: Optional[int],
        status: SubscriptionStatus,
    ):
        raise NotImplementedError()

    @abstractmethod
    async def iter_subscriptions_for_export(self, chunk_size: int):
        raise NotImplementedError()

    @abstractmethod
    async def copy_messages_from_generator(
        self, generator: AsyncGenerator[List[dict], None]
    ):
        raise NotImplementedError()

    @abstractmethod
    async def list_unprocessed_email_messages(self) -> List[dict]:
        raise NotImplementedError()

    @abstractmethod
    async def mark_messages_processed(
        self,
        messages: Dict[int, Optional[str]],
        processed_at: datetime,
        processed_meta: Optional[dict] = None,
    ) -> None:
        raise NotImplementedError()

    @abstractmethod
    async def import_certificate_mailing_stats(
        self, generator: AsyncGenerator[List[dict], None]
    ) -> None:
        raise NotImplementedError

    @abstractmethod
    async def iter_subscriptions_versions_for_export(
        self, chunk_size: int
    ) -> AsyncGenerator[List[dict], None]:
        raise NotImplementedError


class DataManager(BaseDataManager):
    __slots__ = ["_db"]

    _db: SwimEngine

    def __init__(self, db: SwimEngine):
        self._db = db

    async def list_scenarios(self, *, biz_id: int) -> List[dict]:
        async with self._db.acquire() as con:
            rows = await con.fetch(sqls.list_business_subscriptions, biz_id)

        subscriptions = {row.get("scenario_name"): dict(row) for row in rows}

        scenarios = []
        for scenario_cls in (
            DiscountForLostScenario,
            EngageProspectiveScenario,
            ThankTheLoyalScenario,
            DiscountForDisloyalScenario,
        ):
            scenario = scenario_cls.to_dict()
            if scenario_cls.name in subscriptions:
                scenario["subscription"] = subscriptions[scenario_cls.name]

            scenarios.append(scenario)

        return scenarios

    async def create_subscription(
        self, *, biz_id: int, scenario_name: ScenarioName, coupon_id: Optional[int]
    ) -> dict:
        async with self._db.acquire() as con:
            try:
                row = await con.fetchrow(
                    sqls.create_subscription,
                    scenario_name,
                    biz_id,
                    SubscriptionStatus.ACTIVE,
                    coupon_id,
                )
            except UniqueViolationError:
                raise SubscriptionAlreadyExists()

        return dict(row)

    async def retrieve_subscription(self, *, subscription_id: int, biz_id: int) -> dict:
        async with self._db.acquire() as con:
            sub_row = await con.fetchrow(
                sqls.retrieve_subscription, subscription_id, biz_id
            )

        if not sub_row:
            raise UnknownSubscription(
                f"subscription_id={subscription_id}, biz_id={biz_id}"
            )

        async with self._db.acquire() as con:
            version_rows = await con.fetch(
                sqls.retrieve_subscription_coupons_history, subscription_id
            )

        return dict(
            subscription=dict(sub_row),
            coupons_history=[dict(row) for row in version_rows],
        )

    async def retrieve_subscription_current_state(
        self, *, subscription_id: int, biz_id: int
    ) -> dict:
        async with self._db.acquire() as con:
            row = await con.fetchrow(
                sqls.fetch_subscription_state, subscription_id, biz_id
            )

        if not row:
            raise UnknownSubscription(
                f"subscription_id={subscription_id}, biz_id={biz_id}"
            )

        return dict(row)

    async def update_subscription_status(
        self, *, subscription_id: int, biz_id: int, status: SubscriptionStatus
    ):
        async with self._db.acquire() as con:
            await con.execute(
                sqls.update_subscription_status, subscription_id, biz_id, status
            )

    async def update_subscriptions_statuses(
        self, *, subscription_ids: int, status: SubscriptionStatus
    ):
        async with self._db.acquire() as con:
            await con.execute(
                sqls.update_subscriptions_statuses, subscription_ids, status
            )

    async def replace_subscription_coupon(
        self,
        *,
        subscription_id: int,
        biz_id: int,
        coupon_id: Optional[int],
        status: SubscriptionStatus,
    ):
        async with self._db.acquire() as con:
            await con.fetchval(
                sqls.replace_subscription_coupon,
                subscription_id,
                biz_id,
                coupon_id,
                status,
            )

    async def iter_subscriptions_for_export(self, chunk_size: int):
        async with self._db.acquire(PoolType.replica) as con:
            async with con.transaction():
                cur = await con.cursor(sqls.list_subscriptions_for_export)

                while True:
                    rows = await cur.fetch(chunk_size)

                    if rows == []:
                        break
                    yield [
                        dict(
                            subscription_id=row["subscription_id"],
                            scenario_code=row["scenario_name"].name,
                            segments=[
                                el.name
                                for el in find_scenario_by_name(
                                    row["scenario_name"]
                                ).segments
                            ],
                            biz_id=row["biz_id"],
                            coupon_id=row["coupon_id"],
                        )
                        for row in rows
                    ]

    async def copy_messages_from_generator(
        self, generator: AsyncGenerator[List[dict], None]
    ):
        _getter = itemgetter(
            "time_to_send",
            "message_anchor",
            "message_type",
            "message_meta",
            "doorman_ids",
            "promoter_ids",
            "scenario_codes",
        )

        async with self._db.acquire(PoolType.master) as con:
            async with con.transaction():
                await con.execute(sqls.create_messages_tmp_table)

                async for records in generator:
                    values = list(map(_getter, records))
                    await con.copy_records_to_table(
                        "messages_tmp",
                        records=values,
                        columns=[
                            "time_to_send",
                            "message_anchor",
                            "message_type",
                            "message_meta",
                            "doorman_ids",
                            "promoter_ids",
                            "scenario_names",
                        ],
                    )

                await con.execute(sqls.copy_messages_from_tmp)

    async def list_unprocessed_email_messages(self) -> List[dict]:
        async with self._db.acquire(PoolType.replica) as con:
            results = await con.fetch(sqls.list_unprocessed_email_messages)

        return list(map(dict, results))

    async def mark_messages_processed(
        self,
        messages: Dict[int, Optional[str]],
        processed_at: datetime,
        processed_meta: Optional[dict] = None,
    ) -> None:
        message_ids_by_error = defaultdict(list)
        for message_id, error in messages.items():
            message_ids_by_error[error].append(message_id)

        async with self._db.acquire(PoolType.master) as con:
            async with con.transaction():
                for error, message_ids in message_ids_by_error.items():
                    await con.execute(
                        sqls.mark_messages_processed,
                        message_ids,
                        processed_at,
                        error,
                        processed_meta,
                    )

    async def import_certificate_mailing_stats(
        self, generator: AsyncGenerator[List[dict], None]
    ) -> None:
        _getter = itemgetter(
            "biz_id", "coupon_id", "scenario", "clicked", "opened", "sent", "dt"
        )

        async with self._db.acquire(PoolType.master) as con:
            async with con.transaction():
                await con.execute(sqls.create_certificate_mailing_stats_tmp_table)

                async for records in generator:
                    values = list(map(_getter, records))
                    await con.copy_records_to_table(
                        "certificate_mailing_stats_tmp",
                        records=values,
                        columns=[
                            "biz_id",
                            "coupon_id",
                            "scenario_name",
                            "clicked",
                            "opened",
                            "sent",
                            "sent_date",
                        ],
                    )

                await con.execute(sqls.copy_certificate_mailing_stats_from_tmp)

    async def iter_subscriptions_versions_for_export(
        self, chunk_size: int
    ) -> AsyncGenerator[List[dict], None]:
        async with self._db.acquire(PoolType.replica) as con:
            async with con.transaction():

                cur = await con.cursor(sqls.list_subscriptions_versions_for_export)

                while True:
                    rows = await cur.fetch(chunk_size)
                    if not rows:
                        break

                    yield [
                        dict(
                            version_id=row["version_id"],
                            subscription_id=row["subscription_id"],
                            biz_id=row["biz_id"],
                            scenario_code=row["scenario_name"].name,
                            coupon_id=row["coupon_id"],
                            status=row["status"].name,
                            created_at=self._encode_dt_to_yt_timestamp(
                                row["created_at"]
                            ),
                        )
                        for row in rows
                    ]

    @staticmethod
    def _encode_dt_to_yt_timestamp(dt: datetime) -> int:
        return int(dt.timestamp() * 1_000_000)
