import abc
import six

from awacs.lib.vectors import version as ver
from awacs.model import util


VECTORS_REGISTRY = {}


class ValidationError(Exception):
    def __init__(self, message, cause=None):
        self.cause = cause
        super(ValidationError, self).__init__(message)


class VectorMeta(abc.ABCMeta):
    def __new__(mcs, name, bases, attrs):
        if name in VECTORS_REGISTRY:
            raise RuntimeError('Vector {} already exists: {!r}'.format(name, VECTORS_REGISTRY[name]))

        slots = set()
        if '__slots__' in attrs:
            slots.update(attrs['__slots__'])  # we also want to preserve slots defined by users
        if bases == (object,):
            slots.add('__dict__')  # and we want our instances to be mockable (very useful in tests)

        version_classes = attrs['__version_classes__']
        if version_classes:
            attrs['__version_field_names__'] = []
            assert isinstance(version_classes, (tuple, list))  # assure deterministic order
            for version_class in version_classes:
                assert issubclass(version_class, ver.Version)
                slots.add(version_class.vector_field_name)
                attrs['__version_field_names__'].append(version_class.vector_field_name)
            main_version_class = attrs['__main_version_class__']
            assert issubclass(main_version_class, ver.Version)
            slots.add(main_version_class.vector_field_name)
            attrs['__main_version_field_name__'] = main_version_class.vector_field_name

        attrs['__slots__'] = sorted(slots)

        klass = VECTORS_REGISTRY[name] = abc.ABCMeta.__new__(mcs, name, bases, attrs)
        return klass


class Vector(six.with_metaclass(VectorMeta, object)):
    # inheritors MUST define these fields
    __main_version_class__ = None  # type: type(ver.Version)
    __version_classes__ = None  # type: list[type(ver.Version)] or tuple[type(ver.Version)]

    # inheritors MUST NOT define these fields
    __main_version_field_name__ = None  # type: six.text_type
    __version_field_names__ = None  # type: list[six.text_type]

    def __init__(self, *args, **kwargs):
        if not hasattr(self, self.__main_version_field_name__):
            raise RuntimeError('Vector is missing a required field {}'.format(self.__main_version_field_name__))
        for version_field_name in self.__version_field_names__:
            if getattr(self, version_field_name, None) is None:
                raise RuntimeError('Vector is missing a required field {}'.format(version_field_name))

    @property
    def main_version(self):
        """
        :rtype ver.Version
        """
        return getattr(self, self.__main_version_field_name__)

    @main_version.setter
    def main_version(self, value):
        """
        :type value: ver.Version
        """
        setattr(self, self.__main_version_field_name__, value)

    def __iter__(self):
        if self.main_version is not None:
            yield self.main_version
        for version_field_name in self.__version_field_names__:
            for version in six.itervalues(self.must_get_version_dict(version_field_name)):
                yield version

    def __repr__(self):
        contents = []
        for version_class in self.__version_classes__:
            versions = self.must_get_version_dict(version_class.vector_field_name)
            contents.append('{}={}'.format(version_class.pb_field_name, util.versions_dict_to_str(versions)))
        return '{}({})'.format(util.version_to_str(self.main_version), ', '.join(contents))

    def __eq__(self, other):
        if not isinstance(other, Vector):
            raise NotImplementedError
        if self.main_version != other.main_version:
            return False
        for version_field_name in self.__version_field_names__:
            if self.must_get_version_dict(version_field_name) != other.must_get_version_dict(version_field_name):
                return False
        return True

    def __ne__(self, other):
        return not (self == other)

    def get_weak_hash(self):
        """
        :rtype: six.binary_type
        """
        h = 0
        if self.main_version is not None:
            h = util.crc32(self.main_version.get_weak_hash(), h)
        for version_field_name in self.__version_field_names__:
            for _, version in sorted(six.iteritems(self.must_get_version_dict(version_field_name))):
                h = util.crc32(version.get_weak_hash(), h)
        return util.int_to_hex_bytes(h)

    def get_weak_hash_str(self):
        """
        :rtype: six.text_type
        """
        return self.get_weak_hash().decode('utf-8')

    def is_empty(self):
        for version_field_name in self.__version_field_names__:
            if self.must_get_version_dict(version_field_name):
                return False
        return self.main_version is None

    def diff(self, to):
        """
        :type to: Vector
        """
        if not isinstance(to, Vector):
            raise NotImplementedError
        updated = set()
        added = set()
        removed = set()
        if self.main_version != to.main_version:
            if self.main_version is None:
                added.add(to.main_version)
            elif to.main_version is None:
                removed.add(self.main_version)
            else:
                updated.add((self.main_version, to.main_version))

        for version_field_name in self.__version_field_names__:
            to_versions = to.must_get_version_dict(version_field_name)
            from_versions = self.must_get_version_dict(version_field_name)
            for version_id, to_version in six.iteritems(to_versions):
                if version_id in from_versions:
                    from_version = from_versions[version_id]
                    if from_version != to_version:
                        updated.add((from_version, to_version))
                else:
                    added.add(to_version)
            for version_id, from_version in six.iteritems(from_versions):
                if version_id not in to_versions:
                    removed.add(from_version)

        return util.Diff(updated=updated, added=added, removed=removed)

    def must_get_version_dict(self, version_field_name):
        return getattr(self, version_field_name)

    def must_get_version_dict_by_version(self, version):
        version_class = type(version)
        if version_class != self.__main_version_class__ and version_class not in self.__version_classes__:
            raise RuntimeError('Unsupported version class {}'.format(version_class))
        return self.must_get_version_dict(version_class.vector_field_name)

    def get_version_item(self, version_field_name, version_id):
        if version_field_name == self.__main_version_field_name__:
            return self.main_version
        assert version_id is not None, 'version_id must be present if ' \
                                       'version_field_name != self.__main_version_field_name__'
        return self.must_get_version_dict(version_field_name).get(version_id)

    def get_version_item_by_version(self, version):
        return self.get_version_item(version.vector_field_name, version.id)
