# coding: utf-8
import collections

import gevent
import inject
import ujson

from awacs.lib.nannyclient import INannyClient, NannyClient
from awacs.lib.strutils import to_full_id
from awacs.model import cache
from awacs.model.util import clone_pb
from awacs.wrappers.base import Holder
from awacs.wrappers.main import IncludeUpstreams
from infra.awacs.proto import model_pb2
from .base import NamespaceAspectsUpdater, BalancerIsTooLargeError
from ..balancer.generator import get_included_full_backend_ids


class L3BalancerNode(object):
    __slots__ = ('namespace_id', 'id', 'included_full_balancer_ids', 'included_full_backend_ids')

    def __init__(self, namespace_id, id,
                 included_full_balancer_ids=frozenset(), included_full_backend_ids=frozenset()):
        self.namespace_id = namespace_id
        self.id = id
        self.included_full_balancer_ids = included_full_balancer_ids
        self.included_full_backend_ids = included_full_backend_ids

    def to_dict(self):
        return {
            'type': 'l3_balancer',
            'id': self.id,
            'namespace_id': self.namespace_id,
            'included_balancer_ids': sorted(['/'.join(full_id) for full_id in self.included_full_balancer_ids]),
            'included_backend_ids': sorted(['/'.join(full_id) for full_id in self.included_full_backend_ids]),
        }


class BalancerNode(object):
    __slots__ = (
        'namespace_id', 'id', 'included_full_upstream_ids', 'included_full_backend_ids',
        'included_full_domain_ids')

    def __init__(
            self, namespace_id, id, included_full_upstream_ids=frozenset(),
            included_full_backend_ids=frozenset(), included_full_domain_ids=frozenset()):
        self.namespace_id = namespace_id
        self.id = id
        self.included_full_upstream_ids = included_full_upstream_ids
        self.included_full_backend_ids = included_full_backend_ids
        self.included_full_domain_ids = included_full_domain_ids

    def to_dict(self):
        return {
            'type': 'balancer',
            'id': self.id,
            'namespace_id': self.namespace_id,
            'included_upstream_ids': full_id_list(self.included_full_upstream_ids),
            'included_backend_ids': full_id_list(self.included_full_backend_ids),
            'included_domain_ids': full_id_list(self.included_full_domain_ids),
        }


def full_id_list(full_ids):
    return sorted(['/'.join(full_id) for full_id in full_ids])


class UpstreamNode(object):
    __slots__ = ('namespace_id', 'id', 'included_full_backend_ids')

    def __init__(self, namespace_id, id, included_full_backend_ids=frozenset()):
        self.namespace_id = namespace_id
        self.id = id
        self.included_full_backend_ids = included_full_backend_ids

    def to_dict(self):
        return {
            'type': 'upstream',
            'id': self.id,
            'namespace_id': self.namespace_id,
            'included_backend_ids': sorted(['/'.join(full_id) for full_id in self.included_full_backend_ids]),
        }


class BackendNode(object):
    __slots__ = ('namespace_id', 'id')

    def __init__(self, namespace_id, id):
        self.namespace_id = namespace_id
        self.id = id

    def to_dict(self):
        return {
            'type': 'backend',
            'id': self.id,
            'namespace_id': self.namespace_id,
        }


class DomainNode(object):
    __slots__ = ('namespace_id', 'id', 'included_upstream_ids')

    def __init__(self, namespace_id, id, included_upstream_ids):
        self.namespace_id = namespace_id
        self.id = id
        self.included_upstream_ids = included_upstream_ids

    def to_dict(self):
        return {
            'type': 'domain',
            'id': self.id,
            'namespace_id': self.namespace_id,
            'included_upstream_ids': full_id_list(self.included_upstream_ids)
        }


def l3_balancer_to_node(l3_balancer_pb, backend_ids_to_full_balancer_ids):
    """
    :type l3_balancer_pb: model_pb2.L3Balancer
    :type backend_ids_to_full_balancer_ids: dict[str, str]
    :rtype: L3BalancerNode
    """
    namespace_id = l3_balancer_pb.meta.namespace_id
    l3_balancer_id = l3_balancer_pb.meta.id
    included_full_backend_ids = set()
    included_full_balancer_ids = set()
    if l3_balancer_pb.spec.real_servers.type == l3_balancer_pb.spec.real_servers.BACKENDS:
        for backend_pb in l3_balancer_pb.spec.real_servers.backends:
            backend_id = backend_pb.id
            if backend_id in backend_ids_to_full_balancer_ids:
                full_balancer_ids = backend_ids_to_full_balancer_ids[backend_id]
                included_full_balancer_ids.update(full_balancer_ids)
            else:
                included_full_backend_ids.add((namespace_id, backend_pb.id))
    elif l3_balancer_pb.spec.real_servers.type == l3_balancer_pb.spec.real_servers.BALANCERS:
        for balancer_pb in l3_balancer_pb.spec.real_servers.balancers:
            included_full_balancer_ids.add((namespace_id, balancer_pb.id))

    return L3BalancerNode(namespace_id=namespace_id,
                          id=l3_balancer_id,
                          included_full_balancer_ids=included_full_balancer_ids,
                          included_full_backend_ids=included_full_backend_ids)


