import glob
import json
import os
import re
import shutil
import socket
import tarfile
import tempfile

import ipaddress
import requests
import six
import vcr
from awacs.resolver.util import _resolve_host
from google.protobuf import field_mask_pb2
from six.moves import cPickle as pickle
from six.moves import range


class Node(object):
    EXTERNAL_USERS = 1
    NAMESPACE = 2
    IP = 3
    ENDPOINT_SET = 4
    GENCFG_WORLD = 5

    def __init__(self, type, ns_id=None, ip=None, es_id=None):
        assert type in (self.EXTERNAL_USERS, self.NAMESPACE, self.IP, self.ENDPOINT_SET, self.GENCFG_WORLD)
        if type == self.NAMESPACE:
            assert ns_id is not None
        if type == self.IP:
            assert ip is not None
        if type == self.ENDPOINT_SET:
            assert es_id is not None
        self.type = type
        self.ns_id = ns_id
        self.ip = ip
        self.es_id = es_id

    def __repr__(self):
        return str(self)

    def __str__(self):
        type_str = {
            self.EXTERNAL_USERS: 'EXTERNAL_USERS',
            self.IP: 'IP',
            self.ENDPOINT_SET: 'ENDPOINT_SET',
            self.NAMESPACE: 'NAMESPACE',
            self.GENCFG_WORLD: 'GENCFG_WORLD',
        }[self.type]
        ip = self.ip
        if ip is not None:
            ip = six.text_type(ip)
            if ':' in ip:
                ip = ipaddress.IPv6Address(ip).exploded
            else:
                ip = ipaddress.IPv4Address(ip).exploded
        return '{} {}'.format(type_str, self.ns_id or ip or self.es_id or '')

    def __eq__(self, other):
        return str(self) == str(other)

    def __hash__(self):
        return hash(str(self))


class Edge(object):
    def __init__(self, a, b):
        self.a = a
        self.b = b

    def __repr__(self):
        return str(self)

    def __str__(self):
        return '{} -> {}'.format(self.a, self.b)

    def __eq__(self, other):
        return str(self) == str(other)

    def __hash__(self):
        return hash(str(self))


