from datetime import datetime
from typing import Dict, Iterable, List, Optional, Tuple

import aiohttp
from google.protobuf.timestamp_pb2 import Timestamp
from marshmallow import fields, post_load
from smb.common.http_client import collect_errors

from smb.common.aiotvm import HttpClientWithTvm
from maps_adv.common.protomallow import PbDateTimeField, PbEnumField, ProtobufSchema
from maps_adv.geosmb.doorman.proto import (
    clients_pb2,
    common_pb2,
    errors_pb2,
    events_pb2,
    statistics_pb2,
)
from maps_adv.geosmb.doorman.proto.segments_pb2 import SegmentType as SegmentTypePb

from .enums import (
    ENUM_TO_PROTO_MAP,
    PROTO_TO_ENUM_MAP,
    ClientGender,
    OrderEvent,
    SegmentType,
    Source,
)
from .exceptions import BadRequest, Conflict, NotFound, UnexpectedNaiveDatetime


class OrderStatisticsSchema(ProtobufSchema):
    class Meta:
        pb_message_class = statistics_pb2.OrderStatistics

    total = fields.Integer()
    successful = fields.Integer()
    unsuccessful = fields.Integer()
    last_order_timestamp = PbDateTimeField()


class ClientStatisticsSchema(ProtobufSchema):
    class Meta:
        pb_message_class = statistics_pb2.ClientStatistics

    orders = fields.Nested(OrderStatisticsSchema)


class ClientDataSchema(ProtobufSchema):
    class Meta:
        pb_message_class = clients_pb2.ClientData

    id = fields.Integer()
    biz_id = fields.Integer()
    phone = fields.Integer()
    email = fields.String()
    passport_uid = fields.Integer()
    first_name = fields.String()
    last_name = fields.String()
    gender = PbEnumField(
        enum=ClientGender,
        pb_enum=common_pb2.ClientGender,
        values_map=PROTO_TO_ENUM_MAP["gender"],
    )
    comment = fields.String()
    segments = fields.List(
        PbEnumField(
            enum=SegmentType,
            pb_enum=SegmentTypePb,
            values_map=PROTO_TO_ENUM_MAP["segment_type"],
        )
    )
    statistics = fields.Nested(ClientStatisticsSchema)
    source = PbEnumField(
        enum=Source, pb_enum=common_pb2.Source, values_map=PROTO_TO_ENUM_MAP["source"]
    )
    registration_timestamp = PbDateTimeField()
    labels = fields.List(fields.String())


class ClientsListOutputSchema(ProtobufSchema):
    class Meta:
        pb_message_class = clients_pb2.ClientsListOutput

    clients = fields.Nested(ClientDataSchema, many=True)
    total_count = fields.Integer()


class ClientContactsSchema(ProtobufSchema):
    class Meta:
        pb_message_class = clients_pb2.ClientContacts

    id = fields.Integer(required=True)
    biz_id = fields.Integer(required=True)
    passport_uid = fields.Integer()
    phone = fields.Integer()
    email = fields.String()
    first_name = fields.String()
    last_name = fields.String()
    cleared_for_gdpr = fields.Boolean()


class ClientContactsListSchema(ProtobufSchema):
    class Meta:
        pb_message_class = clients_pb2.ClientContactsList

    clients = fields.Nested(ClientContactsSchema, many=True)

    @post_load
    def to_dict(self, data: dict) -> dict:
        return {client.pop("id"): client for client in data["clients"]}


class ClearedClientSchema(ProtobufSchema):
    class Meta:
        pb_message_class = clients_pb2.ClearClientsForGdprOutput.ClearedClient

    biz_id = fields.Integer(required=True)
    client_id = fields.Integer(required=True)


class ClearClientsForGdprOutputSchema(ProtobufSchema):
    class Meta:
        pb_message_class = clients_pb2.ClearClientsForGdprOutput

    cleared_clients = fields.List(fields.Nested(ClearedClientSchema), required=True)


class SearchClientsForGdprOutputSchema(ProtobufSchema):
    class Meta:
        pb_message_class = clients_pb2.SearchClientsForGdprOutput

    clients_exist = fields.Boolean(required=True)