def balancer_state_pb_to_node(balancer_pb, balancer_state_pb, all_ns_domains_full_ids):
    """
    :type balancer_pb: model_pb2.Balancer
    :type balancer_state_pb: model_pb2.BalancerState
    :param set[tuple[str, str]] all_ns_domains_full_ids: all domains to add to dict if balancer includes domains
    :rtype: BalancerNode
    """
    namespace_id = balancer_pb.meta.namespace_id
    balancer_id = balancer_pb.meta.id
    included_full_upstream_ids = set()
    included_full_backend_ids = set()
    included_full_domain_ids = set()
    for module in Holder(balancer_pb.spec.yandex_balancer.config).walk_chain(visit_branches=True):
        if module.includes_domains():
            included_full_domain_ids.update(all_ns_domains_full_ids)
        else:
            included_full_upstream_ids = set(
                (namespace_id, upstream_id) for upstream_id in balancer_state_pb.upstreams)
        # we use a different backend collection method than upstream_pb_to_node for historical reasons
        # to update RevisionGraphIndex'es (for the new api) we use the same method as in upstream_pb_to_node
        if module.includes_backends():
            included_full_backend_ids.update(
                module.include_backends.get_included_full_backend_ids(namespace_id))

    included_full_backend_ids = included_full_backend_ids
    return BalancerNode(namespace_id=namespace_id, id=balancer_id,
                        included_full_upstream_ids=included_full_upstream_ids,
                        included_full_backend_ids=included_full_backend_ids,
                        included_full_domain_ids=included_full_domain_ids,
                        )


def upstream_pb_to_node(upstream_pb):
    """
    :type upstream_pb: model_pb2.Upstream
    :rtype: UpstreamNode
    """
    namespace_id = upstream_pb.meta.namespace_id
    upstream_id = upstream_pb.meta.id
    included_full_backend_ids = get_included_full_backend_ids(upstream_pb)
    return UpstreamNode(namespace_id=namespace_id, id=upstream_id,
                        included_full_backend_ids=included_full_backend_ids)


def backend_pb_to_node(backend_pb):
    """
    :type backend_pb: model_pb2.Backend
    :rtype: BackendNode
    """
    namespace_id = backend_pb.meta.namespace_id
    backend_id = backend_pb.meta.id
    return BackendNode(namespace_id=namespace_id, id=backend_id)


def domain_pb_to_node(domain_pb, all_ns_upstream_full_ids):
    """
    :type domain_pb: model_pb2.Domain
    :param set[tuple[six.text_type, six.text_type]] all_ns_upstream_full_ids: All upstreams in the namespace to return if upstreams.type is `ALL`
    :rtype: DomainNode
    """
    namespace_id = domain_pb.meta.namespace_id

    include_upstreams = IncludeUpstreams(
        domain_pb.spec.yandex_balancer.config.include_upstreams)
    included_upstream_ids = include_upstreams.get_included_upstream_ids(
        str(namespace_id), all_ns_upstream_full_ids)

    return DomainNode(
        namespace_id=namespace_id,
        id=domain_pb.meta.id,
        included_upstream_ids=included_upstream_ids)


