from datetime import datetime, timezone
from math import ceil
from typing import List, Optional, Tuple

import aiohttp
import marshmallow
from google.protobuf.timestamp_pb2 import Timestamp
from smb.common.http_client import HttpClient, collect_errors

from maps_adv.geosmb.clients.market.proto import common_pb2, orders_pb2

from .exceptions import BadMarketIntResponse, InvalidDatetimeTZ
from .schemas import (
    MAP_PROTO_ERROR_TO_EXCEPTION,
    ActualOrdersOutputSchema,
    FilterServiceCategoriesParametersSchema,
    FilterServicesParametersSchema,
    ServiceCategoryListSchema,
    ServiceListSchema,
)


class MarketIntClient(HttpClient):
    DEFAULT_LIMIT = 500
    request_timeout = 5

    @collect_errors
    async def fetch_services(
        self,
        *,
        biz_id: int,
        ordering: Optional[List[dict]] = None,
        filtering: Optional[dict] = None,
    ) -> List[dict]:
        """
        Returns list of business services.
        :param biz_id: Business id
        :param ordering: List of ordering params. Can be empty. Described as dict with params:
            :param type: Required. OrderType enum
            :param field: Required. FilterServiceOrderingFields enum
        :param filtering: Optional. Filterting params. Described as dict with params:
            :param query: Optional. String.
            :param category_ids: List of ints. Can be empty.
            :param statuses: List of ServiceStatus enums. Can be empty.
            :param include_without_categories: Optional. Bool.
        :return: List of dicts with services info.
        """
        offset = 0
        services = []

        paging, service_items = await self._fetch_services(
            biz_id=biz_id,
            ordering=ordering,
            filtering=filtering,
            limit=self.DEFAULT_LIMIT,
            offset=offset,
        )
        services.extend(service_items)

        pages = ceil(paging["total"] / self.DEFAULT_LIMIT)
        for _ in range(1, pages):
            offset += self.DEFAULT_LIMIT
            __, service_items = await self._fetch_services(
                biz_id=biz_id,
                ordering=ordering,
                filtering=filtering,
                limit=self.DEFAULT_LIMIT,
                offset=offset,
            )
            services.extend(service_items)

        return services

    async def _fetch_services(
        self,
        *,
        biz_id: int,
        limit: int,
        offset: int,
        ordering: Optional[List[dict]] = None,
        filtering: Optional[dict] = None,
    ) -> Tuple[dict, List[dict]]:
        input_data = dict(
            biz_id=biz_id,
            paging=dict(limit=limit, offset=offset),
            ordering=ordering if ordering is not None else [],
            filtering=filtering,
        )

        got = await self.request(
            method="POST",
            uri="/v1/filter_services",
            expected_statuses=[200],
            headers=self._make_headers(),
            data=FilterServicesParametersSchema().to_bytes(input_data),
            metric_name="/v1/filter_services",
        )

        try:
            result = ServiceListSchema().from_bytes(got)
            # TODO: remove this when market-int starts appending postfixes (GEOSMB-2777)
            self._add_avatars_image_postfix(result["items"])
            return result["paging"], result["items"]
        except marshmallow.ValidationError as e:
            raise BadMarketIntResponse(e.messages)

    @collect_errors
    async def fetch_service_categories(
        self,
        *,
        biz_id: int,
        ordering: Optional[List[dict]] = None,
        query: Optional[str] = None,
    ) -> dict:
        """
        Returns list of business categories.
        :param biz_id: Business id
        :param query: Optional. String.
        :param ordering: List of ordering params. Can be empty.
            Described as dict with params:
            :param type: Required. OrderType enum
            :param field: Required. FilterServiceCategoriesOrderingFields enum
        :return: {category_id: category_name}
        """
        offset = 0
        categories = {}

        paging, category_items = await self._fetch_service_categories(
            biz_id=biz_id,
            ordering=ordering,
            query=query,
            limit=self.DEFAULT_LIMIT,
            offset=offset,
        )
        categories.update({c["category_id"]: c["name"] for c in category_items})

        pages = ceil(paging["total"] / self.DEFAULT_LIMIT)
        for _ in range(1, pages):
            offset += self.DEFAULT_LIMIT
            __, category_items = await self._fetch_service_categories(
                biz_id=biz_id,
                ordering=ordering,
                query=query,
                limit=self.DEFAULT_LIMIT,
                offset=offset,
            )
            categories.update({c["category_id"]: c["name"] for c in category_items})

        return categories

    async def _fetch_service_categories(
        self,
        *,
        biz_id: int,
        ordering: Optional[List[dict]] = None,
        limit: int,
        offset: int,
        query: Optional[str] = None,
    ) -> Tuple[dict, List[dict]]:
        input_data = dict(
            biz_id=biz_id,
            paging=dict(limit=limit, offset=offset),
            ordering=ordering if ordering is not None else [],
            query=query,
        )

        got = await self.request(
            method="POST",
            uri="/v1/filter_service_categories",
            expected_statuses=[200],
            headers=self._make_headers(),
            data=FilterServiceCategoriesParametersSchema().to_bytes(input_data),
            metric_name="/v1/filter_service_categories",
        )

        try:
            result = ServiceCategoryListSchema().from_bytes(got)
            return result["paging"], result["items"]
        except marshmallow.ValidationError as e:
            raise BadMarketIntResponse(e.messages)

    @collect_errors
    async def fetch_actual_orders(self, actual_on: datetime) -> List[dict]:
        if actual_on.tzinfo != timezone.utc:
            raise InvalidDatetimeTZ(
                f"active_on param must have UTC tz: {actual_on.isoformat()}"
            )

        response_body = await self.request(
            method="POST",
            uri="/v1/fetch_actual_orders",
            expected_statuses=[200],
            headers=self._make_headers(),
            data=orders_pb2.ActualOrdersInput(
                actual_on=Timestamp(seconds=int(actual_on.timestamp()))
            ).SerializeToString(),
            metric_name="/v1/fetch_actual_orders",
        )

        return ActualOrdersOutputSchema().from_bytes(response_body)["orders"]

    async def _handle_custom_errors(self, response: aiohttp.ClientResponse) -> None:
        if response.status == 400:
            response_body = await response.read()
            error = common_pb2.Error.FromString(response_body)

            raise MAP_PROTO_ERROR_TO_EXCEPTION[error.code](error)
        else:
            await self._raise_unknown_response(response)

    def _make_headers(self) -> dict:
        return {"Content-Type": "application/x-protobuf"}

    @staticmethod
    def _add_avatars_image_postfix(items: List[dict]) -> None:
        """Add %s postfix to avatar url templates if missing"""
        for item in items:
            try:
                if not item["main_image_url_template"].endswith("/%s"):
                    item["main_image_url_template"] += "/%s"
            except KeyError:
                continue
