from collections import defaultdict

import ujson
from marshmallow import fields


class NestedWithKey(fields.Nested):
    def __init__(self, nested, key, *args, many_per_key=False, **kwargs):
        self.key = key
        self.many_per_key = many_per_key
        super().__init__(nested, many=True, *args, **kwargs)

    def _serialize(self, nested_obj, attr, obj):
        nested_list = super()._serialize(nested_obj, attr, obj)
        if not self.many_per_key:
            return {
                item.pop(self.key): item
                for item in nested_list
            }
        else:
            data = defaultdict(list)
            for item in nested_list:
                key = item.pop(self.key)
                data[key].append(item)
            return data

    def _deserialize(self, value, attr, data):
        if not self.many_per_key:
            if not isinstance(value, dict) or not all([isinstance(v, dict) for v in value.values()]):
                self.fail('type', input=value, type=value.__class__.__name__)

            return super()._deserialize(
                value=[
                    {**v, self.key: k}
                    for k, v in value.items()
                ],
                attr=attr,
                data=data,
            )
        else:
            raise NotImplementedError


class QueryParamList(fields.List):
    def _add_to_schema(self, field_name, schema):
        super()._add_to_schema(field_name, schema)
        self.load_from = field_name + '[]'

    def _deserialize(self, value, attr, data):
        try:
            if isinstance(data, dict):
                value = data.get(attr)
            else:
                value = data.getall(attr)
        except KeyError:
            self.fail('invalid')
        return super()._deserialize(value, attr, data)


class JSONNested(fields.Nested):
    def _deserialize(self, value, attr, data):
        return super()._deserialize(ujson.loads(value), attr, data)

    def _serialize(self, value, attr, obj):
        return super()._serialize(value, attr, obj)


class StripMixin:
    def _deserialize(self, value, attr, data, **kwargs):
        if hasattr(value, 'strip'):
            value = value.strip()
        return super()._deserialize(value, attr, data, **kwargs)


class StripString(StripMixin, fields.String):
    pass


class StripEmail(StripMixin, fields.Email):
    pass


class NullableQueryParamMixin:
    def _deserialize(self, value, attr, data):
        if value == 'null':
            return None
        return super()._deserialize(value, attr, data)


class NullableBoolean(NullableQueryParamMixin, fields.Boolean):
    pass
