# coding: utf8
from __future__ import absolute_import, division, print_function, unicode_literals

import logging
import socket

from retry import retry

from infra.yp_service_discovery.python.resolver.resolver import Resolver
from infra.yp_service_discovery.api import api_pb2


log = logging.getLogger(__name__)


class YpEndpoints(object):
    """
    Обертка для Endpoints resolver via YP Service Discovery
    - https://a.yandex-team.ru/arc/trunk/arcadia/infra/yp_service_discovery/python/resolver
    - https://wiki.yandex-team.ru/yp/discovery/

    Пример - получить хосты всех бэкендов сервиса rzd_proxy (из Ya.Deploy):
    yp_endpoins = YpEndpoints(client_name='myservice', endpoint_set_id='rzd_proxy.DeployUnit1')
    hosts = yp_endpoins.get_hosts()
    """

    ALL_DATACENTERS = ('sas', 'vla', 'man', 'myt', 'iva')

    def __init__(self, client_name, endpoint_set_id, datacenters=ALL_DATACENTERS, timeout=5):
        self.endpoint_set_id = endpoint_set_id
        self.datacenters = datacenters
        self.client_name = '{}:{}'.format(client_name, socket.gethostname())
        self.timeout = timeout

        self._resolver = None

    @property
    def resolver(self):
        if self._resolver is None:
            self._resolver = Resolver(self.client_name, timeout=self.timeout)

        return self._resolver

    def get_hosts(self):
        hosts = []
        for dc, endpoint_res in self.get_endpoints().items():
            for endpoint in endpoint_res.endpoint_set.endpoints:
                if endpoint.ready:
                    hosts.append(endpoint.fqdn)

        return hosts

    def get_ip6_addresses_by_dc(self):
        return {
            dc: [endpoint.ip6_address for endpoint in endpoint_res.endpoint_set.endpoints if endpoint.ready]
            for dc, endpoint_res in self.get_endpoints().items()
        }

    def get_endpoints(self):
        result_by_dc = {}
        for dc in self.datacenters:
            result = self.get_endpoints_for_dc(dc)
            result_by_dc[dc] = result

        return result_by_dc

    @retry(tries=3, delay=1)
    def get_endpoints_for_dc(self, dc):
        request = api_pb2.TReqResolveEndpoints()
        request.endpoint_set_id = self.endpoint_set_id
        request.cluster_name = dc
        log.info('Resolving endpoints for %s %s', self.endpoint_set_id, dc)
        endpoint_res = self.resolver.resolve_endpoints(request)
        log.info('Found %s endpoint for %s %s', len(endpoint_res.endpoint_set.endpoints), self.endpoint_set_id, dc)

        return endpoint_res
