"""
Copyright 2016-2017 Maxim Kulkin
Copyright 2018 Alex Rothberg and contributors

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""
from typing import Any, ClassVar, Dict, Optional

from marshmallow import Schema, ValidationError

from sendr_utils.schemas.base import BaseSchema


class SchemaMappingWithFallback(Dict[str, Schema]):
    """
    Пользоваться строго по назначению - если хочется fallback схему для нераспознанных значений детерминанта схемы.

    Подробности
    OneOfSchema ожидает type_schemas: ClassVar[Dict[str, Schema]].
    Для кастомных "словарей" полезно пользоваться collections.abc.MutableMapping.
    Однако `not issubclass(MutableMapping[str, Schema], Dict[str, Schema])`.
    Если бы мы наследовались от MutableMapping, то решение было бы более чистым. Достаточно было бы
    переопределить __getitem__. У dict'а же метод get не вызывает __getitem__ (видимо, вызывается нативный код).
    """
    def __init__(self, default_schema: Schema, /, *args: Any, **kwargs: Any):
        super().__init__(*args, **kwargs)
        self._default_schema = default_schema

    def __getitem__(self, k: str) -> Schema:
        try:
            return super().__getitem__(k)
        except KeyError:
            return self._default_schema

    def get(self, k: str, default: Optional[Schema] = None, /) -> Schema:
        assert default is None
        return self[k]


# Наследуемся от sendr_utils.BaseSchema, а не от marshmallow.Schema, чтобы унаследовать Meta
class OneOfSchema(BaseSchema):
    """
    This is a special kind of schema that actually multiplexes other schemas
    based on object type. When serializing values, it uses get_obj_type() method
    to get object type name. Then it uses `type_schemas` name-to-Schema mapping
    to get schema for that particular object type, serializes object using that
    schema and adds an extra "type" field with name of object type.
    Deserialization is reverse.
    Example:
        class Foo(object):
            def __init__(self, foo):
                self.foo = foo
        class Bar(object):
            def __init__(self, bar):
                self.bar = bar
        class FooSchema(marshmallow.Schema):
            foo = marshmallow.fields.String(required=True)
            @marshmallow.post_load
            def make_foo(self, data, **kwargs):
                return Foo(**data)
        class BarSchema(marshmallow.Schema):
            bar = marshmallow.fields.Integer(required=True)
            @marshmallow.post_load
            def make_bar(self, data, **kwargs):
                return Bar(**data)
        class MyUberSchema(marshmallow.OneOfSchema):
            type_schemas = {
                'foo': FooSchema,
                'bar': BarSchema,
            }
            def get_obj_type(self, obj):
                if isinstance(obj, Foo):
                    return 'foo'
                elif isinstance(obj, Bar):
                    return 'bar'
                else:
                    raise Exception('Unknown object type: %s' % repr(obj))
        MyUberSchema().dump([Foo(foo='hello'), Bar(bar=123)], many=True)
        # => [{'type': 'foo', 'foo': 'hello'}, {'type': 'bar', 'bar': 123}]
    You can control type field name added to serialized object representation by
    setting `type_field` class property.
    """

    type_field: ClassVar[str] = "type"
    type_field_remove: ClassVar[bool] = True
    type_schemas: ClassVar[Dict[str, Schema]] = {}

    def get_obj_type(self, obj):
        """Returns name of object schema"""
        return obj.__class__.__name__

    def dump(self, obj, *, many=None, **kwargs):
        errors = {}
        result_data = []
        result_errors = {}
        many = self.many if many is None else bool(many)
        if not many:
            result = result_data = self._dump(obj, **kwargs)
        else:
            for idx, o in enumerate(obj):
                try:
                    result = self._dump(o, **kwargs)
                    result_data.append(result)
                except ValidationError as error:
                    result_errors[idx] = error.normalized_messages()

        result = result_data
        errors = result_errors

        if not errors:
            return result
        else:
            exc = ValidationError(errors, data=obj)
            raise exc

    def _dump(self, obj, *, update_fields=True, **kwargs):
        obj_type = self.get_obj_type(obj)
        if not obj_type:
            return (
                None,
                {"_schema": "Unknown object class: %s" % obj.__class__.__name__},
            )

        type_schema = self.type_schemas.get(obj_type)
        if not type_schema:
            return None, {"_schema": "Unsupported object type: %s" % obj_type}

        schema = type_schema if isinstance(type_schema, Schema) else type_schema()

        schema.context.update(getattr(self, "context", {}))

        result = schema.dump(obj, many=False, **kwargs)
        if result is not None:
            result.data[self.type_field] = obj_type
        return result

    def load(self, data, *, many=None, partial=None, **kwargs):
        errors = {}
        result_data = []
        result_errors = {}
        many = self.many if many is None else bool(many)
        if partial is None:
            partial = self.partial
        if not many:
            try:
                result = result_data = self._load(
                    data, partial=partial, **kwargs
                )
            except ValidationError as error:
                result_errors = error.normalized_messages()
        else:
            for idx, item in enumerate(data):
                try:
                    result = self._load(item, partial=partial, **kwargs)
                    result_data.append(result)
                except ValidationError as error:
                    result_errors[idx] = error.normalized_messages()

        result = result_data
        errors = result_errors

        if not errors:
            return result
        else:
            exc = ValidationError(errors, data=data)
            raise exc

    def _load(self, data, *, partial=None, **kwargs):
        if not isinstance(data, dict):
            raise ValidationError({"_schema": "Invalid data type: %s" % data})

        data = dict(data)

        data_type = data.get(self.type_field)
        if self.type_field in data and self.type_field_remove:
            data.pop(self.type_field)

        if not data_type:
            raise ValidationError(
                {self.type_field: ["Missing data for required field."]}
            )

        try:
            type_schema = self.type_schemas.get(data_type)
        except TypeError:
            # data_type could be unhashable
            raise ValidationError({self.type_field: ["Invalid value: %s" % data_type]})
        if not type_schema:
            raise ValidationError(
                {self.type_field: ["Unsupported value: %s" % data_type]}
            )

        schema = type_schema if isinstance(type_schema, Schema) else type_schema()

        schema.context.update(getattr(self, "context", {}))

        return schema.load(data, many=False, partial=partial, **kwargs)

    def validate(self, data, *, many=None, partial=None):
        try:
            self.load(data, many=many, partial=partial)
        except ValidationError as ve:
            return ve.messages
        return {}
