import re
from abc import abstractmethod
from datetime import datetime
from operator import itemgetter
from typing import Dict, List, Optional
from asyncpg.exceptions import ForeignKeyViolationError

from maps_adv.billing_proxy.lib.db.enums import (
    BillingType,
    CampaignType,
    CurrencyType,
    DbEnumConverter,
    PlatformType,
)

from . import sqls
from .base import BaseDataManager
from .exceptions import (
    ClientDoesNotExist,
    ProductDoesNotExist,
    ContractDoesNotExist,
    ProductClientMismatch,
    ConflictingProducts,
    ClientsAlreadyBoundToOverlappingProducts,
)


class AbstractProductsDataManager(BaseDataManager):
    @abstractmethod
    async def find_product(self, product_id: int) -> Optional[Dict]:
        raise NotImplementedError()

    @abstractmethod
    async def list_products(
        self, service_ids: Optional[List[int]] = None
    ) -> List[dict]:
        raise NotImplementedError()

    @abstractmethod
    async def list_by_params(
        self,
        platforms: List[PlatformType],
        campaign_type: CampaignType,
        service_id: Optional[int],
        currency: CurrencyType,
        dt: datetime,
        client_id: Optional[int] = None,
    ) -> List[Dict]:
        raise NotImplementedError()

    @abstractmethod
    async def find_product_active_version(
        self, product_id: int, dt: datetime
    ) -> Optional[dict]:
        raise NotImplementedError()

    @abstractmethod
    async def list_clients_bound_to_product(self, product_id):
        raise NotImplementedError()

    @abstractmethod
    async def bind_client_to_product(self, product_id, clients):
        raise NotImplementedError()

    @abstractmethod
    async def unbind_client_from_product(self, product_id, client_id, contract_id):
        raise NotImplementedError()

    @abstractmethod
    async def create_product(self, **kw_args):
        raise NotImplementedError()

    @abstractmethod
    async def update_product(self, **kw_args):
        raise NotImplementedError()


