from functools import wraps
from typing import Optional, Type

from marshmallow import Schema, ValidationError, fields

from maps_adv.geosmb.marksman.server.lib.domain import Domain

__all__ = ["ApiProvider"]


class OutputValidationError(Exception):
    def __init__(self, inner_exception: ValidationError):
        self.inner_exception = inner_exception


def with_schemas(
    input_schema: Optional[Type[Schema]] = None,
    output_schema: Optional[Type[Schema]] = None,
):
    """Decorator for serializing input arguments and deserializing returned value."""

    def decorator(func):
        @wraps(func)
        async def wrapper(s, data=None, **kwargs):
            if input_schema is not None:
                if data is None:
                    raise RuntimeError("This api_provider requires input data")
                dumped = input_schema().load(data).data
                kwargs.update(dumped)

            got = await func(s, **kwargs)
            if output_schema is not None:
                try:
                    return output_schema().dump(got).data
                except ValidationError as exc:
                    raise OutputValidationError(exc)

        return wrapper

    return decorator


class StrictSchema(Schema):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.strict = True


class AddBusinessInputSchema(StrictSchema):
    biz_id = fields.Integer(required=True)


class SegmentDataSchema(StrictSchema):
    segment_name = fields.String(required=True)
    cdp_id = fields.Integer(required=True)
    cdp_size = fields.Integer(required=True)


class LabelDataSchema(StrictSchema):
    label_name = fields.String(required=True)
    cdp_id = fields.Integer(required=True)
    cdp_size = fields.Integer(required=True)


class ListBusinessSegmentsData(StrictSchema):
    biz_id = fields.Integer(required=True)
    permalink = fields.Integer(required=True)
    counter_id = fields.Integer(required=True)
    segments = fields.Nested(SegmentDataSchema, many=True)
    labels = fields.Nested(LabelDataSchema, many=True)


class ApiProvider:
    _domain: Domain

    def __init__(self, domain: Domain):
        self._domain = domain

    @with_schemas(input_schema=AddBusinessInputSchema)
    async def add_business(self, **kwargs):
        await self._domain.add_business(**kwargs)

    @with_schemas(output_schema=ListBusinessSegmentsData)
    async def list_business_segments_data(self, **kwargs):
        return await self._domain.list_business_segments_data(**kwargs)
