from typing import List

import aiohttp
from marshmallow import fields
from smb.common.http_client import collect_errors
from smb.common.http_client.lib.exceptions import BaseHttpClientException

from smb.common.aiotvm import HttpClientWithTvm
from maps_adv.common.protomallow import ProtobufSchema
from maps_adv.geosmb.promoter.proto import errors_pb2, leads_pb2

__all__ = ["PromoterClient"]


class PromoterClientException(BaseHttpClientException):
    pass


class BadRequest(PromoterClientException):
    pass


class RemovedLeadSchema(ProtobufSchema):
    class Meta:
        pb_message_class = leads_pb2.RemoveLeadsForGdprOutput.RemovedLead

    biz_id = fields.Integer(required=True)
    lead_id = fields.Integer(required=True)


class RemoveLeadsForGdprOutputSchema(ProtobufSchema):
    class Meta:
        pb_message_class = leads_pb2.RemoveLeadsForGdprOutput

    removed_leads = fields.Nested(RemovedLeadSchema, many=True, required=True)


class PromoterClient(HttpClientWithTvm):
    @collect_errors
    async def remove_leads_for_gdpr(self, *, passport_uid: int) -> List[dict]:
        response_body = await self.request(
            method="POST",
            uri="/internal/v1/remove_leads_for_gdpr/",
            expected_statuses=[200],
            data=leads_pb2.RemoveLeadsForGdprInput(
                passport_uid=passport_uid
            ).SerializeToString(),
            metric_name="/internal/v1/remove_leads_for_gdpr/"
        )

        data = RemoveLeadsForGdprOutputSchema().from_bytes(response_body)
        return data["removed_leads"]

    @collect_errors
    async def search_leads_for_gdpr(self, *, passport_uid: int) -> bool:
        response_body = await self.request(
            method="POST",
            uri="/internal/v1/search_leads_for_gdpr/",
            expected_statuses=[200],
            data=leads_pb2.RemoveLeadsForGdprInput(
                passport_uid=passport_uid
            ).SerializeToString(),
            metric_name="/internal/v1/search_leads_for_gdpr/"
        )

        return leads_pb2.SearchLeadsForGdprOutput.FromString(response_body).leads_exist

    async def _handle_custom_errors(self, response: aiohttp.ClientResponse):
        self._raise_for_matched_exception(
            response.status, {400: BadRequest}, await response.content.read()
        )
        await self._raise_unknown_response(response)

    @staticmethod
    def _raise_for_matched_exception(
        status: int, exceptions_map: dict, exception_body: bytes
    ):
        try:
            exception_cls = exceptions_map[status]
            error_pb = errors_pb2.Error.FromString(exception_body)
            raise exception_cls(error_pb.description)
        except KeyError:
            pass
