import asyncio
import logging
from abc import abstractmethod
from contextlib import asynccontextmanager
from datetime import datetime, timezone
from decimal import Decimal
from operator import itemgetter
from typing import Awaitable, Dict, List, Optional
from collections import defaultdict

import janus
import yt.wrapper
from asyncpg.connection import Connection

from maps_adv.billing_proxy.lib.core.balance_client import BalanceClient
from maps_adv.billing_proxy.lib.db.enums import (
    CampaignType,
    CurrencyType,
    DbEnumConverter,
    OrderOperationType,
    PlatformType,
)
from maps_adv.common.geoproduct import GeoproductClient

from . import sqls
from .base import BaseDataManager
from .exceptions import OrdersDoNotExist


class AbstractOrdersDataManager(BaseDataManager):
    @abstractmethod
    async def list_inexistent_order_ids(self, order_ids: List[int]) -> List:
        raise NotImplementedError()

    @abstractmethod
    async def create_order(
        self,
        service_id: int,
        title: str,
        text: str,
        comment: str,
        client_id: int,
        agency_id: Optional[int],
        contract_id: Optional[int],
        product_id: int,
    ) -> dict:
        raise NotImplementedError()

    @abstractmethod
    async def find_order(self, order_id: int) -> dict:
        raise NotImplementedError()

    @abstractmethod
    async def find_order_by_external_id(self, external_id: int) -> dict:
        raise NotImplementedError()

    @abstractmethod
    async def retrieve_order_id_by_external_id(self, external_id: int) -> dict:
        raise NotImplementedError()

    @abstractmethod
    async def update_order(
        self, order_id: int, title: str, text: str, comment: str
    ) -> None:
        raise NotImplementedError()

    @abstractmethod
    async def find_orders(self, order_ids: List[int]) -> List[dict]:
        raise NotImplementedError()

    @abstractmethod
    async def retrieve_order_ids_for_account(
        self, account_manager_id: int
    ) -> List[int]:
        raise NotImplementedError()

    @abstractmethod
    async def list_agency_orders(
        self, agency_id: Optional[int], client_id: Optional[int] = None
    ) -> List[dict]:
        raise NotImplementedError()

    @abstractmethod
    async def list_client_orders(self, client_id: int) -> List[dict]:
        raise NotImplementedError()

    @abstractmethod
    async def list_orders_stats(self, order_ids: List[int]) -> Dict[int, dict]:
        raise NotImplementedError()

    @abstractmethod
    async def list_positive_balance_orders(
        self, order_ids: Optional[List[int]] = None
    ) -> List[int]:
        raise NotImplementedError()

    @abstractmethod
    async def lock_and_return_orders_balance(self, order_ids: List[int]):
        raise NotImplementedError()

    @abstractmethod
    async def list_orders_debits_for_billed_due_to(
        self, billed_due_to: datetime, con: Optional[Connection] = None
    ) -> Dict[int, Decimal]:
        raise NotImplementedError()

    @abstractmethod
    async def charge_orders(
        self, charge_info: Dict[int, Decimal], bill_for_dt: datetime, con: Connection
    ) -> Dict[int, bool]:
        raise NotImplementedError()

    @abstractmethod
    async def update_orders_limits(self, orders_updates: Dict[int, Dict]) -> None:
        raise NotImplementedError()

    @abstractmethod
    async def build_reconciliation_report(
        self,
        due_to: datetime,
        from_datetime: Optional[datetime] = datetime.min,
        service_id: Optional[int] = 110,
    ) -> Dict[int, dict]:
        raise NotImplementedError()

    @abstractmethod
    async def load_geoprod_reconciliation_report(
        self,
        from_datetime: datetime,
        to_datetime: datetime,
        geoprod_service_id: Optional[int] = 37,
    ) -> None:
        raise NotImplementedError()

    @abstractmethod
    async def list_orders_debits(
        self,
        order_ids: List[int],
        billed_after: datetime,
    ) -> list[dict]:
        raise NotImplementedError()