class Processor(object):
    def __init__(self, awacs_client, l3mgr_client, nanny_client):
        """
        :type awacs: infra.awacs.tools.awacstoolslib.awacsclient.AwacsClient
        :type l3mgr_client: awacs.lib.l3mgrclient.L3MgrClient
        :type nanny_client: infra.awacs.tools.awacstoolslib.nannyclient.NannyClient
        """
        self.awacs_client = awacs_client
        self.l3mgr_client = l3mgr_client
        self.nanny_client = nanny_client
        self.puncher_token = os.getenv('PUNCHER_TOKEN')
        self.stat_vcr = vcr.VCR(
            cassette_library_dir='./awacsmap-cassettes',
            # record_mode='new_episodes',
            filter_headers=['authorization'],
            match_on=['method', 'scheme', 'host', 'port', 'path', 'query', 'raw_body'],

        )

    def process_pickled(self, pickle_path='./awacsmap.pickle'):
        with open(pickle_path, 'r') as f:
            data = pickle.load(f)
        ns_ids_to_edges = data['ns_ids_to_edges']
        gencfg_world_edges = data['gencfg_world_edges']
        portals_to_ns_id = data['portals_to_ns_id']
        all_edges = data['all_edges']
        print('all edges', len(all_edges))
        print('gencfg_world_edges', len(gencfg_world_edges))

        g_nodes = {'gencfg balancers', 'external users'}
        g_edges = set()

        external_ns_ids = set()
        internal_ns_ids = set()
        for ns_id, edges in six.iteritems(ns_ids_to_edges):
            g_nodes.add('{}'.format(ns_id))
            target_ns_ids = set()
            for edge in edges:
                assert edge.a.ns_id == ns_id or edge.b.ns_id == ns_id
                if edge.b.ns_id == ns_id and edge.a.type == Node.EXTERNAL_USERS:
                    external_ns_ids.add(ns_id)
                if edge.a.ns_id == ns_id and edge.b in portals_to_ns_id:
                    target_ns_ids.add(portals_to_ns_id[edge.b])
            if target_ns_ids:
                print(ns_id, 'addresses', ' '.join(sorted(target_ns_ids)))
                for target_ns_id in target_ns_ids:
                    g_edges.add((ns_id, target_ns_id))
            internal_ns_ids.update(target_ns_ids)

        gencfg_target_ns_ids = set()
        for edge in gencfg_world_edges:
            assert edge.a.type == Node.GENCFG_WORLD
            if edge.b in portals_to_ns_id:
                target_ns_id = portals_to_ns_id[edge.b]
                gencfg_target_ns_ids.add(target_ns_id)
                g_edges.add(('gencfg balancers', target_ns_id))
                internal_ns_ids.add(target_ns_id)
        print('gencfg balancers address', ' '.join(sorted(gencfg_target_ns_ids)))

        for target_ns_id in external_ns_ids:
            g_edges.add(('external users', target_ns_id))

        print()
        print('PROBLEMATIC NS IDS')
        for ns_id in internal_ns_ids & external_ns_ids:
            print(ns_id)
        with open('./awacsmap-graph.json', 'w') as f:
            json.dump({
                'nodes': sorted(g_nodes),
                'edges': sorted(g_edges),
            }, f)

    def process(self, pickle_path='./awacsmap.pickle'):
        ns_ids_to_edges = {}
        ns_ids_to_portals = {}
        portals_to_ns_id = {}
        all_edges = set()
        gencfg_world_edges = set()

        with self.stat_vcr.use_cassette('ids.yml'):
            namespace_ids = list(self.awacs_client.list_namespace_ids())
        total = len(namespace_ids)
        for i, namespace_id in enumerate(namespace_ids):
            with self.stat_vcr.use_cassette('ns_{}.yml'.format(namespace_id)):
                n = 3
                for attempt in range(n):
                    try:
                        edges, portals = self.process_ns(namespace_id)
                    except Exception as e:
                        print('failed to process', namespace_id, e)
                        if attempt == n - 1:
                            raise
                    else:
                        break
                print(namespace_id, ':', len(edges), 'edges', len(portals), 'portals', '{}/{}'.format(i, total))
                ns_ids_to_edges[namespace_id] = edges
                ns_ids_to_portals[namespace_id] = portals
                for p in portals:
                    if p in portals_to_ns_id:
                        print('portal {} is already used by {}'.format(p, portals_to_ns_id[p]))
                        raise RuntimeError('wtf')
                    else:
                        portals_to_ns_id[p] = namespace_id
                all_edges.update(edges)

        with self.stat_vcr.use_cassette('gencfg_balancers.yml'):
            edges, _ = self.process_gencfg_balancers()
            all_edges.update(edges)
            gencfg_world_edges.update(edges)

        with open(pickle_path, 'w') as f:
            pickle.dump({
                'portals_to_ns_id': portals_to_ns_id,
                'ns_ids_to_portals': ns_ids_to_portals,
                'ns_ids_to_edges': ns_ids_to_edges,
                'all_edges': all_edges,
                'gencfg_world_edges': gencfg_world_edges,
            }, f)

    def get_config_lua(self, service_id):
        data = self.nanny_client.get_service_runtime_attrs(service_id)
        for url_file in data['content']['resources']['url_files']:
            if url_file['local_path'] == 'config.lua':
                return requests.get(url_file['url']).content

    def find_target_endpoint_set_ids(self, content):
        """
        :param content: Lua
        """
        return set(re.findall('endpoint_set_id = "(.+)"', content, re.MULTILINE))

    def find_target_ips(self, content):
        """
        :param content: Lua
        """
        rv = set(re.findall(r'{ ".+"; \d+; .+; "(.+)"; }', content, re.MULTILINE))  # awacs balancers
        rv |= set(re.findall(r'cached_ip = "(.+)";', content, re.MULTILINE))  # gencfg balancers
        return rv

    def process_gencfg_balancers(self):
        gencfg_node = Node(Node.GENCFG_WORLD)
        edges = set()
        portals = set()

        # https://sandbox.yandex-team.ru/resource/2261809857/view
        url = 'https://proxy.sandbox.yandex-team.ru/resource/link/c77ccee2fadd4d77b45eca9c5f510766'
        downloaded_filename = '/tmp/balancer_configs_l7.tar.gz'
        with open(downloaded_filename, 'wb') as f:
            resp = requests.get(url, allow_redirects=True)
            f.write(resp.content)

        dirpath = tempfile.mkdtemp()
        try:
            with tarfile.open(downloaded_filename, "r") as tar:
                tar.extractall(path=dirpath)
            for path in glob.glob(dirpath + '/generated/l7-balancer/*'):
                with open(path) as f:
                    content = f.read()
                    for es_id in self.find_target_endpoint_set_ids(content):
                        edges.add(Edge(gencfg_node, Node(Node.ENDPOINT_SET, es_id=es_id)))
                    for ip in self.find_target_ips(content):
                        edges.add(Edge(gencfg_node, Node(Node.IP, ip=ip)))
        finally:
            shutil.rmtree(dirpath)

        return edges, portals

    def process_ns(self, ns_id):
        edges = set()
        portals = set()

        ns_node = Node(Node.NAMESPACE, ns_id=ns_id)
        for l3_balancer_pb in self.awacs_client.list_l3_balancers(ns_id):
            spec_pb = l3_balancer_pb.spec
            if spec_pb.incomplete:
                continue

            is_l3_external, l3_addrs = self.analyze_l3_balancer(spec_pb.l3mgr_service_id)
            if is_l3_external:
                edges.add(Edge(Node(Node.EXTERNAL_USERS), ns_node))

            for l3_addr in l3_addrs:
                portals.add(Node(Node.IP, ip=l3_addr))

        for balancer_pb in self.awacs_client.list_balancers(ns_id):
            spec_pb = balancer_pb.spec
            if spec_pb.incomplete:
                continue
            nanny_service_id = balancer_pb.spec.config_transport.nanny_static_file.service_id
            content = self.get_config_lua(nanny_service_id)
            for es_id in self.find_target_endpoint_set_ids(content):
                edges.add(Edge(ns_node, Node(Node.ENDPOINT_SET, es_id=es_id)))
            for ip in self.find_target_ips(content):
                edges.add(Edge(ns_node, Node(Node.IP, ip=ip)))

            data = self.nanny_client.get_service_info_attrs(nanny_service_id)
            yp_cluster = data['content'].get('yp_cluster')
            if yp_cluster:
                try:
                    for es_id in self.nanny_client.list_all_endpoint_set_ids(yp_cluster, nanny_service_id):
                        portals.add(Node(Node.ENDPOINT_SET, es_id=es_id))
                except:
                    print('list_all_endpoint_set_ids({}, {}) failed'.format(yp_cluster, nanny_service_id))
                    raise
                field_mask_pb = field_mask_pb2.FieldMask()
                field_mask_pb.paths.append('status')
                for pod_pb in self.nanny_client.iter_pods(yp_cluster, nanny_service_id, field_mask_pb=field_mask_pb):
                    for ip in re.findall('address: "(.+)"', six.text_type(pod_pb), re.MULTILINE):
                        portals.add(Node(Node.IP, ip=ip))
            else:
                resp_data = self.nanny_client.list_current_instances(nanny_service_id)
                for instance in resp_data:
                    fqdn = instance['container_hostname']
                    try:
                        ip = _resolve_host(fqdn, ip_version=socket.AF_INET6)
                    except Exception:
                        print('failed to resolve', fqdn)
                    else:
                        portals.add(Node(Node.IP, ip=ip))

        return edges, portals

    def is_there_a_rule_from_any_to(self, ip):
        # how to get puncher token: https://wiki.yandex-team.ru/noc/nocdev/puncher/api/
        url = ('https://puncher.yandex-team.ru/api/dynfw/rules?rules=exclude_rejected&'
               'sort=source&values=all&systems=&source=any&destination={}'.format(ip))
        resp = requests.get(url, headers={'Authorization': 'OAuth {}'.format(self.puncher_token)})
        resp.raise_for_status()
        return resp.json()['count'] == 1

    def _list_virtual_services(self, l3mgr_service_id):
        url = '/api/v1/service/{svc_id}/vs'.format(svc_id=l3mgr_service_id)
        n = 3
        for attempt in range(n):
            try:
                resp = self.l3mgr_client.get(url, request_timeout=30)
            except Exception as e:
                print('failed to call', url, e)
                if attempt == n - 1:
                    raise
            else:
                vss = resp['objects']
                return vss

    def analyze_l3_balancer(self, l3mgr_service_id):
        has_ipv4_addr = False
        one_addr_is_external = False
        all_addrs_are_external = True
        vss = self._list_virtual_services(l3mgr_service_id)
        seen_addrs = set()
        for vs in vss:
            ip = vs['ip']
            if ip in seen_addrs:
                continue
            if ':' not in ip:
                has_ipv4_addr = True
            is_addr_external = self.is_there_a_rule_from_any_to(ip)
            one_addr_is_external |= is_addr_external
            all_addrs_are_external &= is_addr_external
            seen_addrs.add(ip)
        if has_ipv4_addr and all_addrs_are_external:
            return 1, seen_addrs
        elif one_addr_is_external:
            return 0.5, seen_addrs
        else:
            return 0, seen_addrs
