# coding: utf-8
from __future__ import absolute_import, division, print_function, unicode_literals

from copy import deepcopy
from six import with_metaclass
from six.moves import zip_longest

from flask_restful import fields, marshal

from travel.rasp.bus.api.connectors.fields.validation import Validatable, ValidationErrorBundle, validate
from travel.rasp.bus.api.connectors.fields.fields import Managed


def merge(xs, ys, func=lambda x, y: x):
    if xs is None:
        return deepcopy(ys)
    if isinstance(xs, dict) and isinstance(ys, dict):
        xs.update({k: merge(xs.get(k), v, func) for k, v in ys.items()})
        return xs
    elif isinstance(xs, (list, tuple)) and isinstance(ys, (list, tuple)):
        return [merge(x, y, func) for x, y in zip_longest(xs, ys)]
    return func(xs, ys)


class EntityMeta(type):
    def __new__(meta, name, bases, attrs):
        obj_fields = attrs.get('fields', None)
        base_fields = getattr(bases[0], 'fields', None)
        attrs['fields'] = EntityMeta.override(obj_fields, base_fields)
        return super(EntityMeta, meta).__new__(meta, name, bases, attrs)

    @staticmethod
    def override(obj, base):
        def func(obj, base):
            if isinstance(base, Managed):
                return EntityMeta.override(obj, base.field)
            elif isinstance(base, fields.Nested):
                return EntityMeta.override(obj, base.nested)
            elif isinstance(obj, fields.Nested):
                obj.nested = EntityMeta.override(obj.nested, base)
            elif isinstance(obj, fields.List) and isinstance(base, fields.List):
                obj.container = EntityMeta.override(obj.container, base.container)
            return obj
        return merge(obj, base, func)


class Entity(with_metaclass(EntityMeta, fields.Nested, Validatable)):

    def __init__(self, **kwargs):
        super(Entity, self).__init__(self.fields, **kwargs)

    @classmethod
    def output(cls, key, obj):
        value = obj if key is None else fields.get_value(key, obj)
        return cls.format(value)

    @classmethod
    def format(cls, value):
        return marshal(value, cls.fields)

    @classmethod
    def validate(cls, value, path):
        return validate(cls.validation_fields(), value, path)

    @classmethod
    def enrich(cls, obj, **kwargs):
        return merge(obj, kwargs)

    @staticmethod
    def _enrich(cls, obj, **kwargs):
        if isinstance(obj, (list, tuple)):
            return [cls.enrich(x, **kwargs) for x in obj]
        return cls.enrich(obj, **kwargs)

    @classmethod
    def init(cls, obj, **kwargs):
        instance = cls.output(None, Entity._enrich(cls, obj, **kwargs))
        exceptions = cls.validate(instance, [cls.__module__, cls.__name__])
        if exceptions:
            raise ValidationErrorBundle(
                instance=instance, exceptions=exceptions
            )
        return instance

    @classmethod
    def validation_fields(cls):
        while getattr(cls.__base__, 'fields', None) is not None:
            cls = cls.__base__
        return cls.fields