class ProductsDataManager(AbstractProductsDataManager):
    @staticmethod
    def _process_data(data):
        if not isinstance(data, dict):
            data = dict(data)

        data["billing_type"] = DbEnumConverter.to_enum(
            BillingType, data["billing_type"]
        )
        data["campaign_type"] = DbEnumConverter.to_enum(
            CampaignType, data["campaign_type"]
        )
        data["platforms"] = DbEnumConverter.to_enum(PlatformType, data["platforms"])
        data["currency"] = DbEnumConverter.to_enum(CurrencyType, data["currency"])

        return data

    async def find_product(self, product_id: int) -> Optional[Dict]:
        # In dedicated_client_ids we get list of clients which are
        # allowed to use this product or None if the product has no
        # restrictions
        async with self.connection() as con:
            product_data = await con.fetchrow(sqls.find_product_by_id, product_id)
            versions_data = await con.fetch(
                sqls.list_all_product_versions_by_product, product_id
            )

        product = None
        if product_data:
            product = self._process_data(product_data)
            for version in versions_data:
                version = dict(version)
                version["billing_type"] = DbEnumConverter.to_enum(
                    BillingType, version["billing_type"]
                )
                product.setdefault("versions", []).append(version)

        return product

    async def list_products(
        self, service_ids: Optional[List[int]] = None
    ) -> List[dict]:
        if service_ids is None:
            service_ids = [37, 110]
        async with self.connection() as con:
            products_results = await con.fetch(sqls.list_products, service_ids)
            version_results = await con.fetch(
                sqls.list_all_products_versions, service_ids
            )

        products_dict = dict()
        for product in products_results:
            products_dict[product["id"]] = self._process_data(product)

        for version in version_results:
            try:
                version = dict(version)
                version["billing_type"] = DbEnumConverter.to_enum(
                    BillingType, version["billing_type"]
                )
                products_dict[version["product_id"]].setdefault("versions", []).append(
                    version
                )
            except KeyError:
                pass

        return list(products_dict.values())

    async def find_product_active_version(
        self, product_id: int, dt: datetime
    ) -> Optional[dict]:
        async with self.connection() as con:
            result = await con.fetchrow(
                sqls.find_product_active_version, product_id, dt
            )

        return dict(result) if result else None

    async def list_by_params(
        self,
        platforms: List[PlatformType],
        campaign_type: CampaignType,
        currency: CurrencyType,
        dt: datetime,
        client_id: Optional[int] = None,
        service_id: Optional[int] = None,
    ) -> List[Dict]:

        if dt.tzinfo is None:
            raise ValueError("Must be tz-aware datetime")

        async with self.connection() as con:
            results = await con.fetch(
                sqls.list_products_by_params,
                list(map(DbEnumConverter.from_enum, platforms)),
                DbEnumConverter.from_enum(campaign_type),
                DbEnumConverter.from_enum(currency),
                dt,
                client_id,
                service_id,
            )

        return list(map(self._process_data, results))

    async def list_clients_bound_to_product(self, product_id):
        async with self.connection() as con:
            results = await con.fetch(sqls.list_clients_bound_to_product, product_id)

        return list(map(dict, results))

    async def bind_client_to_product(self, product_id, clients):
        try:
            async with self.connection() as con, con.transaction():
                product_data = await con.fetchrow(sqls.find_product_by_id, product_id)
                if not product_data:
                    raise ProductDoesNotExist(product_id=product_id)

                if product_data.get("type") != "YEARLONG":
                    raise ProductClientMismatch(
                        product_id=product_id,
                        client_ids=list(map(itemgetter("client_id"), clients)),
                    )

                all_bound_clients = await con.fetch(
                    sqls.list_clients_bound_to_product, product_id
                )
                clients_with_null_contract = {
                    row["client_id"]
                    for row in all_bound_clients
                    if row["contract_id"] is None
                }
                non_duplicate_clients = [
                    client
                    for client in clients
                    if client["client_id"] not in clients_with_null_contract
                    or client.get("contract_id") is not None
                ]

                if not non_duplicate_clients:
                    return

                clients_bound_to_overlapping_products = await con.fetch(
                    sqls.list_clients_bound_to_overlapping_products,
                    [client["client_id"] for client in non_duplicate_clients],
                    product_data["campaign_type"],
                    product_id,
                    product_data["platforms"],
                    product_data["currency"],
                )

                if clients_bound_to_overlapping_products:
                    raise ClientsAlreadyBoundToOverlappingProducts(
                        client_ids=[
                            client["client_id"]
                            for client in clients_bound_to_overlapping_products
                        ],
                        platforms=DbEnumConverter.to_enum(
                            PlatformType, product_data["platforms"]
                        ),
                        currency=DbEnumConverter.to_enum(
                            CurrencyType, product_data["currency"]
                        ),
                    )

                await con.execute(
                    sqls.bind_client_to_product.format(
                        ",".join(
                            f"({product_id}, {client['client_id']}, "
                            f" {'NULL' if client.get('contract_id') is None else client.get('contract_id')})"
                            for client in non_duplicate_clients
                        )
                    )
                )
        except ForeignKeyViolationError as exc:
            if (
                exc.constraint_name
                == "fk_product_client_restrictions_client_id_clients"
            ):
                match = re.search(r"\(client_id\)=\((\d+)\)", exc.detail)
                assert match
                raise ClientDoesNotExist(client_id=int(match.group(1)))
            elif (
                exc.constraint_name
                == "fk_product_client_restrictions_product_id_products"
            ):
                match = re.search(r"\(product_id\)=\((\d+)\)", exc.detail)
                assert match
                raise ProductDoesNotExist(product_id=int(match.group(1)))
            elif (
                exc.constraint_name
                == "fk_product_client_restrictions_contract_id_contracts"
            ):
                match = re.search(r"\(contract_id\)=\((\d+)\)", exc.detail)
                assert match
                raise ContractDoesNotExist(contract_id=int(match.group(1)))
            else:
                raise

    async def unbind_client_from_product(self, product_id, clients):
        async with self.connection() as con:
            rows_deleted = await con.fetch(
                sqls.unbind_client_from_product.format(
                    " OR ".join(
                        f"""(
                    product_id = $1 AND
                    client_id = {client["client_id"]} AND
                    (
                        ({client.get("contract_id", "NULL")}::integer IS NULL AND contract_id IS NULL) OR
                        contract_id = {client.get("contract_id", "NULL")}
                    ))"""
                        for client in clients
                    )
                ),
                product_id,
            )

            ids_deleted = [row[0] for row in rows_deleted]
            ids_unknown = [
                client["client_id"]
                for client in clients
                if not client["client_id"] in ids_deleted
            ]

            if ids_unknown:
                raise ProductClientMismatch(
                    product_id=product_id, client_ids=ids_unknown
                )

    async def _add_product_version(self, con, **kw_args):
        await con.execute(
            sqls.insert_product_version,
            kw_args["product_id"],
            kw_args["active_from"],
            kw_args.get("active_to"),
            kw_args["billing_data"],
            kw_args["min_budget"],
            kw_args["cpm_filters"],
        )

    @staticmethod
    async def _check_product_for_duplicates(con, **kw_args):
        if kw_args["type"] == "REGULAR":
            duplicate_products = await con.fetch(
                sqls.check_for_duplicate_products,
                DbEnumConverter.from_enum(kw_args["campaign_type"]),
                DbEnumConverter.from_enum(kw_args["currency"]),
                kw_args["service_id"],
                kw_args["active_from"],
                kw_args.get("active_to"),
                kw_args.get("product_id"),
            )

            if duplicate_products:
                raise ConflictingProducts(
                    product_ids=list(map(itemgetter("product_id"), duplicate_products))
                )

    async def create_product(self, **kw_args):
        platforms = []
        for platform in kw_args["platforms"]:
            platforms.append(DbEnumConverter.from_enum(platform))

        async with self.connection() as con, con.transaction():
            await self._check_product_for_duplicates(con, **kw_args)
            product = await con.fetchrow(
                sqls.insert_product,
                kw_args["oracle_id"],
                kw_args["title"],
                kw_args["act_text"],
                kw_args["description"],
                DbEnumConverter.from_enum(kw_args["currency"]),
                DbEnumConverter.from_enum(kw_args["billing_type"]),
                kw_args["vat_value"],
                DbEnumConverter.from_enum(kw_args["campaign_type"]),
                DbEnumConverter.from_enum(kw_args["platform"]),
                kw_args["comment"],
                platforms,
                kw_args["service_id"],
                kw_args["type"],
            )
            await self._add_product_version(con, product_id=product["id"], **kw_args)

            return {"product_id": product["id"]}

    async def update_product(self, **kw_args):
        async with self.connection() as con, con.transaction():
            await self._check_product_for_duplicates(
                con,
                active_from=kw_args["versions"][0]["active_from"],
                active_to=kw_args["versions"][-1].get("active_to"),
                **kw_args,
            )

            await con.execute(
                sqls.update_product,
                kw_args["product_id"],
                kw_args["title"],
                kw_args["act_text"],
                kw_args["description"],
                kw_args["vat_value"],
                kw_args["comment"],
            )
            await con.execute(sqls.remove_product_versions, kw_args["product_id"])
            for version in kw_args["versions"]:
                version.update(product_id=kw_args["product_id"])
                await self._add_product_version(con, **version)
