import json
from datetime import datetime, timezone
from decimal import Decimal
from typing import Dict, Iterable, List, Optional, Tuple

import aiohttp
from google.protobuf.timestamp_pb2 import Timestamp
from tenacity import RetryError

from maps_adv.billing_proxy.proto.common_pb2 import Error
from maps_adv.billing_proxy.proto.orders_charge_pb2 import (
    OrderChargeInput,
    OrdersChargeInput,
    OrdersChargeOutput,
)
from maps_adv.billing_proxy.proto.orders_for_stat_pb2 import (
    OrdersStatInfo,
    OrdersStatInfoInput,
    OrdersDiscountInfoInput,
    OrdersDiscountInfo,
    OrdersDebitsInfoInput,
    OrdersDebitsInfo,
)
from maps_adv.billing_proxy.proto.orders_pb2 import Order, OrderIds
from maps_adv.billing_proxy.proto.products_pb2 import (
    CpmCalculationInput,
    CpmCalculationResult,
)

from maps_adv.common.client import Client as BaseClient

from .enums import ENUMS_MAP, AdvRubric, CreativeType, OrderSize
from .exceptions import (
    OrderDoesNotExist,
    UnexpectedNaiveDatetime,
    UnknownResponse,
)


class Client(BaseClient):
    async def fetch_active_orders(self, *order_ids: Iterable[int]) -> List[int]:
        input_proto = OrderIds(order_ids=order_ids)

        try:
            response_body = await self._retryer.call(
                self._request,
                "POST",
                "/orders/active/",
                expected_status=200,
                data=input_proto.SerializeToString(),
            )
        except RetryError as exc:
            exc.reraise()

        return OrderIds.FromString(response_body).order_ids

    async def fetch_orders_balance(
        self, *order_ids: Iterable[int]
    ) -> Dict[int, Decimal]:
        input_proto = OrdersStatInfoInput(order_ids=order_ids)

        try:
            response_body = await self._retryer.call(
                self._request,
                "POST",
                "/orders/stats/",
                expected_status=200,
                data=input_proto.SerializeToString(),
            )
        except RetryError as exc:
            exc.reraise()

        result_proto = OrdersStatInfo.FromString(response_body)
        return {el.order_id: Decimal(el.balance) for el in result_proto.orders_info}

    async def fetch_orders_discounts(
        self, billed_at: datetime, *order_ids: Iterable[int]
    ) -> Dict[int, Decimal]:
        if billed_at.tzinfo is None:
            raise UnexpectedNaiveDatetime(billed_at)

        input_proto = OrdersDiscountInfoInput(
            order_ids=order_ids, billed_at=Timestamp(seconds=int(billed_at.timestamp()))
        )

        try:
            response_body = await self._retryer.call(
                self._request,
                "POST",
                "/orders/discounts/",
                expected_status=200,
                data=input_proto.SerializeToString(),
            )
        except RetryError as exc:
            exc.reraise()

        result_proto = OrdersDiscountInfo.FromString(response_body)
        return {el.order_id: Decimal(el.discount) for el in result_proto.discount_info}

    async def submit_orders_charges(
        self, *, charges: Dict[int, Decimal], bill_due_to: datetime
    ) -> Tuple[bool, Dict[int, bool]]:
        """Submits passed orders charges and returns charge result.

        Method returns tuple of two elements:
        - Charge result is a boolean and will be True if server has mutated data and
        False if this charges are ok, but data mutation skipped because these charges
        were applied previously.
        - Second element is a dict where key is order_id and value is boolean result of
        charge result for this order. False application result means that server has
        declined charge for this order.

        Raises `UnexpectedNaiveDatetime` for naive `bill_due_to` datetime.
        """
        if bill_due_to.tzinfo is None:
            raise UnexpectedNaiveDatetime(bill_due_to)

        input_proto = OrdersChargeInput(
            orders_charge=[
                OrderChargeInput(order_id=order_id, charged_amount=str(order_charge))
                for order_id, order_charge in charges.items()
            ],
            bill_for_timestamp=Timestamp(seconds=int(bill_due_to.timestamp())),
        )

        try:
            response_body = await self._retryer.call(
                self._request,
                "POST",
                "/orders/charge/",
                expected_status=201,
                data=input_proto.SerializeToString(),
            )
        except RetryError as exc:
            exc.reraise()

        result_proto = OrdersChargeOutput.FromString(response_body)
        return (
            result_proto.applied,
            {el.order_id: el.success for el in result_proto.charge_result},
        )

    async def fetch_order(self, order_id: int) -> Dict:

        try:
            response_body = await self._retryer.call(
                self._request, "GET", f"/orders/{order_id}/", expected_status=200
            )
        except RetryError as exc:
            exc.reraise()

        result_proto = Order.FromString(response_body)
        return {
            "id": result_proto.id,
            "currency": ENUMS_MAP["currency"][result_proto.currency],
        }

    async def calculate_product_cpm(
        self,
        product_id: int,
        *,
        rubric: Optional[AdvRubric],
        targeting_query: Optional[dict],
        dt: Optional[datetime],
        order_size: Optional[OrderSize],
        creative_types: Optional[List[CreativeType]],
    ) -> Decimal:
        input_proto = CpmCalculationInput(
            rubric=ENUMS_MAP["rubric"][rubric] if rubric else None,
            targeting_query=json.dumps(targeting_query),
            dt=Timestamp(seconds=int(dt.timestamp())) if dt else None,
            order_size=ENUMS_MAP["order_size"][order_size] if order_size else None,
            creative_types=list(
                map(lambda t: ENUMS_MAP["creative_type"][t], creative_types)
            ),
        )
        try:
            response_body = await self._retryer.call(
                self._request,
                "POST",
                f"/products/{product_id}/cpm/",
                expected_status=200,
                data=input_proto.SerializeToString(),
            )
        except RetryError as exc:
            exc.reraise()

        result_proto = CpmCalculationResult.FromString(response_body)
        return Decimal(result_proto.cpm)

    async def fetch_orders_debits(
        self, order_ids: List[int], billed_after: datetime
    ) -> dict:
        input_pb = OrdersDebitsInfoInput(
            order_ids=order_ids,
            billed_after=Timestamp(seconds=int(billed_after.timestamp())),
        )
        try:
            response_body = await self._retryer.call(
                self._request,
                "POST",
                "/orders/debits/",
                expected_status=200,
                data=input_pb.SerializeToString(),
            )
        except RetryError as exc:
            exc.reraise()

        orders_debits = OrdersDebitsInfo().FromString(response_body)
        return {
            order_debits.order_id: [
                {
                    "amount": Decimal(debit.amount),
                    "billed_at": datetime.fromtimestamp(
                        debit.billed_at.seconds, tz=timezone.utc
                    ),
                }
                for debit in order_debits.debits
            ]
            for order_debits in orders_debits.orders_debits
        }

    @staticmethod
    async def _check_response(response: aiohttp.ClientResponse):
        if response.status == 404:
            response_body = await response.read()
            error = Error.FromString(response_body)

            raise OrderDoesNotExist(error.description)

        raise UnknownResponse(response.status)
