import logging

from asyncio import Semaphore, gather
from functools import partial
from itertools import chain
from typing import List

from maps_adv.export.lib.core.client import PointsCacheClient
from maps_adv.export.lib.core.enum import CampaignType
from maps_adv.export.lib.pipeline.exceptions import StepException
from maps_adv.export.lib.pipeline.param import Param
from maps_adv.points.client.lib import Client as PointsClient


_supported_campaign_types = {
    CampaignType.BILLBOARD,
}

logger = logging.getLogger(__name__)


class CampaignException(Exception):
    def __init__(self, campaign_id):
        super().__init__()
        self.campaign_id = campaign_id


class ResolvePointsStep:
    def __init__(self, places: Param, cache_folder, config):
        self._client = PointsClient(config["POINTS_URL"])
        if not config.get("EXPERIMENT_WITHOUT_CACHE_POINTS"):
            self._client = PointsCacheClient(self._client, cache_folder)
        self._places = places
        self._limit_requests = config.get("POINTS_LIMIT_REQUESTS") or 0

    async def __call__(self, campaigns):
        campaigns = list(
            filter(lambda c: c["campaign_type"] in _supported_campaign_types, campaigns)
        )
        campaigns = list(filter(lambda c: "area" in c["placing"], campaigns))

        limit_requests = self._limit_requests
        if self._limit_requests == 0:
            limit_requests = len(campaigns)
        limit_requests = Semaphore(limit_requests)

        points = list(
            await gather(
                *map(
                    partial(self.get_points_for_campaign, limit_requests),
                    campaigns,
                ),
                return_exceptions=True,
            )
        )

        for e in points:
            if isinstance(e, BaseException) and not isinstance(e, CampaignException):
                raise e

        troublesome_ids = [
            e.campaign_id for e in points if isinstance(e, CampaignException)
        ]

        self._places.value = {
            point.id: point
            for point in chain.from_iterable(
                filter(lambda e: not isinstance(e, BaseException), points)
            )
        }

        if troublesome_ids:
            raise StepException(
                troublesome_ids=troublesome_ids,
                processed_ids=[
                    c["id"] for c in campaigns if c["id"] not in troublesome_ids
                ],
            )

    @staticmethod
    def _extract_data(campaign):
        area = campaign["placing"]["area"]
        points_version = area["version"]
        polygons = [polygon["points"] for polygon in area["areas"]]

        return points_version, polygons

    async def get_points_for_campaign(
        self, semaphore: Semaphore, campaign: dict
    ) -> List[dict]:
        try:
            async with semaphore:
                points_version, polygons = self._extract_data(campaign)
                async with self._client as client:
                    points = await client(
                        points_version=points_version, polygons=polygons
                    )
                if points:
                    campaign["places"].extend([point.id for point in points])

            return points
        except Exception:
            logger.exception(f"Failed to fetch points for campaign {campaign['id']}")
            raise CampaignException(campaign_id=campaign["id"])
