import logging

from infra.oxcart.lib import awacsutil
from infra.oxcart.proto import cluster_pb2
import envoy.api.v2.core.address_pb2 as address_pb2
import envoy.api.v2.endpoint.endpoint_pb2 as endpoint_pb2
# import envoy.api.v2.discovery_pb2 as discovery_pb2

# envoy
import envoy.api.v2.cds_pb2 as cds_pb2
import envoy.api.v2.lds_pb2 as lds_pb2

log = logging.getLogger('cluster_maintainer')


class ClusterState(object):
    def _validate(self, backend):
        if backend.type not in (cluster_pb2.Backend.AWACS, cluster_pb2.Backend.STATIC):
            raise NotImplemented('Nothing else than AWACS and STATIC is not supported')

    def __init__(self, config):
        self.cluster_config = config.cluster
        self.cluster = None
        self.version_info = 0

        for backend in self.cluster_config.backends:
            self._validate(backend)

    @staticmethod
    def _validate_weight(weight):
        """ Envoy supports weights in range 1-128 """
        weight = int(weight)
        if weight > 128:
            weight = 128
        elif weight < 1:
            weight = 1
        return weight

    def refresh_awacs(self, backend):

        backend.ClearField('endpoints')
        for back_id in backend.awacs.ids:
            awacs_data = awacsutil.AwacsClient().get_endpoint_set(
                backend.awacs.namespace,
                back_id,
            )
            for instance in awacs_data[0].endpoint_set.spec.instances:
                e = backend.endpoints.add()
                e.addr = instance.ipv6_addr
                e.weight = self._validate_weight(instance.weight)

    def refresh_static(self, backend):
        backend.ClearField('endpoints')
        for back in backend.static_backends.endpoints:
            e = backend.endpoints.add()
            e.addr = back.addr
            e.weight = self._validate_weight(back.weight)

    def refresh(self):
        try:
            c = cluster_pb2.Cluster()
            c.MergeFrom(self.cluster_config)

            for backend in c.backends:
                if backend.type == cluster_pb2.Backend.AWACS:
                    self.refresh_awacs(backend)
                elif backend.type == cluster_pb2.Backend.STATIC:
                    self.refresh_static(backend)

            if c != self.cluster_config:
                self.version_info += 1
                self.cluster_config = c
                log.info('Config updated:')
                log.info(self.cluster_config)

        except:
            log.exception('cluster state refresh failed')
            return False

        return True

    def render_envoy_clusters(self, version_info):
        if self.version_info == version_info:
            return (self.version_info, None)

        if not self.cluster_config.backends:
            return (self.version_info, cds_pb2.Cluster())

        clusters = []

        for backend in self.cluster_config.backends:
            c = cds_pb2.Cluster()
            c.name = backend.id
            c.connect_timeout.seconds = 1

            c.type = cds_pb2.Cluster.STATIC
            c.lb_policy = cds_pb2.Cluster.ROUND_ROBIN
            c.load_assignment.cluster_name = c.name

            locality_lb_endpoints = endpoint_pb2.LocalityLbEndpoints()

            for e in backend.endpoints:
                lb_endpoint = locality_lb_endpoints.lb_endpoints.add()
                lb_endpoint.endpoint.address.socket_address.address = e.addr
                lb_endpoint.endpoint.address.socket_address.port_value = backend.local_port
                lb_endpoint.endpoint.address.socket_address.protocol = address_pb2.SocketAddress.TCP
                lb_endpoint.load_balancing_weight.value = e.weight

            c.load_assignment.endpoints.extend([locality_lb_endpoints])
            clusters.append(c)

        return (self.version_info, clusters)

    def render_envoy_listeners(self, version_info):
        if self.version_info == version_info:
            return (self.version_info, None)

        if not self.cluster_config.backends:
            return (self.version_info, lds_pb2.Listener())

        listeners = []
        i = 0

        for backend in self.cluster_config.backends:
            l = lds_pb2.Listener()
            l.name = "listener_{}".format(i)
            l.address.socket_address.protocol = address_pb2.SocketAddress.TCP
            l.address.socket_address.address = backend.listen_addr
            l.address.socket_address.port_value = backend.local_port

            lfc = l.filter_chains.add()
            lf = lfc.filters.add()
            lf.name = 'envoy.tcp_proxy'
            lf.config["stat_prefix"] = "ingress_tcp"
            lf.config["cluster"] = backend.id
            listeners.append(l)
            i += 1

        return (self.version_info, listeners)
