from typing import List, NamedTuple

import aiohttp
from smb.common.http_client import collect_errors

from smb.common.aiotvm import HttpClientWithTvm

__all__ = [
    "ApiAccessDenied",
    "BadContacts",
    "BadFilteringParams",
    "CdpClient",
    "CdpException",
    "ContactsValidationFailed",
    "SchemaValidationFailed",
]


class CdpException(Exception):
    pass


class ApiAccessDenied(CdpException):
    pass


class SchemaValidationFailed(CdpException):
    pass


class BadContacts(CdpException):
    pass


class ContactsValidationFailed(CdpException):
    pass


class BadFilteringParams(CdpException):
    pass


class CdpClient(HttpClientWithTvm):
    @collect_errors
    async def create_contacts_schema(
        self, counter_id: int, attributes: List[dict]
    ) -> None:
        attributes = attributes.copy()

        for attribute in attributes:
            attribute.setdefault("multivalued", False)
            attribute.setdefault("humanized", "")

        resp = await self.request(
            method="POST",
            uri=f"/cdp/internal/v1/counter/{counter_id}/schema/attributes",
            params={"entity_type": "contact"},
            expected_statuses=[200],
            json={"attributes": attributes},
            headers={"content-type": "application/x-yametrika+json"},
            metric_name="POST_/cdp/internal/v1/counter/{counter_id}/schema/attributes"
        )

        if not resp.get("success"):
            raise SchemaValidationFailed

    @collect_errors
    async def list_segments(self, counter_id: int) -> List[dict]:
        segments = []
        limit, last_received_segment_id = 100, 0

        while True:
            resp = await self.request(
                method="GET",
                uri=f"/cdp/internal/v1/counter/{counter_id}/segments",
                params={"limit": limit, "from_segment_id": last_received_segment_id},
                expected_statuses=[200],
                metric_name="GET_/cdp/internal/v1/counter/{counter_id}/segments"
            )
            segments_chunk = resp["segments"]
            if segments_chunk:
                segments.extend(segments_chunk)
                last_received_segment_id = segments_chunk[-1]["segment_id"]
            else:
                break

        return segments

    @collect_errors
    async def upload_contacts(
        self, counter_id: int, biz_id: int, contacts: List[dict]
    ) -> None:
        contacts_to_upload = []
        chunk_size = 1000
        main_fields = ("email", "phone", "client_ids")

        for contact in contacts:
            if "id" not in contact or not any(
                contact.get(field) for field in main_fields
            ):
                raise BadContacts

            contacts_to_upload.append(
                {
                    "uniq_id": contact["id"],
                    "emails": [contact["email"]] if contact.get("email") else [],
                    "phones": [contact["phone"]] if contact.get("phone") else [],
                    "client_ids": contact.get("client_ids", []),
                    "attribute_values": {
                        "biz_id": biz_id,
                        "segments": contact.get("segments", []),
                        "labels": contact.get("labels", []),
                    },
                }
            )

            if len(contacts_to_upload) == chunk_size:
                await self._contacts_upload(counter_id, contacts_to_upload)
                contacts_to_upload.clear()

        if contacts_to_upload:
            await self._contacts_upload(counter_id, contacts_to_upload)

    @collect_errors
    async def create_segment(
        self, counter_id: int, segment_name: str, *, filtering_params: dict
    ) -> dict:
        attributes = {
            attr["name"]: (attr["type_name"], attr["multivalued"])
            for attr in await self._list_contacts_attributes(counter_id=counter_id)
        }

        filters = []
        for name, value in filtering_params.items():
            try:
                attr_params = attributes[name]
                attr_descr = attr_type_map[attr_params]
            except KeyError:
                raise BadFilteringParams

            if not isinstance(value, attr_descr.python_type):
                raise BadFilteringParams

            filters.append(
                "{spec}_{name}=={value}".format(
                    spec=attr_descr.filter_spec,
                    name=name,
                    value=attr_descr.format_str.format(value),
                )
            )

        if not filters:
            raise BadFilteringParams

        resp = await self.request(
            method="POST",
            uri=f"/cdp/internal/v1/counter/{counter_id}/segment",
            json={
                "segment": {
                    "filter": " AND ".join(filters),
                    "name": segment_name,
                }
            },
            headers={"content-type": "application/x-yametrika+json"},
            expected_statuses=[200],
            metric_name="POST_/cdp/internal/v1/counter/{counter_id}/segment"
        )

        result = resp["segment"]
        result["size"] = await self.get_segment_size(
            counter_id=counter_id, segment_id=result["segment_id"]
        )

        return result

    async def _list_contacts_attributes(self, counter_id: int) -> List[dict]:
        resp = await self.request(
            method="GET",
            uri=f"/cdp/internal/v1/counter/{counter_id}/schema/attributes",
            params={"entity_type": "CONTACT"},
            expected_statuses=[200],
            metric_name="GET_/cdp/internal/v1/counter/{counter_id}/schema/attributes"
        )

        return resp["custom_attributes"]

    @collect_errors
    async def get_segment_size(self, counter_id: int, segment_id: int) -> int:
        resp = await self.request(
            method="GET",
            uri=f"/cdp/internal/v1/counter/{counter_id}/segment/{segment_id}/size",
            expected_statuses=[200],
            metric_name="GET_/cdp/internal/v1/counter/{counter_id}/segment/{segment_id}/size"
        )

        return resp["segment_size"]

    async def _contacts_upload(self, counter_id: int, contacts: List[dict]) -> None:
        resp = await self.request(
            method="POST",
            uri=f"/cdp/internal/v1/counter/{counter_id}/data/contacts",
            params={"merge_mode": "SAVE"},
            expected_statuses=[200, 201],
            json={"contacts": contacts},
            metric_name="POST_/cdp/internal/v1/counter/{counter_id}/data/contacts"
        )

        if resp["uploading"]["api_validation_status"] != "PASSED":
            raise ContactsValidationFailed

    async def _handle_custom_errors(self, response: aiohttp.ClientResponse) -> None:
        if response.status == 403:
            raise ApiAccessDenied

        await super()._handle_custom_errors(response)


class AttributeDescr(NamedTuple):
    filter_spec: str
    python_type: type
    format_str: str


attr_type_map = {
    ("NUMERIC", False): AttributeDescr("cdp:cn:attrNum", int, "{}"),
    ("TEXT", False): AttributeDescr("cdp:cn:attrStr", str, "'{}'"),
    ("NUMERIC", True): AttributeDescr("cdp:cn:multiAttrNum", int, "{}"),
    ("TEXT", True): AttributeDescr("cdp:cn:multiAttrStr", str, "'{}'"),
}
