# coding: utf-8
import collections
import itertools

import six
import ujson

from awacs.model import util


class L3Vector(object):
    __slots__ = ('balancer_version', 'backend_versions', 'endpoint_set_versions')

    def __init__(self, l3_balancer_version, backend_versions, endpoint_set_versions):
        """
        :type l3_balancer_version: L3BalancerVersion | None
        :type backend_versions: dict[six.text_type, BackendVersion]
        :type endpoint_set_versions: dict[six.text_type, EndpointSetVersion]
        """
        self.balancer_version = l3_balancer_version
        self.backend_versions = backend_versions
        self.endpoint_set_versions = endpoint_set_versions

    def __iter__(self):
        if self.balancer_version:
            yield self.balancer_version
        for version in itertools.chain(
                six.itervalues(self.backend_versions),
                six.itervalues(self.endpoint_set_versions)
        ):
            yield version

    def __repr__(self):
        return '{}(backends={}, endpoint_sets={})'.format(
            util.version_to_str(self.balancer_version) if self.balancer_version else 'NULL',
            ujson.dumps({id_: util.version_to_str(version)
                         for id_, version in six.iteritems(self.backend_versions)}),
            ujson.dumps({id_: util.version_to_str(version)
                         for id_, version in six.iteritems(self.endpoint_set_versions)}),
        )

    def __eq__(self, other):
        if not isinstance(other, self.__class__):
            return NotImplemented
        return (
                self.balancer_version == other.balancer_version and
                self.backend_versions == other.backend_versions and
                self.endpoint_set_versions == other.endpoint_set_versions
        )

    def __ne__(self, other):
        return not self.__eq__(other)

    def get_weak_hash(self):
        """
        :rtype: six.binary_type
        """
        h = 0
        if self.balancer_version:
            h = util.crc32(self.balancer_version.get_weak_hash(), h)
        for backend_id, backend_version in sorted(six.iteritems(self.backend_versions)):
            h = util.crc32(backend_version.get_weak_hash(), h)
        for endpoint_set_id, endpoint_set_version in sorted(six.iteritems(self.endpoint_set_versions)):
            h = util.crc32(endpoint_set_version.get_weak_hash(), h)
        return util.int_to_hex_bytes(h)

    def get_weak_hash_str(self):
        """
        :rtype: six.text_type
        """
        return self.get_weak_hash().decode('utf-8')

    def is_empty(self):
        return (
                not self.balancer_version and
                not self.backend_versions and
                not self.endpoint_set_versions
        )

    def greater_than(self, other):
        return self != other and (
                self.balancer_version >= other.balancer_version and
                all((self.backend_versions[backend_id] >= other.backend_versions[backend_id]
                     for backend_id in self.backend_versions if backend_id in other.backend_versions)) and
                all((self.endpoint_set_versions[endpoint_set_id] >= other.endpoint_set_versions[endpoint_set_id]
                     for endpoint_set_id in self.endpoint_set_versions if
                     endpoint_set_id in other.endpoint_set_versions))
        )

    def diff(self, to):
        """
        :type to: Vector
        :rtype: util.Diff
        """
        updated = set()
        added = set()
        removed = set()
        if self.balancer_version != to.balancer_version:
            if self.balancer_version is None:
                added.add(to.balancer_version)
            elif to.balancer_version is None:
                removed.add(self.balancer_version)
            else:
                updated.add((self.balancer_version, to.balancer_version))

        for backend_id, to_backend_version in six.iteritems(to.backend_versions):
            if backend_id in self.backend_versions:
                from_backend_version = self.backend_versions[backend_id]
                if from_backend_version != to_backend_version:
                    updated.add((from_backend_version, to_backend_version))
            else:
                added.add(to_backend_version)
        for backend_id, from_backend_version in six.iteritems(self.backend_versions):
            if backend_id not in to.backend_versions:
                removed.add(from_backend_version)

        for endpoint_set_id, to_endpoint_set_version in six.iteritems(to.endpoint_set_versions):
            if endpoint_set_id in self.endpoint_set_versions:
                from_endpoint_set_version = self.endpoint_set_versions[endpoint_set_id]
                if from_endpoint_set_version != to_endpoint_set_version:
                    updated.add((from_endpoint_set_version, to_endpoint_set_version))
            else:
                added.add(to_endpoint_set_version)
        for endpoint_set_id, from_endpoint_set_version in six.iteritems(self.endpoint_set_versions):
            if endpoint_set_id not in to.endpoint_set_versions:
                removed.add(from_endpoint_set_version)

        return util.Diff(updated=updated, added=added, removed=removed)

    def replace_balancer_version(self, l3_balancer_version):
        return L3Vector(l3_balancer_version=l3_balancer_version,
                        backend_versions=dict(self.backend_versions),
                        endpoint_set_versions=dict(self.endpoint_set_versions))

    def remove_balancer_version(self):
        return L3Vector(l3_balancer_version=None,
                        backend_versions=dict(self.backend_versions),
                        endpoint_set_versions=dict(self.endpoint_set_versions))

    def replace_backend_version(self, backend_id, backend_version):
        backend_versions = dict(self.backend_versions)
        backend_versions[backend_id] = backend_version
        return L3Vector(l3_balancer_version=self.balancer_version,
                        backend_versions=backend_versions,
                        endpoint_set_versions=dict(self.endpoint_set_versions))

    def remove_backend_version(self, backend_id):
        backend_versions = dict(self.backend_versions)
        del backend_versions[backend_id]
        return L3Vector(l3_balancer_version=self.balancer_version,
                        backend_versions=backend_versions,
                        endpoint_set_versions=dict(self.endpoint_set_versions))

    def replace_endpoint_set_version(self, endpoint_set_id, endpoint_set_version):
        endpoint_set_versions = dict(self.endpoint_set_versions)
        endpoint_set_versions[endpoint_set_id] = endpoint_set_version
        return L3Vector(l3_balancer_version=self.balancer_version,
                        backend_versions=dict(self.backend_versions),
                        endpoint_set_versions=endpoint_set_versions)

    def remove_endpoint_set_version(self, endpoint_set_id):
        endpoint_set_versions = dict(self.endpoint_set_versions)
        del endpoint_set_versions[endpoint_set_id]
        return L3Vector(l3_balancer_version=self.balancer_version,
                        backend_versions=dict(self.backend_versions),
                        endpoint_set_versions=endpoint_set_versions)

    def clone(self):
        return L3Vector(self.balancer_version,
                        dict(self.backend_versions),
                        dict(self.endpoint_set_versions))

    def has_no_backends(self):
        if not self.balancer_version:
            return True
        for backend_version in six.itervalues(self.backend_versions):
            if not backend_version.deleted:
                return False
        return True

    def omit_orphan_endpoint_sets(self):
        for endpoint_set_id in list(self.endpoint_set_versions):
            if endpoint_set_id not in self.backend_versions:
                del self.endpoint_set_versions[endpoint_set_id]


