import logging
from operator import itemgetter
from typing import AsyncGenerator, List

from aioitertools.itertools import chain, groupby

from maps_adv.geosmb.clients.bvm import BvmClient, BvmNotFound
from maps_adv.geosmb.clients.cdp import ApiAccessDenied, CdpClient
from maps_adv.geosmb.clients.geosearch import GeoSearchClient
from maps_adv.geosmb.marksman.server.lib.data_manager import BaseDataManager
from maps_adv.geosmb.marksman.server.lib.enums import SegmentType
from maps_adv.geosmb.marksman.server.lib.exceptions import (
    BizNotAdded,
    BizNotFound,
    NoOrgInfo,
    OrgWithoutCounter,
)

__all__ = ["Domain"]


class Domain:
    _dm: BaseDataManager
    _bvm: BvmClient
    _geosearch: GeoSearchClient
    _cdp: CdpClient

    def __init__(
        self,
        dm: BaseDataManager,
        bvm: BvmClient,
        geosearch: GeoSearchClient,
        cdp: CdpClient,
    ):
        self._dm = dm
        self._bvm = bvm
        self._geosearch = geosearch
        self._cdp = cdp

    async def add_business(self, biz_id: int) -> None:
        try:
            permalinks = await self._bvm.fetch_permalinks_by_biz_id(biz_id=biz_id)
        except BvmNotFound:
            raise BizNotFound
        else:
            permalink = permalinks[0]

        orginfo = await self._geosearch.resolve_org(permalink=permalink)
        if orginfo is None:
            raise NoOrgInfo
        if orginfo.metrika_counter is None:
            raise OrgWithoutCounter

        await self._dm.add_business(
            biz_id=biz_id, permalink=permalink, counter_id=int(orginfo.metrika_counter)
        )

    async def list_business_segments_data(self, biz_id: int) -> dict:
        result = await self._dm.list_business_segments_data(biz_id=biz_id)
        if result is None:
            raise BizNotAdded

        return result

    async def sync_businesses_segments(
        self, generator: AsyncGenerator[List[list], None]
    ) -> None:
        contact_attributes = [
            {"type_name": "numeric", "name": "biz_id", "multivalued": False},
            {"type_name": "text", "name": "segments", "multivalued": True},
            {"type_name": "text", "name": "labels", "multivalued": True},
        ]

        async for biz_id, clients in groupby(
            chain.from_iterable(generator), key=itemgetter("biz_id")
        ):
            business = await self._dm.list_business_segments_data(biz_id=biz_id)

            try:
                await self._cdp.create_contacts_schema(
                    counter_id=business["counter_id"], attributes=contact_attributes
                )
            except ApiAccessDenied:
                logging.getLogger("marksman.sync_businesses_segments").error(
                    "Got 403 while accessing counter %d", business["counter_id"]
                )
                continue

            existing_segment_names = set(
                map(itemgetter("segment_name"), business["segments"])
            )
            existing_label_names = set(
                map(itemgetter("label_name"), business["labels"])
            )
            required_segment_names, required_label_names = set(), set()
            contacts = []

            for client in clients:
                required_segment_names.update(client["segments"])
                required_label_names.update(client["labels"])
                client.pop("biz_id")
                contacts.append(client)

            for segment_name in sorted(
                required_segment_names.difference(existing_segment_names)
            ):
                created_segment = await self._cdp.create_segment(
                    counter_id=business["counter_id"],
                    segment_name=f"GEOSMB_SEG_{segment_name}",
                    filtering_params={"biz_id": biz_id, "segments": segment_name},
                )
                await self._dm.add_business_segment(
                    biz_id=biz_id,
                    name=segment_name,
                    cdp_id=created_segment["segment_id"],
                    cdp_size=created_segment["size"],
                    type_=SegmentType.SEGMENT,
                )

            for label_name in sorted(
                required_label_names.difference(existing_label_names)
            ):
                created_segment = await self._cdp.create_segment(
                    counter_id=business["counter_id"],
                    segment_name=f"GEOSMB_LABEL_{label_name}",
                    filtering_params={"biz_id": biz_id, "labels": label_name},
                )
                await self._dm.add_business_segment(
                    biz_id=biz_id,
                    name=label_name,
                    cdp_id=created_segment["segment_id"],
                    cdp_size=created_segment["size"],
                    type_=SegmentType.LABEL,
                )

            await self._cdp.upload_contacts(
                counter_id=business["counter_id"], biz_id=biz_id, contacts=contacts
            )

    async def sync_segments_sizes(self) -> None:
        biz_ids = await self._dm.list_biz_ids()

        segment_sizes = {}
        for biz_id in biz_ids:
            biz_id_data = await self._dm.list_business_segments_data(biz_id=biz_id)

            for segment in biz_id_data["segments"]:
                cdp_id = segment["cdp_id"]
                segment_sizes[cdp_id] = await self._cdp.get_segment_size(
                    counter_id=biz_id_data["counter_id"], segment_id=cdp_id
                )

        await self._dm.update_segments_sizes(sizes=segment_sizes)
