from datetime import datetime, timezone
from decimal import Decimal
from typing import Dict, Iterable, Optional

import aiohttp
from google.protobuf.timestamp_pb2 import Timestamp
from tenacity import RetryError

from maps_adv.common.client import Client as BaseClient
from maps_adv.common.helpers.enums import CampaignTypeEnum
from maps_adv.common.proto import campaign_pb2
from maps_adv.statistics.dashboard.proto.campaign_stat_pb2 import (
    CampaignChargedSumInput,
    CampaignChargedSumOutput,
    CampaignEventsForPeriodInput,
    CampaignEventsInputPart,
    CampaignEventsForPeriodOutput,
)

from .exceptions import (
    NoCampaignsPassed,
    NoStatistics,
    UnknownResponse,
)


ENUM_CONVERTERS = {
    "campaign_type": {
        CampaignTypeEnum.PIN_ON_ROUTE: campaign_pb2.CampaignType.PIN_ON_ROUTE,
        CampaignTypeEnum.BILLBOARD: campaign_pb2.CampaignType.BILLBOARD,
        CampaignTypeEnum.ZERO_SPEED_BANNER: campaign_pb2.CampaignType.ZERO_SPEED_BANNER,
        CampaignTypeEnum.CATEGORY_SEARCH: campaign_pb2.CampaignType.CATEGORY_SEARCH,
        CampaignTypeEnum.ROUTE_BANNER: campaign_pb2.CampaignType.ROUTE_BANNER,
        CampaignTypeEnum.VIA_POINTS: campaign_pb2.CampaignType.VIA_POINTS,
        CampaignTypeEnum.OVERVIEW_BANNER: campaign_pb2.CampaignType.OVERVIEW_BANNER,
        CampaignTypeEnum.PROMOCODE: campaign_pb2.CampaignType.PROMOCODE,
    }
}


class Client(BaseClient):
    async def campaigns_charged_sum(
        self, *campaign_ids: Iterable[int], on_datetime: Optional[datetime] = None
    ) -> Dict[int, Decimal]:
        if not campaign_ids:
            raise NoCampaignsPassed()

        if not on_datetime:
            on_datetime = datetime.now(tz=timezone.utc)

        try:
            response_body = await self._retryer.call(
                self._request,
                "POST",
                "/statistics/campaigns/charged_sum/",
                data=CampaignChargedSumInput(
                    campaign_ids=campaign_ids, on_timestamp=int(on_datetime.timestamp())
                ).SerializeToString(),
                expected_status=200,
            )
        except RetryError as exc:
            exc.reraise()

        result_pb = CampaignChargedSumOutput.FromString(response_body)

        return {
            el.campaign_id: Decimal(el.charged_sum)
            for el in result_pb.campaigns_charged_sums
        }

    async def campaigns_events(
        self,
        campaigns: dict,
        period_from: Optional[datetime] = None,
        period_to: Optional[datetime] = None,
    ) -> Dict[int, int]:
        if not campaigns:
            return {}

        try:
            response_body = await self._retryer.call(
                self._request,
                "POST",
                "/statistics/campaigns/events/",
                data=CampaignEventsForPeriodInput(
                    campaigns=list(
                        map(
                            lambda campaign: CampaignEventsInputPart(
                                campaign_id=campaign["id"],
                                campaign_type=ENUM_CONVERTERS["campaign_type"][
                                    campaign["type"]
                                ],
                            ),
                            campaigns,
                        )
                    ),
                    period_from=Timestamp(seconds=int(period_from.timestamp()))
                    if period_from
                    else None,
                    period_to=Timestamp(seconds=int(period_to.timestamp()))
                    if period_to
                    else None,
                ).SerializeToString(),
                expected_status=200,
            )
        except RetryError as exc:
            exc.reraise()

        result_pb = CampaignEventsForPeriodOutput.FromString(response_body)

        return {el.campaign_id: el.events for el in result_pb.campaigns_events}

    @staticmethod
    def _check_response(response: aiohttp.ClientResponse):
        if response.status == 204:
            raise NoStatistics()

        raise UnknownResponse(response.status)
