import sys

try:
    __import__('pkg_resources').require('requests')
    import requests

    requests.packages.urllib3.disable_warnings()
except BaseException:
    if getattr(sys, 'is_standalone_binary', False):
        raise

import json
from collections import defaultdict


class HQResolverError(Exception):
    def __init__(self, source, error):
        super(HQResolverError, self).__init__("HQ '{}' {}".format(source, error))


class HQServiceNotFoundError(HQResolverError):
    pass


class ServiceResolvePolicy(object):
    NANNY_ENV_PROD = "prod"
    NANNY_ENV_DEV = "dev"
    NANNY_ENV_ADMIN = "admin"

    NANNY_ENV_TAG_TO_KEY = {
        'nanny_env_prod': NANNY_ENV_PROD,
        'nanny_env_admin': NANNY_ENV_ADMIN,
        'nanny_env_dev': NANNY_ENV_DEV
    }

    NANNY_ENV_TO_NANNY_URL = {
        NANNY_ENV_PROD: 'https://nanny.yandex-team.ru',
        NANNY_ENV_DEV: 'https://dev-nanny.yandex-team.ru',
        NANNY_ENV_ADMIN: 'https://admin-nanny.yandex-team.ru',
    }

    SD_INSTANCE_SOURCE = "SD"

    @classmethod
    def get(cls, service_id, timeouts=(5, 10, 15), nanny_env=NANNY_ENV_PROD):
        exc = None
        for timeout in timeouts:
            try:
                nanny_url = cls.NANNY_ENV_TO_NANNY_URL[nanny_env]

                url = nanny_url + '/v2/services/_helpers/get_active_revision_id/{}/'.format(service_id)
                r = requests.get(url, timeout=timeout)
                exc = None
                break
            except requests.exceptions.Timeout:
                exc = 'timeout'
            except requests.exceptions.ConnectionError as e:
                exc = str(e)

        if exc is not None:
            raise HQResolverError('Nanny API', exc)

        if r.status_code != requests.codes.ok:
            try:
                msg = r.json()['msg']
            except BaseException:
                msg = 'Unknown error'

            if r.status_code == requests.codes.not_found:
                raise HQServiceNotFoundError('ServiceResolvePolicy get: ' + service_id, 'Not found')

            raise HQResolverError('Nanny API', msg)
        resp = r.json()
        return cls(
            service_id=service_id,
            active_revision_id=resp['active_revision_id'],
            instances_source=resp['current_instances_source'],
            yp_clusters=resp['yp_clusters'],
            pod_set_id=resp['pod_set_id'],
            hq_clusters=resp['hq_clusters']
        )

    @classmethod
    def get_nanny_env(cls, tags):
        if not tags:
            return cls.NANNY_ENV_PROD

        for t in tags:
            if t.startswith('nanny_env_'):
                if t not in cls.NANNY_ENV_TAG_TO_KEY:
                    raise HQResolverError('Nanny API', 'wrong tag: {}, valid: {}'.format(
                        t, ", ".join(cls.NANNY_ENV_TAG_TO_KEY.keys())))
                tags.remove(t)
                return cls.NANNY_ENV_TAG_TO_KEY[t]

        return cls.NANNY_ENV_PROD

    def __init__(self, service_id, active_revision_id, instances_source, yp_clusters, pod_set_id, hq_clusters):
        self.service_id = service_id
        self.active_revision_id = active_revision_id
        self.instances_source = instances_source
        self.yp_clusters = yp_clusters
        self.pod_set_id = pod_set_id
        self.hq_clusters = hq_clusters

    def is_use_sd(self):
        return self.instances_source == self.SD_INSTANCE_SOURCE


