# encoding: UTF-8

import marshmallow as ma
import sqlalchemy as sa
import sqlalchemy.orm as orm


class ModelSchemaOpts(ma.SchemaOpts):
    def __init__(self, meta):
        super(ModelSchemaOpts, self).__init__(meta)
        self.model = getattr(meta, 'model', None)


class ModelSchema(ma.Schema):
    OPTIONS_CLASS = ModelSchemaOpts

    def __init__(self, *args, **kwargs):
        self.provider = kwargs.pop('provider', None)
        super(ModelSchema, self).__init__(*args, **kwargs)

    def _get_pk(self, data):
        mapper = sa.inspect(self.opts.model)  # type: orm.Mapper
        pk = []
        for column in mapper.primary_key:
            value = data.get(column.name)
            if value is None:
                return None
            pk.append(value)
        return tuple(pk) if len(pk) > 1 else pk[0]

    @ma.pre_dump
    def _refresh_dirty(self, data):
        session = orm.object_session(data)  # type: orm.Session
        if session and session.is_modified(data):
            session.flush([data])

    @ma.post_load
    def _make_instance(self, data):
        instance = None
        pk = self._get_pk(data)

        if pk is not None and self.provider is not None:
            instance = self.provider(pk)

        if instance is None:
            instance = self.opts.model()

        for attr, value in data.items():
            setattr(instance, attr, value)
        return instance


class EnumField(ma.fields.Field):
    default_error_messages = {
        '__invalid': 'Must be one of: {choice}'
    }

    def __init__(self, enum_cls, *args, **kwargs):
        super(EnumField, self).__init__(*args, **kwargs)
        self.enum_cls = enum_cls
        self.__name_choice = ', '.join(map(
            repr,
            self.enum_cls.__members__,
        ))
        self.__value_choice = ', '.join(map(
            repr,
            (v.value for v in self.enum_cls.__members__.values())
        ))
        self.by_name = kwargs.get('by_name', False)

    def _serialize(self, value, attr, obj):
        try:
            value = self.enum_cls(value)
        except ValueError:
            self.fail('__invalid', choice='')

        return value.name if self.by_name else value.value

    def _deserialize(self, value, attr, data):
        if self.by_name:
            try:
                return getattr(self.enum_cls, value)
            except AttributeError:
                self.fail('__invalid', choice=self.__name_choice)
        else:
            try:
                return self.enum_cls(value)
            except ValueError:
                self.fail('__invalid', choice=self.__value_choice)