class L3BalancerVersion(collections.namedtuple('L3BalancerVersion', ['ctime', 'balancer_id', 'version'])):
    deleted = False

    def get_weak_hash(self):
        """
        :rtype: six.binary_type
        """
        return self.balancer_id.encode('utf-8') + b':' + self.version.encode('utf-8')

    @classmethod
    def from_pb(cls, pb):
        """
        :type pb: infra.awacs.proto.model_pb2.L3Balancer
        :rtype: L3BalancerVersion
        """
        return cls(pb.meta.mtime.ToMicroseconds(), pb.meta.id, pb.meta.version)

    @classmethod
    def from_rev_status_pb(cls, l3_balancer_id, pb):
        """
        :type l3_balancer_id: six.text_type
        :type pb: infra.awacs.proto.model_pb2.L3BalancerState.RevisionL3Status
        :rtype: L3BalancerVersion
        """
        return cls(pb.ctime.ToMicroseconds(), l3_balancer_id, pb.revision_id)

    def __repr__(self):
        return 'L3Bal(id={}, v={}, ctime={})'.format(self.balancer_id, self.version[:8], self.ctime)


class BackendVersion(collections.namedtuple('BackendVersion',
                                            ('ctime', 'backend_id', 'version', 'deleted'))):
    def get_weak_hash(self):
        """
        :rtype: six.binary_type
        """
        return self.backend_id.encode('utf-8') + b':' + self.version.encode('utf-8')

    @classmethod
    def from_pb(cls, pb):
        """
        :type pb: infra.awacs.proto.model_pb2.Backend
        :rtype: BackendVersion
        """
        return cls(ctime=pb.meta.mtime.ToMicroseconds(),
                   backend_id=pb.meta.id,
                   version=pb.meta.version,
                   deleted=pb.spec.deleted)

    @classmethod
    def from_rev_status_pb(cls, backend_id, pb):
        """
        :type backend_id: six.text_type
        :type pb: infra.awacs.proto.model_pb2.L3BalancerState.RevisionL3Status
        :rtype: BackendVersion
        """
        return cls(ctime=pb.ctime.ToMicroseconds(),
                   backend_id=backend_id,
                   version=pb.revision_id,
                   deleted=pb.deleted)

    def __repr__(self):
        return 'Back(id={}, v={}, ctime={})'.format(self.backend_id, self.version[:8], self.ctime)


class EndpointSetVersion(collections.namedtuple('EndpointSetVersion',
                                                ('ctime', 'endpoint_set_id', 'version', 'deleted'))):
    def get_weak_hash(self):
        """
        :rtype: six.binary_type
        """
        return self.endpoint_set_id.encode('utf-8') + b':' + self.version.encode('utf-8')

    @classmethod
    def from_pb(cls, pb):
        """
        :type pb: infra.awacs.proto.model_pb2.EndpointSet
        :rtype: EndpointSetVersion
        """
        return cls(ctime=pb.meta.mtime.ToMicroseconds(),
                   endpoint_set_id=pb.meta.id,
                   version=pb.meta.version,
                   deleted=pb.spec.deleted)

    @classmethod
    def from_rev_status_pb(cls, endpoint_set_id, pb):
        """
        :type endpoint_set_id: six.text_type
        :type pb: infra.awacs.proto.model_pb2.BalancerL3State.RevisionL3Status
        :rtype: EndpointSetVersion
        """
        return cls(ctime=pb.ctime.ToMicroseconds(),
                   endpoint_set_id=endpoint_set_id,
                   version=pb.revision_id,
                   deleted=pb.deleted)

    def __repr__(self):
        return 'ESet(id={}, v={}, ctime={})'.format(self.endpoint_set_id, self.version[:8], self.ctime)