class BulkCreateClientsOutputSchema(ProtobufSchema):
    class Meta:
        pb_message_class = clients_pb2.BulkCreateClientsOutput

    total_created = fields.Integer()
    total_merged = fields.Integer()


class DoormanClient(HttpClientWithTvm):
    @collect_errors
    async def create_client(
        self,
        *,
        biz_id: int,
        source: Source,
        metadata: Optional[dict] = None,
        phone: Optional[int] = None,
        email: Optional[str] = None,
        passport_uid: Optional[int] = None,
        first_name: Optional[str] = None,
        last_name: Optional[str] = None,
        gender: Optional[ClientGender] = None,
        comment: Optional[str] = None,
        initiator_id: Optional[int] = None,
    ) -> dict:
        if metadata is None:
            metadata = {}

        request_pb = clients_pb2.ClientSetupData(
            biz_id=biz_id,
            metadata=clients_pb2.SourceMetadata(
                source=ENUM_TO_PROTO_MAP["source"][source], **metadata
            ),
            phone=phone,
            email=email,
            passport_uid=passport_uid,
            first_name=first_name,
            last_name=last_name,
            gender=ENUM_TO_PROTO_MAP["gender"][gender] if gender else None,
            comment=comment,
            initiator_id=initiator_id,
        )

        response_body = await self.request(
            method="POST",
            uri="/v1/create_client/",
            expected_statuses=[201],
            data=request_pb.SerializeToString(),
            metric_name="/v1/create_client/",
        )

        return ClientDataSchema().from_bytes(response_body)

    @collect_errors
    async def create_clients(
        self,
        *,
        biz_id: int,
        source: Source,
        label: Optional[str] = None,
        clients: List[dict],
    ) -> Tuple[int, int]:
        """Creates clients by bulk. Each item in `clients` is dict
            with optional fields:
                - first_name="Иван",
                - last_name="Волков",
                - phone=1111111111,
                - email="ivan@yandex.ru",
                - comment="ivan comment",
        returns: (total_created, total_merged)
        """
        request_pb = clients_pb2.BulkCreateClientsInput(
            biz_id=biz_id,
            source=ENUM_TO_PROTO_MAP["source"][source],
            label=label,
            clients=[
                clients_pb2.BulkCreateClientsInput.BulkClient(**client)
                for client in clients
            ],
        )

        response_body = await self.request(
            method="POST",
            uri="/v1/create_clients/",
            expected_statuses=[201],
            data=request_pb.SerializeToString(),
            metric_name="/v1/create_clients/",
        )

        response = BulkCreateClientsOutputSchema().from_bytes(response_body)

        return response["total_created"], response["total_merged"]

    @collect_errors
    async def retrieve_client(self, *, biz_id: int, client_id: int) -> dict:
        request_pb = clients_pb2.ClientRetrieveInput(biz_id=biz_id, id=client_id)

        response_body = await self.request(
            method="POST",
            uri="/v1/retrieve_client/",
            expected_statuses=[200],
            data=request_pb.SerializeToString(),
            metric_name="/v1/retrieve_client/",
        )

        return ClientDataSchema().from_bytes(response_body)

    @collect_errors
    async def list_clients(
        self,
        *,
        biz_id: int,
        search_string: Optional[str] = None,
        limit: int,
        offset: int,
    ) -> Tuple[List[dict], int]:
        request_pb = clients_pb2.ClientsListInput(
            biz_id=biz_id,
            search_string=search_string,
            pagination=common_pb2.Pagination(limit=limit, offset=offset),
        )

        response_body = await self.request(
            method="POST",
            uri="/v1/list_clients/",
            expected_statuses=[200],
            data=request_pb.SerializeToString(),
            metric_name="/v1/list_clients/",
        )

        result = ClientsListOutputSchema().from_bytes(response_body)

        return result["clients"], result["total_count"]

    @collect_errors
    async def list_contacts(self, client_ids: Iterable[int]) -> Dict[int, dict]:
        response_body = await self.request(
            method="POST",
            uri="/v1/list_contacts/",
            expected_statuses=[200],
            data=clients_pb2.ListContactsInput(
                client_ids=client_ids
            ).SerializeToString(),
            metric_name="/v1/list_contacts/",
        )

        return ClientContactsListSchema().from_bytes(response_body)

    @collect_errors
    async def update_client(
        self,
        *,
        client_id: int,
        biz_id: int,
        source: Source,
        metadata: Optional[dict] = None,
        phone: Optional[int] = None,
        email: Optional[str] = None,
        passport_uid: Optional[int] = None,
        first_name: Optional[str] = None,
        last_name: Optional[str] = None,
        gender: Optional[ClientGender] = None,
        comment: Optional[str] = None,
        initiator_id: Optional[int] = None,
    ):
        if metadata is None:
            metadata = {}

        request_pb = clients_pb2.ClientUpdateData(
            id=client_id,
            data=clients_pb2.ClientSetupData(
                biz_id=biz_id,
                metadata=clients_pb2.SourceMetadata(
                    source=ENUM_TO_PROTO_MAP["source"][source], **metadata
                ),
                phone=phone,
                email=email,
                passport_uid=passport_uid,
                first_name=first_name,
                last_name=last_name,
                gender=ENUM_TO_PROTO_MAP["gender"][gender] if gender else None,
                comment=comment,
                initiator_id=initiator_id,
            ),
        )

        response_body = await self.request(
            method="POST",
            uri="/v1/update_client/",
            expected_statuses=[200],
            data=request_pb.SerializeToString(),
            metric_name="/v1/update_client/",
        )

        return ClientDataSchema().from_bytes(response_body)

    @collect_errors
    async def add_order_event(
        self,
        *,
        biz_id: int,
        client_id: int,
        event_type: OrderEvent,
        event_timestamp: datetime,
        source: Source,
        order_id: int,
    ):
        if event_timestamp.tzinfo is None:
            raise UnexpectedNaiveDatetime(event_timestamp)

        await self._add_event(
            events_pb2.AddEventInput(
                biz_id=biz_id,
                client_id=client_id,
                timestamp=Timestamp(seconds=int(event_timestamp.timestamp())),
                source=ENUM_TO_PROTO_MAP["source"][source],
                order_event=events_pb2.OrderEvent(
                    type=ENUM_TO_PROTO_MAP["order_event_type"][event_type],
                    order_id=order_id,
                ),
            )
        )

    @collect_errors
    async def clear_clients_for_gdpr(self, passport_uid: int) -> List[dict]:
        response_body = await self.request(
            method="POST",
            uri="/internal/v1/clear_clients_for_gdpr/",
            expected_statuses=[200],
            data=clients_pb2.ClearClientsForGdprInput(
                passport_uid=passport_uid
            ).SerializeToString(),
            metric_name="/internal/v1/clear_clients_for_gdpr/",
        )

        data = ClearClientsForGdprOutputSchema().from_bytes(response_body)
        return data["cleared_clients"]

    @collect_errors
    async def search_clients_for_gdpr(self, passport_uid: int) -> List[dict]:
        response_body = await self.request(
            method="POST",
            uri="/internal/v1/search_clients_for_gdpr/",
            expected_statuses=[200],
            data=clients_pb2.SearchClientsForGdprInput(
                passport_uid=passport_uid
            ).SerializeToString(),
            metric_name="/internal/v1/search_clients_for_gdpr/",
        )

        data = SearchClientsForGdprOutputSchema().from_bytes(response_body)
        return data["clients_exist"]

    async def _add_event(self, request_pb: events_pb2.AddEventInput) -> None:
        await self.request(
            method="POST",
            uri="/v1/add_event/",
            expected_statuses=[201],
            data=request_pb.SerializeToString(),
            metric_name="/v1/add_event/",
        )

    async def _handle_custom_errors(self, response: aiohttp.ClientResponse):
        self._raise_for_matched_exception(
            response.status,
            {400: BadRequest, 404: NotFound, 409: Conflict},
            await response.content.read(),
        )
        await self._raise_unknown_response(response)

    @staticmethod
    def _raise_for_matched_exception(
        status: int, exceptions_map: dict, exception_body: bytes
    ):
        try:
            exception_cls = exceptions_map[status]
            error_pb = errors_pb2.Error.FromString(exception_body)
            raise exception_cls(error_pb.description)
        except KeyError:
            pass