class SDPod(object):
    @classmethod
    def from_response(cls, yp_cluster, pod_data):
        return cls(
            yp_cluster=yp_cluster,
            pod_id=pod_data['id'],
            node_id=pod_data.get('node_id'),
            persistent_fqdn=pod_data.get('dns', {}).get('persistent_fqdn'),
            target_revisions=pod_data.get('iss_conf_summaries', {}),
            current_revisions=pod_data.get('agent', {}).get('iss_summary', {}).get('state_summaries', {}),
        )

    def __init__(self, yp_cluster, pod_id, node_id, persistent_fqdn, target_revisions, current_revisions):
        self.yp_cluster = yp_cluster
        self.pod_id = pod_id
        self.node_id = node_id
        self.persistent_fqdn = persistent_fqdn
        self.target_revisions = target_revisions
        self.current_revisions = current_revisions


class SDResolver(object):
    API_SD_URL = "http://sd.yandex.net:8080/resolve_pods/json"
    ACTIVE_CONFIG_ALIAS = 'ACTIVE'

    ITAG2YPCLUSTERS = {
        'a_geo_sas': ['sas'],
        'a_geo_man': ['man'],
        'a_geo_vla': ['vla'],
        'a_geo_msk': ['iva', 'myt'],
        'a_dc_sas': ['sas'],
        'a_dc_man': ['man'],
        'a_dc_vla': ['vla'],
        'a_dc_iva': ['iva'],
        'a_dc_myt': ['myt']
    }

    def __init__(self):
        super(SDResolver, self).__init__()
        self.timeouts = [5, 10, 15]

    def get_service_resolve_policy(self, service_id_or_policy, tags=None):
        if isinstance(service_id_or_policy, ServiceResolvePolicy):
            return service_id_or_policy
        nanny_env = ServiceResolvePolicy.get_nanny_env(tags)
        return ServiceResolvePolicy.get(service_id_or_policy, nanny_env=nanny_env)

    def get_hosts(self, service_id_or_policy, revision_id=None, tags=None):
        policy = self.get_service_resolve_policy(service_id_or_policy, tags)
        yp_clusters, tags = self._filter_yp_clusters_and_tags(policy.yp_clusters, tags)
        pods = self._get_pods(policy.service_id, revision_id, policy.pod_set_id, yp_clusters)
        return [p.node_id for p in pods]

    def get_slots(self, service_id_or_policy, revision_id=None, hosts=None, tags=None):
        policy = self.get_service_resolve_policy(service_id_or_policy, tags)
        yp_clusters, tags = self._filter_yp_clusters_and_tags(policy.yp_clusters, tags)
        pods = self._get_pods(policy.service_id, revision_id, policy.pod_set_id, yp_clusters)
        slots = []
        for pod in pods:
            hostname = pod.node_id
            if not hosts or hostname in hosts:
                host, port = hostname, pod.pod_id
                if host and port:
                    slots.append(
                        (
                            hostname,
                            '{}@{}'.format(port, host),
                            '{}#{}'.format(policy.service_id, revision_id) if revision_id else ''
                        )
                    )
        return slots

    def get_instances(self, service_id_or_policy, revision_id=None, hosts=None, tags=None):
        policy = self.get_service_resolve_policy(service_id_or_policy, tags)
        yp_clusters, tags = self._filter_yp_clusters_and_tags(policy.yp_clusters, tags)
        result = defaultdict(set)
        pods = self._get_pods(policy.service_id, revision_id, policy.pod_set_id, yp_clusters)
        for pod in pods:
            hostname = pod.node_id
            if not hosts or hostname in hosts:
                host, port = hostname, pod.pod_id
                if revision_id == self.ACTIVE_CONFIG_ALIAS and policy.active_revision_id:
                    if policy.active_revision_id in pod.current_revisions:
                        result[hostname].add((None, '{}:{}@{}'.format(host, port, policy.active_revision_id)))
                elif revision_id:
                    # instances for specified revision only
                    result[hostname].add((None, '{}:{}@{}'.format(host, port, revision_id)))
                else:
                    # instances for all known revisions
                    for rev_id, rev in pod.current_revisions.iteritems():
                        result[hostname].add((None, '{}:{}@{}'.format(host, port, rev_id)))

        return result

    def get_configurations(self, service_id_or_policy):
        policy = self.get_service_resolve_policy(service_id_or_policy)
        pods = self._get_pods(policy.service_id, None, policy.pod_set_id, policy.yp_clusters)
        configs = defaultdict(lambda: defaultdict(int))
        for pod in pods:
            for rev_id, rev in pod.current_revisions.iteritems():
                if rev.get('ready', {}).get('status') == 'True':
                    configs[rev_id]['Active'] += 1
                if rev.get('installed', {}).get('status') == 'True':
                    configs[rev_id]['Installed'] += 1
                configs[rev_id]['Total'] += 1
        return configs

    def get_mtns(self, service_id_or_policy, revision_id=None, tags=None):
        policy = self.get_service_resolve_policy(service_id_or_policy, tags)
        yp_clusters, tags = self._filter_yp_clusters_and_tags(policy.yp_clusters, tags)
        pods = self._get_pods(policy.service_id, revision_id, policy.pod_set_id, yp_clusters)
        return [p.persistent_fqdn for p in pods]

    def _filter_yp_clusters_and_tags(self, yp_clusters, tags):
        if not tags:
            return yp_clusters, None

        for tag in tags:
            yp_clusters = self.ITAG2YPCLUSTERS.get(tag)
            if yp_clusters:
                return yp_clusters, None

        raise HQResolverError('SDResolver', 'does not support filter by tags: {}'.format(tags))

    def _get_pods(self, service_id, revision_id, pod_set_id, yp_clusters):
        if revision_id and revision_id == self.ACTIVE_CONFIG_ALIAS:
            revision_id = None
        pods = []
        for yp_cluster in yp_clusters:
            result = self._do_yp_sd_request(service_id, json_data={
                'cluster_name': yp_cluster,
                'pod_set_id': pod_set_id,
                'client_name': 'skynet',
            })
            for d in result.get('pod_set', {}).get('pods', []):
                sd_pod = SDPod.from_response(yp_cluster, d)
                if not sd_pod.node_id:
                    continue
                if revision_id and revision_id not in sd_pod.current_revisions:
                    continue
                pods.append(sd_pod)

        if not pods:
            raise HQServiceNotFoundError('Service (YP SD)' + service_id, 'Not found')
        return pods

    def _do_yp_sd_request(self, service_id, json_data):
        for timeout in self.timeouts:
            try:
                r = requests.post(
                    self.API_SD_URL,
                    data=json.dumps(json_data),
                    timeout=timeout,
                    headers={'Content-Type': 'application/json', 'Accept-Encoding': 'gzip'}
                    )
                exc = None
                break
            except requests.exceptions.Timeout:
                exc = 'timeout'
            except requests.exceptions.ConnectionError as e:
                exc = str(e)

        if exc is not None:
            raise HQResolverError('YP SD', exc)
        if r.status_code != requests.codes.ok:
            try:
                msg = r.json()['msg']
            except BaseException:
                msg = 'Unknown error'

            if r.status_code == requests.codes.not_found:
                raise HQServiceNotFoundError('Service _do_request: ' + service_id, 'Not found')

            raise HQResolverError('YP SD', msg)
        return r.json()


