# coding: utf-8
import six

from awacs.model.util import find_rev_status_by_revision_id
from infra.awacs.proto import model_pb2
from .vector import L3BalancerVersion, BackendVersion, EndpointSetVersion


class L3RevStatus(object):
    __slots__ = ('pb', )

    def __init__(self, pb):
        self.pb = pb

    def set_validated(self, status, message=''):
        if self.pb.validated.status != status:
            self.pb.validated.status = status
            self.pb.validated.message = message
            return True
        else:
            return False

    def modify_in_progress(self, modifier):
        if not callable(modifier):
            raise RuntimeError('modifier is not callable')
        return modifier(self.pb.in_progress)


class L3Entity(object):
    __slots__ = ('pb', )

    def __init__(self, pb):
        self.pb = pb

    def set_active_rev(self, version):
        if isinstance(version, (L3BalancerVersion, BackendVersion, EndpointSetVersion)):
            version = version.version
        elif isinstance(version, six.string_types):
            pass
        else:
            raise RuntimeError('Unexpected version type: {}'.format(type(version)))
        rv = False
        for l3_status_pb in self.pb.l3_statuses:
            if l3_status_pb.revision_id == version:
                if l3_status_pb.active.status != 'True':
                    l3_status_pb.active.status = 'True'
                    rv = True
            else:
                if l3_status_pb.active.status != 'False':
                    l3_status_pb.active.status = 'False'
                    rv = True
        return rv

    def omit_revs(self, filter_):
        updated_l3_status_pbs = []
        for l3_status_pb in self.pb.l3_statuses:
            if not filter_(l3_status_pb):
                updated_l3_status_pbs.append(l3_status_pb)
        if list(self.pb.l3_statuses) != updated_l3_status_pbs:
            del self.pb.l3_statuses[:]
            self.pb.l3_statuses.extend(updated_l3_status_pbs)
            return True
        else:
            return False

    def update_revs(self, cb):
        rv = False
        for l3_status_pb in self.pb.l3_statuses:
            if cb(l3_status_pb):
                rv = True
        return rv

    def add_new_if_missing(self, version):
        if isinstance(version, (L3BalancerVersion, BackendVersion, EndpointSetVersion)):
            revision_id = version.version
        elif isinstance(version, six.string_types):
            revision_id = version
        else:
            raise RuntimeError('Unexpected version type: {}'.format(type(version)))

        if self.select_rev(version):
            return False

        l3_status_pb = self.pb.l3_statuses.add(revision_id=revision_id)
        if isinstance(version, (BackendVersion, EndpointSetVersion)):
            l3_status_pb.deleted = version.deleted
        l3_status_pb.ctime.FromMicroseconds(version.ctime)
        l3_status_pb.validated.status = 'Unknown'
        l3_status_pb.in_progress.status = 'False'
        l3_status_pb.active.status = 'False'
        return True

    def select_rev(self, version):
        if isinstance(version, (L3BalancerVersion, BackendVersion, EndpointSetVersion)):
            revision_id = version.version
        elif isinstance(version, six.string_types):
            revision_id = version
        else:
            raise RuntimeError('Unexpected version type: {}'.format(type(version)))
        rev_pb = find_rev_status_by_revision_id(self.pb.l3_statuses, revision_id)
        if rev_pb is None:
            return None
        else:
            return L3RevStatus(rev_pb)


class L3BalancerStateHandler(object):
    __slots__ = ('pb', )

    def __init__(self, pb):
        """
        :type pb: model_pb2.L3BalancerState
        """
        self.pb = pb

    def iter_versions(self):
        for l3_status_pb in self.pb.l3_balancer.l3_statuses:
            yield L3BalancerVersion.from_rev_status_pb(l3_balancer_id=self.pb.l3_balancer_id, pb=l3_status_pb)
        for backend_id, backend_status_pb in sorted(self.pb.backends.items()):
            for l3_status_pb in backend_status_pb.l3_statuses:
                yield BackendVersion.from_rev_status_pb(backend_id=backend_id, pb=l3_status_pb)
        for endpoint_set_id, endpoint_set_status_pb in sorted(self.pb.endpoint_sets.items()):
            for l3_status_pb in endpoint_set_status_pb.l3_statuses:
                yield EndpointSetVersion.from_rev_status_pb(endpoint_set_id=endpoint_set_id, pb=l3_status_pb)

    def select(self, version):
        if isinstance(version, L3BalancerVersion):
            return self.select_l3_balancer()
        elif isinstance(version, BackendVersion):
            return self.select_backend(version.backend_id)
        elif isinstance(version, EndpointSetVersion):
            return self.select_endpoint_set(version.endpoint_set_id)
        else:
            raise RuntimeError('Unexpected version type: {}'.format(type(version)))

    def set_active_rev(self, version):
        return self.select(version).set_active_rev(version)

    def select_rev(self, version):
        return self.select(version).select_rev(version)

    def add_new_if_missing(self, version):
        return self.select(version).add_new_if_missing(version)

    def select_l3_balancer(self):
        return L3Entity(self.pb.l3_balancer)

    def select_backend(self, backend_id):
        return L3Entity(self.pb.backends[backend_id])

    def select_endpoint_set(self, endpoint_set_id):
        return L3Entity(self.pb.endpoint_sets[endpoint_set_id])

    def remove_backends_wo_revs(self):
        """
        :rtype: bool
        """
        updated = False
        for backend_id, state_pb in self.pb.backends.items():
            if len(state_pb.l3_statuses) == 0:
                del self.pb.backends[backend_id]
                updated = True
        return updated

    def remove_endpoint_sets_wo_revs(self):
        """
        :rtype: bool
        """
        updated = False
        for endpoint_set_id, state_pb in self.pb.endpoint_sets.items():
            if len(state_pb.l3_statuses) == 0:
                del self.pb.endpoint_sets[endpoint_set_id]
                updated = True
        return updated

    def delete_backend(self, backend_id):
        """
        :param six.text_type backend_id:
        :rtype: bool
        """
        assert isinstance(backend_id, six.string_types)
        deleted = False
        if backend_id in self.pb.backends:
            del self.pb.backends[backend_id]
            deleted = True
        return deleted

    def delete_endpoint_set(self, endpoint_set_id):
        """
        :param six.text_type endpoint_set_id:
        :rtype: bool
        """
        assert isinstance(endpoint_set_id, six.string_types)
        deleted = False
        if endpoint_set_id in self.pb.endpoint_sets:
            del self.pb.endpoint_sets[endpoint_set_id]
            deleted = True
        return deleted

    def set_ignore_existing_l3mgr_config(self, value, author, comment):
        if self.pb.ignore_existing_l3mgr_config.value == value:
            return False
        self.pb.ignore_existing_l3mgr_config.value = value
        self.pb.ignore_existing_l3mgr_config.mtime.GetCurrentTime()
        self.pb.ignore_existing_l3mgr_config.author = author
        self.pb.ignore_existing_l3mgr_config.comment = comment
        return True

    def increment_skip_count(self, vector_hash):
        self.pb.skip_counts[vector_hash] += 1