class GraphAspectsUpdater(NamespaceAspectsUpdater):
    _nanny_client = inject.attr(INannyClient)  # type: NannyClient

    _cache = inject.attr(cache.IAwacsCache)  # type: cache.AwacsCache

    MAX_UPSTREAMS_NUMBER = 1000

    def get_aspects_name(self):
        return 'graph'

    def get_inclusion_graph_json(self, namespace_id):
        c = cache.IAwacsCache.instance()

        nodes = []
        frontier = set()
        visited = set()

        service_ids_to_full_balancer_ids = collections.defaultdict(set)
        backend_ids_to_full_balancer_ids = collections.defaultdict(set)

        for balancer_pb in c.list_all_balancers(namespace_id=namespace_id):
            if balancer_pb.spec.incomplete:
                continue
            balancer_id = balancer_pb.meta.id
            service_id = balancer_pb.spec.config_transport.nanny_static_file.service_id
            service_ids_to_full_balancer_ids[service_id].add((namespace_id, balancer_id))
            frontier.add(('balancer', namespace_id, balancer_pb.meta.id))

        for l3_balancer_pb in c.list_all_l3_balancers(namespace_id=namespace_id):
            frontier.add(('l3_balancer', namespace_id, l3_balancer_pb.meta.id))

        upstream_pbs = c.list_all_upstreams(namespace_id=namespace_id)
        if len(upstream_pbs) > self.MAX_UPSTREAMS_NUMBER:
            raise BalancerIsTooLargeError(
                'Balancer is too big (number of upstreams is larger than {})'.format(self.MAX_UPSTREAMS_NUMBER))
        all_ns_upstream_ids = set()
        for upstream_pb in upstream_pbs:
            frontier.add(('upstream', namespace_id, upstream_pb.meta.id))
            all_ns_upstream_ids.add(to_full_id(namespace_id, upstream_pb.meta.id))

        all_domains_full_ids = set()
        for domain_pb in c.list_all_domains(namespace_id=namespace_id):
            frontier.add(('domain', namespace_id, domain_pb.meta.id))
            all_domains_full_ids.add(to_full_id(namespace_id, domain_pb.meta.id))

        for backend_pb in c.list_all_backends(namespace_id=namespace_id):
            backend_id = backend_pb.meta.id
            frontier.add(('backend', namespace_id, backend_id))
            if backend_pb.spec.selector.type == backend_pb.spec.selector.NANNY_SNAPSHOTS:
                for snapshot_pb in backend_pb.spec.selector.nanny_snapshots:
                    service_id = snapshot_pb.service_id
                    if service_id in service_ids_to_full_balancer_ids:
                        backend_ids_to_full_balancer_ids[backend_id].update(
                            service_ids_to_full_balancer_ids[service_id])
            if backend_pb.spec.selector.type == backend_pb.spec.selector.BALANCERS:
                backend_ids_to_full_balancer_ids[backend_id].update(
                    {backend_balancer_pb.id for backend_balancer_pb in backend_pb.spec.selector.balancers}
                )

        while frontier:
            new_frontier = set()
            for item in frontier:
                gevent.idle()
                if item in visited:
                    continue
                item_type, namespace_id, item_id = item
                try:
                    if item_type == 'l3_balancer':
                        l3_balancer_pb = c.must_get_l3_balancer(namespace_id=namespace_id, l3_balancer_id=item_id)
                        if l3_balancer_pb.spec.incomplete:
                            continue
                        node = l3_balancer_to_node(l3_balancer_pb,
                                                   backend_ids_to_full_balancer_ids=backend_ids_to_full_balancer_ids)
                    elif item_type == 'balancer':
                        balancer_pb = c.must_get_balancer(namespace_id=namespace_id, balancer_id=item_id)
                        if balancer_pb.spec.incomplete:
                            continue
                        balancer_state_pb = c.must_get_balancer_state(namespace_id=namespace_id, balancer_id=item_id)
                        node = balancer_state_pb_to_node(balancer_pb, balancer_state_pb, all_domains_full_ids)
                    elif item_type == 'upstream':
                        upstream_pb = c.must_get_upstream(namespace_id=namespace_id, upstream_id=item_id)
                        node = upstream_pb_to_node(upstream_pb)
                    elif item_type == 'backend':
                        backend_pb = c.get_backend(namespace_id=namespace_id, backend_id=item_id)
                        if backend_pb is not None and backend_pb.meta.is_system.value:
                            continue
                        node = BackendNode(namespace_id=namespace_id, id=item_id)
                    elif item_type == 'domain':
                        domain_pb = c.must_get_domain(namespace_id=namespace_id, domain_id=item_id)
                        if domain_pb.spec.incomplete:
                            continue
                        node = domain_pb_to_node(domain_pb, all_ns_upstream_ids)
                    else:
                        raise AssertionError('unknown item type {}'.format(item_type))
                except Exception:
                    self._log.warn('Failed to make a graph node from %s %s/%s', item_type, namespace_id, item_id,
                                   exc_info=True)
                    continue

                if item_type in ('l3_balancer', 'balancer', 'upstream'):
                    new_frontier.update(('backend', namespace_id, backend_id)
                                        for namespace_id, backend_id in node.included_full_backend_ids
                                        if ('backend', namespace_id, backend_id) not in visited)
                if item_type == 'balancer':
                    new_frontier.update(('upstream', namespace_id, upstream_id)
                                        for namespace_id, upstream_id in node.included_full_upstream_ids
                                        if ('upstream', namespace_id, upstream_id) not in visited)

                nodes.append(node)
                visited.add(item)
            frontier = new_frontier
        nodes.sort(key=lambda node: (type(node).__name__, node.namespace_id, node.id))
        data = [node.to_dict() for node in nodes]
        gevent.idle()
        return ujson.dumps(data)

    def update(self, namespace_pb, aspects_set_content_pb):
        """
        :type namespace_pb: model_pb2.Namespace
        :type aspects_set_content_pb: model_pb2.NamespaceAspectsSetContent
        """
        namespace_id = namespace_pb.meta.id

        content_pb = aspects_set_content_pb.graph.content
        prev_content_pb = clone_pb(content_pb)
        content_pb.Clear()

        content_pb.inclusion_graph_json = self.get_inclusion_graph_json(namespace_id)

        return prev_content_pb != content_pb