class HQResolver(object):
    API_NANNY_URL = 'https://nanny.yandex-team.ru/v2/services/_helpers/get_active_revision_id/{}/'
    API_HQ_FIND_CLUSTERS = 'http://federated.yandex-team.ru/rpc/federated/FindClusters/'
    API_HQ_FIND_INSTANCES = '{}rpc/instances/FindInstances/'
    ACTIVE_CONFIG_ALIAS = 'ACTIVE'

    ITAG2DC = {
        'a_geo_sas': 'sas_prod',
        'a_geo_man': 'man_prod',
        'a_geo_vla': 'vla_prod',
        'a_geo_msk': 'msk_prod',
        'a_dc_sas': 'sas_prod',
        'a_dc_man': 'man_prod',
        'a_dc_vla': 'vla_prod',
        'a_dc_iva': 'msk_prod',
        'a_dc_myt': 'msk_prod'
    }

    DOMAIN_TO_CUT = ['.iva.yp-c.yandex.net', '.myt.yp-c.yandex.net']

    def __init__(self, use_service_resolve_policy=None):
        super(HQResolver, self).__init__()
        self._use_service_resolve_policy = use_service_resolve_policy
        self._sd_resolver = SDResolver()
        self.timeouts = [5, 10, 15]

    def get_service_resolve_policy(self, service_id, tags=None):
        nanny_env = ServiceResolvePolicy.get_nanny_env(tags)
        if self._use_service_resolve_policy is False:
            return None
        return ServiceResolvePolicy.get(service_id, nanny_env=nanny_env)

    def get_hosts(self, service_id, revision_id=None, tags=None):
        # return: set of hostnames
        resolve_policy = self.get_service_resolve_policy(service_id, tags=tags)
        if resolve_policy and resolve_policy.is_use_sd():
            return self._sd_resolver.get_hosts(resolve_policy, revision_id, tags)
        instances = self._do_request(service_id, revision_id, tags)
        return [i['spec']['nodeName'] for i in instances]

    def get_slots(self, service_id, revision_id=None, hosts=None, tags=None):
        # return: host, slot, configurationId
        resolve_policy = self.get_service_resolve_policy(service_id, tags=tags)
        if resolve_policy and resolve_policy.is_use_sd():
            return self._sd_resolver.get_slots(resolve_policy, revision_id, tags)
        instances = self._do_request(service_id, revision_id, tags)
        slots = []

        for i in instances:
            hostname = i['spec']['nodeName']
            if not hosts or hostname in hosts:
                host, port, _ = self._split_instance_id(i['meta']['id'])
                if not host:
                    host = hostname
                if host and port:
                    slots.append(
                        (
                            hostname,
                            '{}@{}'.format(port, host),
                            '{}#{}'.format(service_id, revision_id) if revision_id else ''
                        )
                    )

        return slots

    def get_instances(self, service_id, revision_id=None, hosts=None, tags=None):
        resolve_policy = self.get_service_resolve_policy(service_id, tags=tags)
        if resolve_policy and resolve_policy.is_use_sd():
            return self._sd_resolver.get_instances(resolve_policy, revision_id, hosts, tags)

        instances = self._do_request(service_id, revision_id, tags)
        result = defaultdict(set)

        for i in instances:
            hostname = i['spec']['nodeName']
            if not hosts or hostname in hosts:
                host, port, _ = self._split_instance_id(i['meta']['id'])
                if not host:
                    host = hostname
                if revision_id == self.ACTIVE_CONFIG_ALIAS:
                    for rev in i['status']['revision']:
                        if rev['ready']['status'] == 'True':
                            result[hostname].add((None, '{}:{}@{}'.format(host, port, rev['id'])))
                elif revision_id:
                    # instances for specified revision only
                    result[hostname].add((None, '{}:{}@{}'.format(host, port, revision_id)))
                else:
                    # instances for all known revisions
                    for rev in i['spec']['revision']:
                        result[hostname].add((None, '{}:{}@{}'.format(host, port, rev['id'])))

        return result

    def get_configurations(self, service_id):
        # return: {config: {Active: count, Installed: count, Total: count}}
        resolve_policy = self.get_service_resolve_policy(service_id)
        if resolve_policy and resolve_policy.is_use_sd():
            return self._sd_resolver.get_configurations(resolve_policy)
        instances = self._do_request(service_id, None, None)
        configs = defaultdict(lambda: defaultdict(int))

        for i in instances:
            for rev in i['status']['revision']:
                if rev['ready']['status'] == 'True':
                    configs[rev['id']]['Active'] += 1
                if rev['installed']['status'] == 'True':
                    configs[rev['id']]['Installed'] += 1

                configs[rev['id']]['Total'] += 1

        return configs

    def get_mtns(self, service_id, revision_id=None, tags=None):
        # return: set of MTN FQDNs
        resolve_policy = self.get_service_resolve_policy(service_id, tags=tags)
        if resolve_policy and resolve_policy.is_use_sd():
            return self._sd_resolver.get_mtns(resolve_policy, revision_id, tags)
        instances = self._do_request(service_id, revision_id, tags)
        return [i['spec']['hostname'] for i in instances]

    def _do_request(self, service_id, revision_id, tags):
        ready_only = False
        if revision_id and revision_id == self.ACTIVE_CONFIG_ALIAS:
            ready_only = True
            revision_id = None

        clusters = self._find_clusters()
        service_dcs = self._get_service_dc_by_tag(tags)

        # we support filtering for current config only
        if not revision_id and not service_dcs:
            try:
                service_dcs = self._get_service_dcs(service_id)
            except HQResolverError:
                # ignore all Nanny errors except NOT FOUND
                pass

        instances = []
        req_filter = {
            'filter': {
                'service_id': service_id,
                'ready_only': ready_only
            }
        }

        for c_name, c_url in clusters.items():
            if service_dcs is not None and c_name not in service_dcs:
                # there is no instances in this DC, skip it
                continue
            instances.extend(self._find_instances(c_url, req_filter, revision_id, tags))

        if not instances:
            raise HQServiceNotFoundError('Service: ' + service_id, 'Not found')

        return instances

    def _find_clusters(self):
        response = self._do_post_request(self.API_HQ_FIND_CLUSTERS, '{}')
        return {i['meta']['name']: i['spec']['endpoint']['url'] for i in response['value']}

    def _find_instances(self, base_url, req_filter, revision_id, tags):
        instances = []
        response = self._do_post_request(self.API_HQ_FIND_INSTANCES.format(base_url), json.dumps(req_filter))

        # filter by revision_id if required
        if revision_id or tags:
            for i in response['instance']:
                for rev in i['spec']['revision']:
                    if revision_id:
                        if rev['id'] == revision_id:
                            if tags and tags - set(rev['tags']):
                                continue
                            instances.append(i)
                            break
                    elif not (tags - set(rev['tags'])):
                        instances.append(i)
        else:
            instances = response['instance']

        return instances

    def _do_post_request(self, url, data):
        exc = None
        for timeout in self.timeouts:
            try:
                r = requests.post(
                    url,
                    data=data,
                    timeout=timeout,
                    headers={'Content-Type': 'application/json', 'Accept-Encoding': 'gzip'}
                )
                exc = None
                break
            except requests.exceptions.Timeout:
                exc = 'timeout'
            except requests.exceptions.ConnectionError as e:
                exc = str(e)

        if exc is not None:
            raise HQResolverError(url, exc)

        if r.status_code != requests.codes.ok:
            raise HQResolverError(url, r.status_code)

        return r.json()

    def _get_service_dc_by_tag(self, tags):
        if tags:
            for tag in tags:
                cluster = self.ITAG2DC.get(tag)
                if cluster:
                    return [cluster]

    def _get_service_dcs(self, service_id):
        exc = None
        for timeout in self.timeouts:
            try:
                r = requests.get(self.API_NANNY_URL.format(service_id), timeout=timeout)
                exc = None
                break
            except requests.exceptions.Timeout:
                exc = 'timeout'
            except requests.exceptions.ConnectionError as e:
                exc = str(e)

        if exc is not None:
            raise HQResolverError('Nanny API', exc)

        if r.status_code != requests.codes.ok:
            try:
                msg = r.json()['msg']
            except BaseException:
                msg = 'Unknown error'

            if r.status_code == requests.codes.not_found:
                raise HQServiceNotFoundError('Service: ' + service_id, 'Not found')

            raise HQResolverError('Nanny API', msg)

        return r.json()['hq_clusters']

    def _split_instance_id(self, instance_id):
        host, port, service_id = None, None, None
        try:
            if ':' in instance_id:
                host, port_service_id = instance_id.split(':')
                port, service_id = port_service_id.split('@')
            else:
                port, _ = instance_id.split('@')
                port = self._cut_domain(port)
                service_id = None
                host = None

        except BaseException:
            # something went wrong, ignore this instance
            pass

        return host, port, service_id

    def _cut_domain(self, port):
        for domain in self.DOMAIN_TO_CUT:
            if port.endswith(domain):
                return port[:-len(domain)]

        return port
