from copy import deepcopy
from datetime import datetime, timezone, timedelta
from decimal import Decimal
from typing import List, Optional

import pytz

from maps_adv.billing_proxy.lib.data_manager import (
    exceptions as data_manager_exceptions,
)
from maps_adv.billing_proxy.lib.core.cpm_filters import (
    cpm_filters_registry,
    CpmData,
    CpmCoef,
)
from maps_adv.billing_proxy.lib.data_manager.clients import AbstractClientsDataManager
from maps_adv.billing_proxy.lib.data_manager.products import AbstractProductsDataManager
from maps_adv.billing_proxy.lib.domain import (
    BillingType,
    CampaignType,
    CreativeType,
    CurrencyType,
    FixTimeIntervalType,
    PlatformType,
)

from .enums import OrderSize, RubricName, CpmCoefFieldType
from .exceptions import (
    ClientDoesNotExist,
    ContractDoesNotExist,
    MultipleProductsMatched,
    NoActiveVersionsForProduct,
    NoProductVersionsSpecified,
    NonCPMProduct,
    NoProductsMatched,
    ProductDoesNotExist,
    ProductClientMismatch,
    InvalidBillingType,
    NoPlatformsSpecified,
    ConflictingProductVersionTimeSpans,
    ConflictingProducts,
    ProductHasZeroCost,
    ClientsAlreadyBoundToOverlappingProducts,
)


class BillingSetuper:
    @classmethod
    def setup_billing(cls, product):
        cls._setup_billing(product)
        for version in product.get("versions", []):
            cls._setup_billing(version)

    @classmethod
    def _setup_billing(cls, product):
        product["billing"] = cls._create_billing(
            product["billing_type"], product["billing_data"]
        )
        del product["billing_type"]
        del product["billing_data"]

    @classmethod
    def _create_billing(cls, billing_type, billing_data):
        if billing_type == BillingType.CPM:
            billing_data = cls._parse_cpm_data(billing_data)
        elif billing_type == BillingType.FIX:
            billing_data = cls._parse_fix_data(billing_data)
        else:
            billing_data = {}

        return {billing_type.value.lower(): billing_data}

    @classmethod
    def _parse_cpm_data(cls, data):
        return {"base_cpm": Decimal(data["base_cpm"])}

    @classmethod
    def _parse_fix_data(cls, data):
        time_interval = FixTimeIntervalType[data["time_interval"]]
        return {"cost": Decimal(data["cost"]), "time_interval": time_interval}

    @staticmethod
    def split_billing(billing):
        keys = list(billing.keys())
        if not keys:
            raise InvalidBillingType(billing_type="")
        billing_type = keys[0]
        if billing_type.lower() == BillingType.CPM.value.lower():
            base_cpm = Decimal(billing[billing_type]["base_cpm"])
            if base_cpm <= Decimal("0"):
                raise ProductHasZeroCost()
            return (
                BillingType.CPM,
                {"base_cpm": str(base_cpm)},
            )
        elif billing_type.lower() == BillingType.FIX.value.lower():
            cost = Decimal(billing[billing_type]["cost"])
            if cost <= Decimal("0"):
                raise ProductHasZeroCost()
            return (
                BillingType.FIX,
                {
                    "cost": str(cost),
                    "time_interval": billing[billing_type]["time_interval"].name,
                },
            )
        else:
            raise InvalidBillingType(billing_type=billing_type)


MONTHLY_DISCOUNTS = {
    1: Decimal("0.7"),
    2: Decimal("0.8"),
    3: Decimal("1.0"),
    4: Decimal("1.0"),
    5: Decimal("1.0"),
    6: Decimal("1.0"),
    7: Decimal("1.0"),
    8: Decimal("1.0"),
    9: Decimal("1.3"),
    10: Decimal("1.3"),
    11: Decimal("1.3"),
    12: Decimal("1.3"),
}


def seasonal_coefs_apply(product: dict) -> bool:
    REGULAR_TYPE = "REGULAR"
    return (
        product["type"] == REGULAR_TYPE
        and product["campaign_type"] != CampaignType.CATEGORY_SEARCH_PIN
        and product["currency"]
        in (CurrencyType.RUB, CurrencyType.EUR, CurrencyType.USD)
    )


