from sandbox.common import patterns


__all__ = ("Serializable",)


class RegisterSubclasses(type):
    def __init__(cls, name, bases, namespace):
        if name in cls.CLASS_REGISTRY:
            raise RuntimeError("'{}' is already registered".format(name))
        cls.CLASS_REGISTRY[name] = cls
        super(RegisterSubclasses, cls).__init__(name, bases, namespace)


class Serializable(object):
    SERIALIZABLE_ATTRS = ()
    CLASS_REGISTRY = {}
    __metaclass__ = RegisterSubclasses

    class Prototype(object):
        pass

    def __getstate__(self):
        return {
            attr: self.__dict__[attr]
            for attr in self.__class__.SERIALIZABLE_ATTRS
            if attr in self.__dict__
        }

    def __setstate__(self, state):
        self.__dict__.update({
            attr: state[attr]
            for attr in self.__class__.SERIALIZABLE_ATTRS
            if attr in state
        })

    def encode(self):
        def wrap(obj, state):
            cls = obj if isinstance(obj, type) else type(obj)
            return {"__cls": cls.__name__, "__state": state}

        def recurse(obj):
            if not hasattr(obj, "__getstate__"):
                return obj
            state = obj.__getstate__()
            if isinstance(state, dict):
                for key, value in state.iteritems():
                    state[key] = recurse(value)
            elif isinstance(state, list):
                state = map(recurse, state)
            return wrap(obj, state)

        return recurse(self)

    @classmethod
    def decode(cls, state):
        def unwrap(class_name, state):
            kls = cls.CLASS_REGISTRY.get(class_name)
            if state is None:
                return kls
            obj = (kls if issubclass(kls, patterns.Abstract) else cls.Prototype)()
            obj.__class__ = kls
            obj.__setstate__(state)
            return obj

        def recurse(state):
            if isinstance(state, dict) and "__cls" in state:
                inner = state["__state"]
                if isinstance(inner, dict):
                    for key, value in inner.iteritems():
                        inner[key] = recurse(value)
                elif isinstance(inner, list):
                    inner = map(recurse, inner)
                return unwrap(state["__cls"], inner)
            else:
                return state

        return recurse(state)
