import json
import logging

from aiohttp import web
from marshmallow import ValidationError

from maps_adv.billing_proxy.lib.api.api_providers.exceptions import (
    BadDtParameter,
    exception_error_codes,
)
from maps_adv.billing_proxy.lib.core.balance_client import BalanceApiError
from maps_adv.billing_proxy.lib.domain.exceptions import (
    DomainException,
    WrongBalanceServiceID,
)
from maps_adv.billing_proxy.proto.common_pb2 import Error

logger = logging.getLogger(__name__)


@web.middleware
async def handle_domain_exception(request, handler):
    try:
        return await handler(request)
    except DomainException as exc:
        logger.exception(exc)
        error = Error(
            code=exception_error_codes[type(exc)].value, description=exc.context_as_str
        )
        return web.Response(status=422, body=error.SerializeToString())


@web.middleware
async def handle_balance_api_exception(request, handler):
    try:
        return await handler(request)
    except BalanceApiError as exc:
        logger.exception("Balance api error: %s", exc)

        error = Error(code=Error.BALANCE_API_ERROR, description="Balance API error")
        return web.Response(status=503, body=error.SerializeToString())


@web.middleware
async def handle_validation_exception(request, handler):
    try:
        return await handler(request)
    except ValidationError as exc:
        logger.error("Serialization error: %s", exc.normalized_messages())
        error = Error(
            code=Error.DATA_VALIDATION_ERROR, description=json.dumps(exc.messages)
        )
        return web.Response(status=400, body=error.SerializeToString())


@web.middleware
async def handle_wrong_balance_service_id(request, handler):
    try:
        return await handler(request)
    except WrongBalanceServiceID as exc:
        logger.exception(exc)
        return web.json_response(status=400, data={"error": "Wrong service id"})


@web.middleware
async def handle_invalid_reconciliation_params(request, handler):
    try:
        return await handler(request)
    except BadDtParameter as exc:
        logger.exception(exc)
        return web.Response(status=400)