class ProductsDomain:
    def __init__(
        self,
        products_dm: AbstractProductsDataManager,
        clients_dm: AbstractClientsDataManager,
        balance_service_id: int,
        seasonal_coefs_since: datetime,
    ):
        self._products_dm = products_dm
        self._clients_dm = clients_dm
        self._balance_service_id = balance_service_id
        self._seasonal_coefs_since = seasonal_coefs_since
        self._moscow_tz = pytz.timezone("Europe/Moscow")

    async def retrieve_product(self, product_id: int) -> dict:
        product = await self._products_dm.find_product(product_id)
        if product is None:
            raise ProductDoesNotExist(product_id=product_id)

        BillingSetuper.setup_billing(product)
        return product

    async def list_products(
        self, service_ids: Optional[List[int]] = None
    ) -> List[dict]:
        products = await self._products_dm.list_products(service_ids=service_ids)
        for product in products:
            BillingSetuper.setup_billing(product)
        return products

    async def advise_product(
        self,
        platforms: List[PlatformType],
        campaign_type: CampaignType,
        service_id: Optional[int] = None,
        contract_id: Optional[int] = None,
        client_id: Optional[int] = None,
        currency: Optional[CurrencyType] = None,
    ) -> dict:
        if client_id is None and contract_id is not None:
            raise ValueError

        if bool(currency is None) != bool(
            client_id is not None and contract_id is not None
        ):
            raise ValueError

        if client_id is not None:
            if not await self._clients_dm.client_exists(client_id):
                raise ClientDoesNotExist(client_id=client_id)

        if contract_id is not None:
            contract = await self._clients_dm.find_contract(contract_id)
            if contract is None:
                raise ContractDoesNotExist(contract_id=contract_id)
            currency = contract["currency"]

        products = await self._products_dm.list_by_params(
            platforms=platforms,
            campaign_type=campaign_type,
            service_id=service_id,
            currency=currency,
            dt=datetime.now(tz=timezone.utc),
            client_id=client_id,
        )

        if not products:
            raise NoProductsMatched
        elif len(products) != 1:
            # Prefer client-specific products
            client_specific_products = list(
                filter(lambda pr: pr["is_client_specific"], products)
            )
            if service_id is not None:
                products = client_specific_products
            else:
                if client_specific_products:
                    products = client_specific_products
                if len(products) != 1:
                    # Ignore products with third-party service_ids
                    products = list(
                        filter(
                            lambda pr: pr["service_id"] == self._balance_service_id,
                            products,
                        )
                    )
            if len(products) != 1:
                raise MultipleProductsMatched

        advise_product = products[0]
        BillingSetuper.setup_billing(advise_product)
        return advise_product

    def _monthly_discounts_to_coefs(
        self, active_from: datetime, active_to: datetime
    ) -> list[CpmCoef]:
        # Seasonal coefs use the Moscow timezone
        active_from = active_from.astimezone(self._moscow_tz)
        active_to = active_to.astimezone(self._moscow_tz)

        coefs = []
        current_from = max(
            active_from, self._seasonal_coefs_since.astimezone(self._moscow_tz)
        )
        while current_from < active_to:
            start_of_the_next_month = datetime(
                current_from.year + current_from.month // 12,
                current_from.month % 12 + 1,
                1,
                tzinfo=current_from.tzinfo,
            )
            current_to = min(active_to, start_of_the_next_month)
            coefs.append(
                CpmCoef(
                    CpmCoefFieldType.MONTHLY,
                    current_from.month,
                    MONTHLY_DISCOUNTS[current_from.month],
                    current_from.astimezone(timezone.utc),
                    current_to.astimezone(timezone.utc),
                )
            )
            current_from = current_to
        return coefs

    @staticmethod
    def _merge_periods(periods: list) -> list:
        result = []
        for period in periods:
            if result and result[-1]["final_cpm"] == period["final_cpm"]:
                result[-1]["active_to"] = period["active_to"]
            else:
                result.append(period)

        return result

    async def calculate_cpm(
        self,
        *,
        product_id: int,
        targeting_query: Optional[dict] = None,
        rubric_name: Optional[RubricName] = None,
        order_size: Optional[OrderSize] = None,
        creative_types: List[CreativeType] = None,
        dt: Optional[datetime] = None,
        active_from: Optional[datetime] = None,
        active_to: Optional[datetime] = None,
    ) -> Decimal:
        product = await self._products_dm.find_product(product_id)
        if product is None:
            raise ProductDoesNotExist(product_id=product_id)

        if product["billing_type"] is not BillingType.CPM:
            raise NonCPMProduct(product_id=product_id)

        if dt is None:
            dt = datetime.now(tz=timezone.utc)

        if active_from is None:
            active_from = dt

        active_version = await self._products_dm.find_product_active_version(
            product_id, active_from
        )

        if active_version is None:
            raise NoActiveVersionsForProduct(product_id=product_id)

        if active_to is None:
            # Make a 1 second interval to calculate seasonal coef for
            active_to = active_from + timedelta(seconds=1)

        if (
            active_version["active_to"] is not None
            and active_to > active_version["active_to"]
        ):
            raise NoActiveVersionsForProduct(product_id=product_id)

        billing_data = deepcopy(active_version["billing_data"])
        cpm = Decimal(billing_data.pop("base_cpm"))
        input_data = {
            "targeting_query": targeting_query,
            "rubric_name": rubric_name,
            "order_size": order_size.value if order_size else None,
            "creative_types": creative_types,
        }
        cpm_data = CpmData(base_cpm=cpm)
        for filter_key in active_version["cpm_filters"]:
            cpm_filter = cpm_filters_registry[filter_key]
            cpm_data = cpm_filter(cpm_data=cpm_data, **billing_data, **input_data)
        monthly_coefs = (
            self._monthly_discounts_to_coefs(active_from, active_to)
            if seasonal_coefs_apply(product)
            else []
        )
        cpm_data.coefs.extend(monthly_coefs)
        if monthly_coefs:
            periods = [
                {
                    "active_from": monthly_coef.active_from,
                    "active_to": monthly_coef.active_to,
                    "final_cpm": Decimal.quantize(
                        cpm_data.final_cpm * monthly_coef.rate,
                        Decimal("1e-4"),
                        "ROUND_FLOOR",
                    ),
                }
                for monthly_coef in monthly_coefs
            ]
            if active_from < monthly_coefs[0].active_from:
                periods.insert(
                    0,
                    {
                        "active_from": active_from,
                        "active_to": monthly_coefs[0].active_from - timedelta(days=1),
                        "final_cpm": Decimal.quantize(
                            cpm_data.final_cpm, Decimal("1e-4"), "ROUND_FLOOR"
                        ),
                    },
                )
        else:
            periods = [
                {
                    "active_from": active_from,
                    "active_to": active_to,
                    "final_cpm": Decimal.quantize(
                        cpm_data.final_cpm, Decimal("1e-4"), "ROUND_FLOOR"
                    ),
                }
            ]

        periods = self._merge_periods(periods)

        result = {
            "cpm": Decimal.quantize(cpm_data.final_cpm, Decimal("1e-4"), "ROUND_FLOOR"),
            "base_cpm": Decimal.quantize(
                cpm_data.base_cpm, Decimal("1e-4"), "ROUND_FLOOR"
            ),
            "coefs": cpm_data.coefs,
            "periods": periods,
        }

        return result

    async def list_clients_bound_to_product(self, *, product_id):
        return await self._products_dm.list_clients_bound_to_product(product_id)

    async def bind_client_to_product(self, *, product_id, clients):
        try:
            await self._products_dm.bind_client_to_product(product_id, clients)
        except data_manager_exceptions.ClientDoesNotExist as exc:
            raise ClientDoesNotExist(client_id=exc.client_id)
        except data_manager_exceptions.ProductDoesNotExist as exc:
            raise ProductDoesNotExist(product_id=exc.product_id)
        except data_manager_exceptions.ContractDoesNotExist as exc:
            raise ContractDoesNotExist(contract_id=exc.contract)
        except data_manager_exceptions.ProductClientMismatch as exc:
            raise ProductClientMismatch(
                product_id=exc.product_id, client_ids=exc.client_ids
            )
        except data_manager_exceptions.ClientsAlreadyBoundToOverlappingProducts as exc:
            raise ClientsAlreadyBoundToOverlappingProducts(
                client_ids=exc.client_ids,
                platforms=[platform.name for platform in exc.platforms],
                currency=exc.currency,
            )

    async def unbind_client_from_product(self, *, product_id, clients):
        try:
            await self._products_dm.unbind_client_from_product(product_id, clients)
        except data_manager_exceptions.ProductClientMismatch as exc:
            raise ProductClientMismatch(
                product_id=exc.product_id, client_ids=exc.client_ids
            )

    async def create_product(self, **kw_args):
        if not kw_args["platforms"]:
            raise NoPlatformsSpecified()
        kw_args["platform"] = kw_args["platforms"][0]
        kw_args["service_id"] = self._balance_service_id
        kw_args["billing_type"], kw_args["billing_data"] = BillingSetuper.split_billing(
            kw_args["billing"]
        )
        try:
            return await self._products_dm.create_product(**kw_args)
        except data_manager_exceptions.ConflictingProducts as exc:
            raise ConflictingProducts(product_ids=exc.product_ids)

    @staticmethod
    def _validate_versions(product, versions):
        versions.sort(key=lambda version: version["active_from"])
        latest_to = versions[0]["active_from"]
        for version in versions:
            billing_type, billing_data = BillingSetuper.split_billing(
                version["billing"]
            )
            if product["billing_type"] != billing_type:
                raise InvalidBillingType(billing_type=billing_type)
            version["billing_data"] = billing_data
            if latest_to is None or version["active_from"] != latest_to:
                raise ConflictingProductVersionTimeSpans(
                    to=latest_to, from_=version["active_from"]
                )
            latest_to = version.get("active_to")

    async def update_product(self, **kw_args):
        product = await self._products_dm.find_product(kw_args["product_id"])
        if product is None:
            raise ProductDoesNotExist(product_id=kw_args["product_id"])

        if not kw_args["versions"]:
            raise NoProductVersionsSpecified()

        self._validate_versions(product, kw_args["versions"])

        kw_args["service_id"] = product["service_id"]
        kw_args["currency"] = product["currency"]
        kw_args["campaign_type"] = product["campaign_type"]
        kw_args["type"] = product["type"]

        try:
            await self._products_dm.update_product(**kw_args)
        except data_manager_exceptions.ConflictingProducts as exc:
            raise ConflictingProducts(product_ids=exc.product_ids)