class OrdersDataManager(AbstractOrdersDataManager):
    _service_create_order_methods = {
        37: "_create_order_geoproduct",
        110: "_create_order_balance",
    }
    QUEUE_TIMEOUT = 60
    QUEUE_SIZE = 10000

    def __init__(
        self,
        *args,
        balance_client: BalanceClient,
        geoproduct_client: GeoproductClient,
        geoproduct_operator_id: int,
        skip_balance_api_call_on_orders_charge: bool,
        use_recalculate_statistic_mode: bool,
        yt_cluster: str,
        yt_token: str,
        reconciliation_report_dir: str,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)
        self._balance_client = balance_client
        self._geoproduct_client = geoproduct_client
        self._geoproduct_operator_client = geoproduct_operator_id
        self._skip_balance_api_call_on_orders_charge = (
            skip_balance_api_call_on_orders_charge
        )
        self._use_recalculate_statistic_mode = use_recalculate_statistic_mode
        self.yt_cluster = yt_cluster
        self.yt_token = yt_token
        self.reconciliation_report_dir = reconciliation_report_dir

        self._logger = logging.getLogger("billing_proxy.OrdersDataManager")

    async def list_inexistent_order_ids(self, order_ids: List[int]) -> List:
        async with self.connection() as con:
            existing_ids = await con.fetchval(sqls.list_existing_order_ids, order_ids)

        if existing_ids:
            return list(set(order_ids) - set(existing_ids))
        else:
            return order_ids

    async def create_order(
        self,
        service_id: int,
        title: str,
        text: str,
        comment: str,
        client_id: int,
        agency_id: Optional[int],
        contract_id: Optional[int],
        product_id: int,
    ) -> dict:
        async with self.connection() as con:
            async with con.transaction():
                product_data = await self._retrieve_product_data(product_id, con=con)
                oracle_id = product_data.pop("oracle_id")

                create_order_method = getattr(
                    self, self._service_create_order_methods[service_id]
                )

                order_data = await create_order_method(
                    service_id,
                    title,
                    text,
                    comment,
                    client_id,
                    agency_id,
                    contract_id,
                    product_id,
                    oracle_id,
                    con=con,
                )

        return {**order_data, **product_data}

    async def _create_order_balance(
        self,
        service_id: int,
        title: str,
        text: str,
        comment: str,
        client_id: int,
        agency_id: Optional[int],
        contract_id: Optional[int],
        product_id: int,
        oracle_id: int,
        con: Connection,
    ) -> dict:
        order_data = await con.fetchrow(
            sqls.insert_order,
            None,
            service_id,
            title,
            "",
            text,
            comment,
            client_id,
            agency_id,
            contract_id,
            product_id,
        )

        await self._balance_client.create_order(
            order_id=order_data["id"],
            client_id=client_id,
            agency_id=agency_id,
            oracle_product_id=oracle_id,
            contract_id=contract_id,
            text=text,
        )

        return dict(order_data)

    async def _create_order_geoproduct(
        self,
        service_id: int,
        title: str,
        text: str,
        comment: str,
        client_id: int,
        agency_id: Optional[int],
        contract_id: Optional[int],
        product_id: int,
        oracle_id: int,
        con: Connection,
    ) -> dict:
        geoproduct_order_id = await (
            self._geoproduct_client.create_order_for_media_platform(
                operator_id=self._geoproduct_operator_client,
                client_id=client_id,
                product_id=oracle_id,
            )
        )

        order_data = await con.fetchrow(
            sqls.insert_order,
            geoproduct_order_id,
            service_id,
            title,
            "",
            text,
            comment,
            client_id,
            agency_id,
            contract_id,
            product_id,
        )

        return dict(order_data)

    async def _retrieve_product_data(
        self, product_id: int, con: Optional[Connection] = None
    ) -> dict:
        async with self.connection(con) as con:
            product_data = await con.fetchrow(sqls.retrieve_product_params, product_id)

            product_data = dict(product_data)
            product_data["platforms"] = DbEnumConverter.to_enum(
                PlatformType, product_data["platforms"]
            )
            product_data["campaign_type"] = DbEnumConverter.to_enum(
                CampaignType, product_data["campaign_type"]
            )
            product_data["currency"] = DbEnumConverter.to_enum(
                CurrencyType, product_data["currency"]
            )

            return dict(product_data)

    @staticmethod
    def _clean_enums(data: dict):
        data["campaign_type"] = DbEnumConverter.to_enum(
            CampaignType, data["campaign_type"]
        )
        data["currency"] = DbEnumConverter.to_enum(CurrencyType, data["currency"])
        data["platforms"] = DbEnumConverter.to_enum(PlatformType, data["platforms"])

        return data

    async def find_order(self, order_id: int) -> dict:
        async with self.connection() as con:
            order_data = await con.fetchrow(sqls.find_order_by_id, order_id)

        if order_data:
            order_data = dict(order_data)
            self._clean_enums(order_data)

        return order_data

    async def find_order_by_external_id(self, external_id: int) -> dict:
        async with self.connection() as con:
            order_data = await con.fetchrow(sqls.find_order_by_external_id, external_id)

        if order_data:
            order_data = dict(order_data)
            self._clean_enums(order_data)

        return order_data

    async def retrieve_order_id_by_external_id(self, external_order_id: int) -> dict:
        async with self.connection() as con:
            order_id = await con.fetchval(
                sqls.retrieve_order_id_by_external_id, external_order_id
            )
        return order_id

    async def update_order(
        self, order_id: int, title: str, text: str, comment: str
    ) -> None:
        async with self.connection() as con:
            await con.execute(sqls.update_order, order_id, title, text, comment)

    async def find_orders(self, order_ids: List[int]) -> List[dict]:
        async with self.connection() as con:
            results = await con.fetch(sqls.list_orders_by_ids, order_ids)

        orders_data = []
        for result in results:
            order_data = dict(result)
            self._clean_enums(order_data)
            orders_data.append(order_data)

        return orders_data

    async def retrieve_order_ids_for_account(
        self, account_manager_id: int
    ) -> List[int]:
        async with self.connection() as con:
            result = await con.fetch(
                sqls.retrieve_order_ids_for_account, account_manager_id
            )

        return [row["order_id"] for row in result]

    async def list_agency_orders(
        self, agency_id: Optional[int], client_id: Optional[int] = None
    ) -> List[dict]:
        async with self.connection() as con:
            results = await con.fetch(sqls.list_agency_orders, agency_id, client_id)

        orders_data = []
        for result in results:
            order_data = dict(result)
            self._clean_enums(order_data)
            orders_data.append(order_data)

        return list(map(dict, orders_data))

    async def list_client_orders(self, client_id: int) -> List[dict]:
        async with self.connection() as con:
            results = await con.fetch(sqls.list_client_orders, client_id)

        orders_data = []
        for result in results:
            order_data = dict(result)
            self._clean_enums(order_data)
            orders_data.append(order_data)

        return list(map(dict, orders_data))

    async def list_orders_stats(self, order_ids: List[int]) -> Dict[int, dict]:
        async with self.connection() as con:
            orders_data = await con.fetch(sqls.list_orders_stats, order_ids)

        return {order["id"]: {"balance": order["balance"]} for order in orders_data}

    async def list_positive_balance_orders(
        self, order_ids: Optional[List[int]] = None
    ) -> List[int]:
        async with self.connection() as con:
            result = await con.fetchval(sqls.list_positive_balance_order_ids, order_ids)

        return sorted(result)

    async def list_orders_debits_for_billed_due_to(
        self, billed_due_to: datetime, con: Optional[Connection] = None
    ) -> Dict[int, Decimal]:
        async with self.connection(con) as con:
            order_debits = await con.fetch(
                sqls.list_orders_debits_for_billed_due_to, billed_due_to
            )

        return {
            order_debit["order_id"]: order_debit["amount"]
            for order_debit in order_debits
        }

    async def charge_orders(
        self, charge_info: Dict[int, Decimal], bill_for_dt: datetime, con: Connection
    ) -> Dict[int, bool]:
        balance_bill_for_dt = bill_for_dt
        if self._use_recalculate_statistic_mode:
            balance_bill_for_dt = datetime.now(tz=timezone.utc)

        orders_data_records = await con.fetch(
            sqls.list_order_rows_by_ids, list(charge_info.keys())
        )
        orders_data = {order["id"]: dict(order) for order in orders_data_records}
        external_to_internal = {
            order["external_id"]: order["id"] for order in orders_data_records
        }
        update_info = {
            # external id = order id for our orders and it's geoprod id for theirs
            order["external_id"]: {
                "consumed": order["consumed"] + charge_info[order_id],
                "service_id": order["service_id"],
            }
            for order_id, order in orders_data.items()
        }

        # Update data in balance
        if update_info and not self._skip_balance_api_call_on_orders_charge:
            update_results = await self._balance_client.update_orders(
                update_info, balance_bill_for_dt
            )
            # keep our order ids and convert geoprod ones back to ours
            balance_results = update_results.get(110) or {}
            if update_results.get(37) is not None:
                balance_results.update(
                    {
                        external_to_internal[order_id]: result
                        for order_id, result in update_results.get(37).items()
                    }
                )
        else:
            balance_results = {order_id: True for order_id in charge_info.keys()}

        # Update local data for orders that were successfully updated in balance
        # and create order_logs entries
        orders_upsert_data, order_logs_insert_data = [], []
        for order_id, _ in filter(itemgetter(1), balance_results.items()):
            new_order_data = orders_data[order_id].copy()
            new_order_data["consumed"] += charge_info[order_id]
            orders_upsert_data.append(new_order_data)

            order_logs_insert_data.append(
                (
                    order_id,
                    DbEnumConverter.from_enum(OrderOperationType.DEBIT),
                    charge_info[order_id],
                    new_order_data["consumed"],
                    new_order_data["limit"],
                    bill_for_dt,
                )
            )

        if orders_upsert_data:
            # We use insert to perform bulk update
            await con.executemany(
                sqls.upsert_order_consumed,
                list(
                    map(
                        itemgetter(
                            "id",
                            "external_id",
                            "service_id",
                            "created_at",
                            "tid",
                            "title",
                            "act_text",
                            "text",
                            "comment",
                            "client_id",
                            "agency_id",
                            "contract_id",
                            "product_id",
                            "limit",
                            "consumed",
                        ),
                        orders_upsert_data,
                    )
                ),
            )
            await con.executemany(sqls.insert_order_log, order_logs_insert_data)

        return balance_results

    @asynccontextmanager
    async def lock_and_return_orders_balance(self, order_ids: List[int]):
        async with self.connection() as con:
            async with con.transaction():
                orders_balance = await con.fetch(
                    sqls.lock_and_return_orders_balance_by_ids, order_ids
                )

                yield {ob["id"]: ob["balance"] for ob in orders_balance}, con

    async def update_orders_limits(self, orders_updates: Dict[int, Dict]) -> None:
        async with self.connection() as con:
            async with con.transaction():
                orders_data_records = await con.fetch(
                    sqls.lock_and_return_orders_rows_by_ids, orders_updates.keys()
                )

                if len(orders_data_records) != len(orders_updates):
                    requested_order_ids = set(orders_updates.keys())
                    existing_order_ids = set(map(itemgetter("id"), orders_data_records))
                    raise OrdersDoNotExist(
                        order_ids=sorted(requested_order_ids - existing_order_ids)
                    )

                orders_data = {
                    order["id"]: dict(order) for order in orders_data_records
                }

                orders_upsert_data, order_logs_insert_data = [], []
                for order_id, order_data in orders_data.items():
                    new_limit = orders_updates[order_id]["new_limit"]
                    new_tid = orders_updates[order_id]["tid"]

                    new_order_data = order_data.copy()
                    if new_tid <= order_data["tid"]:
                        continue

                    new_order_data["limit"] = new_limit
                    new_order_data["tid"] = new_tid

                    orders_upsert_data.append(new_order_data)

                    order_logs_insert_data.append(
                        (
                            order_id,
                            DbEnumConverter.from_enum(OrderOperationType.CREDIT),
                            new_limit - order_data["limit"],
                            order_data["consumed"],
                            new_limit,
                            None,
                        )
                    )

                if orders_upsert_data:
                    # We use insert to perform bulk update
                    await con.executemany(
                        sqls.upsert_order_limit_and_tid,
                        list(
                            map(
                                itemgetter(
                                    "id",
                                    "external_id",
                                    "service_id",
                                    "created_at",
                                    "tid",
                                    "title",
                                    "act_text",
                                    "text",
                                    "comment",
                                    "client_id",
                                    "agency_id",
                                    "contract_id",
                                    "product_id",
                                    "limit",
                                    "consumed",
                                ),
                                orders_upsert_data,
                            )
                        ),
                    )
                    await con.executemany(sqls.insert_order_log, order_logs_insert_data)

    async def build_reconciliation_report(
        self,
        due_to: datetime,
        from_datetime: Optional[datetime] = datetime.min,
        service_id: Optional[int] = 110,
    ) -> Dict[int, dict]:
        async with self.connection() as con:
            results = await con.fetch(
                sqls.list_orders_operations_summary,
                due_to,
                from_datetime,
                service_id,
            )

        return {
            order_row["order_id"]: {
                "completion_qty": order_row["completion"],
                "consumption_qty": order_row["consumption"],
            }
            for order_row in results
        }

    async def load_geoprod_reconciliation_report(
        self,
        from_datetime: datetime,
        to_datetime: datetime,
        geoprod_service_id: Optional[int] = 37,
    ) -> None:
        async with self.connection() as con:
            results = await con.fetch(
                sqls.list_orders_operations_summary,
                to_datetime,
                from_datetime,
                geoprod_service_id,
            )

        report = {
            order_row["external_order_id"]: {
                "completion_qty": order_row["completion"],
                "consumption_qty": order_row["consumption"],
            }
            for order_row in results
        }
        self._logger.info(
            "Geoprod reconciliation report",
            extra={
                "fields": {
                    "from": str(from_datetime),
                    "to": str(to_datetime),
                    "report": str(report),
                }
            },
        )
        table_name = from_datetime.strftime("%Y-%m-%d")
        queue = janus.Queue(maxsize=self.QUEUE_SIZE)
        writer_fut = self._start_yt_writer(
            self.reconciliation_report_dir + "/" + table_name, queue
        )

        for order_id, values in report.items():
            await asyncio.wait_for(
                queue.async_q.put(
                    {
                        "order_id": order_id,
                        "paid_delta": values["consumption_qty"],
                        "spent_delta": values["completion_qty"],
                    }
                ),
                timeout=self.QUEUE_TIMEOUT,
            )

        await asyncio.wait_for(queue.async_q.put(None), timeout=self.QUEUE_TIMEOUT)
        await writer_fut

        queue.close()
        await queue.wait_closed()

    def _start_yt_writer(self, table: str, queue: janus.Queue) -> Awaitable:
        def convert(item):
            if isinstance(item, list):
                return [convert(value) for value in item]
            elif isinstance(item, dict):
                return {key: convert(value) for key, value in item.items()}
            elif isinstance(item, datetime):
                return item.timestamp()
            elif isinstance(item, Decimal):
                return str(item)
            else:
                return item

        def dump_to_yt(queue):
            def records():
                while True:
                    item = queue.sync_q.get(timeout=self.QUEUE_TIMEOUT)
                    if item is None:
                        queue.sync_q.task_done()
                        break
                    self._logger.info(
                        "Writing to YT",
                        extra={"fields": {"item": str(item)}},
                    )
                    yield convert(item)
                    queue.sync_q.task_done()

            yt.wrapper.YtClient(self.yt_cluster, token=self.yt_token).write_table(
                yt.wrapper.TablePath(table), records()
            )

        return asyncio.get_running_loop().run_in_executor(None, dump_to_yt, queue)

    async def list_orders_debits(
        self,
        order_ids: List[int],
        billed_after: datetime,
    ) -> list[dict]:
        async with self.connection() as con:
            debits = await con.fetch(sqls.list_orders_debits, order_ids, billed_after)

        orders_debits = defaultdict(list)
        for debit in debits:
            orders_debits[debit["order_id"]].append(
                {"billed_at": debit["billed_due_to"], "amount": debit["amount"]}
            )

        return dict(orders_debits)
