from collections import namedtuple
from contextlib import AsyncExitStack, asynccontextmanager
from datetime import datetime, timedelta, timezone
from typing import AsyncGenerator, Iterable, List, Optional

from asyncpg import Record
from asyncpg.pool import PoolConnectionProxy
from smb.common.pgswim import PoolType, SwimEngine

from maps_adv.geosmb.booking_yang.server.lib.enums import OrderStatus

from . import sqls

DbParam = namedtuple("DbParam", "db_field, db_value")

_notset = object()


class UnknownDbFields(Exception):
    pass


class NaiveDateTime(Exception):
    pass


class OrdersDataManager:
    __slots__ = ("_db",)

    _db: SwimEngine
    _updatable_fields = {
        "permalink",
        "reservation_datetime",
        "reservation_timezone",
        "person_count",
        "customer_name",
        "customer_phone",
        "customer_passport_uid",
        "comment",
        "call_agreement_accepted",
        "yang_suite_id",
        "client_id",
        "biz_id",
        "yang_task_created_at",
        "task_created_at",
        "task_result_got_at",
        "sms_sent_at",
        "sent_result_event_at",
        "exported_as_created_at",
        "exported_as_processed_at",
        "exported_as_notified_at",
        "booking_verdict",
        "booking_meta",
    }

    def __init__(self, db: SwimEngine):
        self._db = db

    async def create_order(
        self,
        permalink: int,
        reservation_datetime: datetime,
        reservation_timezone: str,
        person_count: int,
        customer_name: str,
        customer_phone: str,
        customer_passport_uid: Optional[int],
        comment: str,
        call_agreement_accepted: bool,
        biz_id: int,
        time_to_call: Optional[datetime],
        created_at: datetime,
    ) -> int:
        if not reservation_datetime.tzinfo:
            raise NaiveDateTime(
                f"reservation_datetime must be aware "
                f"with timezone: {reservation_datetime}."
            )

        async with self._db.acquire() as con:
            return await con.fetchval(
                sqls.create_order,
                permalink,
                reservation_datetime,
                reservation_timezone,
                person_count,
                customer_name,
                customer_phone,
                customer_passport_uid,
                comment,
                call_agreement_accepted,
                biz_id,
                time_to_call,
                created_at,
            )

    async def retrieve_order_by_suite(self, yang_suite_id: str) -> Optional[dict]:
        async with self._db.acquire(PoolType.replica) as con:
            row = await con.fetchrow(sqls.retrieve_order_by_suite, yang_suite_id)
            return dict(row) if row is not None else None

    async def retrieve_earliest_unprocessed_order_time(self) -> datetime:
        async with self._db.acquire(PoolType.replica) as con:
            return await con.fetchval(sqls.retrieve_earliest_unprocessed_order_time)

    async def list_pending_orders(self) -> List[dict]:
        async with self._db.acquire(PoolType.master) as con:
            rows = await con.fetch(
                sqls.list_pending_orders, datetime.now(tz=timezone.utc)
            )

        return [dict(row) for row in rows]

    async def list_orders_without_client(self, order_min_age_sec: int) -> List[dict]:
        async with self._db.acquire(PoolType.replica) as con:
            rows = await con.fetch(
                sqls.list_orders_without_client,
                datetime.now(tz=timezone.utc) - timedelta(seconds=order_min_age_sec),
            )

        return [dict(row) for row in rows]

    async def list_orders_for_sending_result_event(self):
        async with self._db.acquire(PoolType.replica) as con:
            rows = await con.fetch(
                sqls.list_orders_for_sending_result_event,
            )

        return [dict(row) for row in rows]

    async def update_orders(
        self,
        *,
        con: Optional[PoolConnectionProxy] = None,
        order_ids: List[int],
        **kwargs,
    ):
        if len(kwargs) == 0:
            return

        self._validate_db_fields(kwargs.keys())

        set_pairs = []
        query_params = []
        for param_pos, param in enumerate(kwargs.items(), start=2):
            set_pairs.append(f"{param[0]}=${param_pos}")
            query_params.append(param[1])

        execute_params = (
            sqls.update_orders.format(set_pairs_str=", ".join(set_pairs)),
            order_ids,
            *query_params,
        )

        async with AsyncExitStack() as exit_stack:
            if con is None:
                con = await exit_stack.enter_async_context(
                    self._db.acquire(PoolType.master)
                )

            await con.execute(*execute_params)

    async def retrieve_orders_for_sending_sms(self):
        async with self._db.acquire(PoolType.replica) as con:
            rows = await con.fetch(sqls.retrieve_orders_for_sending_sms)

            return [dict(row) for row in rows]

    async def iter_created_orders(
        self, chunk_size: int
    ) -> AsyncGenerator[List[dict], None]:
        async for orders_chunk in self._base_iter_orders(
            chunk_size=chunk_size,
            sql=sqls.iter_created_orders,
            dt_fields=["created_at", "processing_time"],
        ):
            yield orders_chunk

    async def iter_orders_processed_by_yang(
        self, chunk_size: int
    ) -> AsyncGenerator[List[dict], None]:
        async for orders_chunk in self._base_iter_orders(
            chunk_size=chunk_size,
            sql=sqls.iter_orders_submitted_to_yang,
            dt_fields=["yang_task_created_at", "task_created_at", "task_result_got_at"],
        ):
            yield orders_chunk

    async def iter_orders_notified_to_users(
        self, chunk_size: int
    ) -> AsyncGenerator[List[dict], None]:
        async for orders_chunk in self._base_iter_orders(
            chunk_size=chunk_size,
            sql=sqls.iter_orders_notified_to_users,
            dt_fields=["sms_sent_at"],
        ):
            yield orders_chunk

    async def list_client_orders(
        self,
        client_id: int,
        biz_id: int,
        datetime_from: Optional[datetime] = None,
        datetime_to: Optional[datetime] = None,
    ):
        period_conditions = []

        if datetime_from is not None:
            period_conditions.append(" AND reservation_datetime >= $3")

        if datetime_to is not None:
            period_conditions.append(" AND reservation_datetime <= $4")

        async with self._db.acquire(PoolType.replica) as con:
            rows = await con.fetch(
                sqls.list_client_orders.format(
                    period_conditions="".join(period_conditions)
                ),
                client_id,
                biz_id,
                datetime_from,
                datetime_to,
            )

            return self._normalize_client_orders(rows)

    async def fetch_pending_yang_tasks_count(self) -> int:
        async with self._db.acquire(PoolType.replica) as con:
            return await con.fetchval(sqls.fetch_pending_yang_tasks_count)

    async def clear_orders_by_passport(self, passport_uid: int) -> List[int]:
        async with self._db.acquire(PoolType.master) as con:
            rows = await con.fetch(sqls.clear_orders_by_passport, passport_uid)

            return [row["order_id"] for row in rows]

    async def check_orders_existence_by_passport(self, passport_uid: int) -> bool:
        async with self._db.acquire(PoolType.replica) as con:
            return await con.fetchval(
                sqls.check_orders_existence_by_passport, passport_uid
            )

    @asynccontextmanager
    async def acquire_con_with_tx(
        self,
    ) -> PoolConnectionProxy:
        async with self._db.acquire(PoolType.master) as con:
            async with con.transaction():
                yield con

    def _normalize_client_orders(self, rows: List[Record]) -> dict:
        events_before = rows[0]["events_before"]
        events_after = rows[0]["events_after"]
        orders = []

        for row in rows:
            if row["order_id"] is not None:
                orders.append(
                    dict(
                        order_id=row["order_id"],
                        created_at=row["created_at"],
                        reservation_datetime=row["reservation_datetime"],
                        reservation_timezone=row["reservation_timezone"],
                        person_count=row["person_count"],
                        status=OrderStatus[row["status"]],
                    )
                )

        return dict(
            events_before=events_before, events_after=events_after, orders=orders
        )

    async def _base_iter_orders(self, chunk_size: int, sql: str, dt_fields: List[str]):
        async with self._db.acquire(PoolType.replica) as con:
            async with con.transaction():
                cur = await con.cursor(sql)

                while True:
                    rows = await cur.fetch(chunk_size)
                    if not rows:
                        break

                    rows = [dict(row) for row in rows]
                    for row in rows:
                        for dt_field in dt_fields:
                            row[dt_field] = (
                                int(row[dt_field].timestamp() * 1_000_000)
                                if row[dt_field] is not None
                                else None
                            )

                    yield rows

    async def fetch_actual_orders(self, actual_on: datetime) -> List[dict]:
        async with self._db.acquire(PoolType.replica) as con:
            rows = await con.fetch(sqls.list_actual_orders_for_personal_poi, actual_on)

        return [dict(row) for row in rows]

    @asynccontextmanager
    async def lock_order(self, order_id: int) -> Optional[PoolConnectionProxy]:
        async with self._db.acquire(PoolType.master) as con:
            success = await con.fetchrow(sqls.select_order_for_update, order_id)
            yield con if success is not None else None

    def _validate_db_fields(self, db_fields: Iterable):
        updated_fields = set(db_fields)
        unknown_fields = updated_fields - self._updatable_fields
        if unknown_fields:
            raise UnknownDbFields(f"Unknown DB fields: {unknown_fields}")
