import six
import logging
import funcsigs

import sys
import platform

import nanny_rpc_client
from nanny_rpc_client.exceptions import NotFoundError

from infra.nanny.yp_lite_api.proto import pod_sets_api_pb2
from infra.nanny.yp_lite_api.proto import endpoint_sets_api_pb2

from infra.nanny.yp_lite_api.py_stubs import pod_sets_api_stub
from infra.nanny.yp_lite_api.py_stubs import endpoint_sets_api_stub

from infra.yp_service_discovery.python.resolver.resolver import Resolver
from infra.yp_service_discovery.api import api_pb2

from saas.library.python.singleton import Singleton
from saas.tools.devops.lib23.service_token import ServiceTokenStore


class YpLitePod(object):

    pb2 = api_stub = None

    @classmethod
    def initialize_api(cls):
        if cls.api_stub is None:
            cls.pb2 = pod_sets_api_pb2
            podset_api = nanny_rpc_client.RetryingRpcClient(
                rpc_url='https://yp-lite-ui.nanny.yandex-team.ru/api/yplite/pod-sets/',
                oauth_token=ServiceTokenStore.get_token('nanny')
            )
            cls.api_stub = pod_sets_api_stub.YpLiteUIPodSetsServiceStub(podset_api)

    def __init__(self, pod_id=None, cluster=None,  fqdn=None):
        """
        :param pod_id:
        :param cluster:
        :param fqdn:
        If fqdn is given and pod_id or cluster is not given, will try to guess them from fqdn
        """

        self.id = pod_id
        if not self.id:
            self.id = fqdn.split('.')[0]

        if cluster:
            self.cluster = cluster.upper()
        else:
            self.cluster = fqdn.split('.')[1].upper().replace('-', '_')

        self.initialize_api()

    def get_pod_labels(self):
        req = self.pb2.GetPodRequest(pod_id=self.id, cluster=self.cluster)
        pod = self.api_stub.get_pod(req, request_timeout=20).pod
        return {att.key: att.value for att in pod.labels.attributes}

    def update_pod_labels(self, **labels):
        # TODO: needs transaction
        current_labels = self.get_pod_labels()
        current_labels.update(labels)
        current_labels = {k: v for k, v in six.iteritems(current_labels) if v is not None}
        req = self.pb2.UpdatePodRequest(pod_id=self.id, version=current_labels['nanny_version'], cluster=self.cluster, labels=current_labels)
        return self.api_stub.update_pod(req, request_timeout=20)


class Endpointset(six.with_metaclass(Singleton)):

    LOGGER = logging.getLogger(__name__)

    _ENDPOINTSET_API = None
    endpointset_api = None

    _client_name = "{}@{}".format(sys.argv[0], platform.node())
    _SD_RESOLVER = Resolver(client_name=_client_name, timeout=60)

    @classmethod
    def _get_instance_id(cls, args, kwargs):
        """
        Singleton interface
        """
        signature = funcsigs.signature(cls.__init__)
        bound_params = signature.bind(cls, *args, **kwargs)
        return bound_params.arguments['name'], bound_params.arguments['cluster']

    @classmethod
    def _extra_actions(cls, instance, args, kwargs):
        """
        Singleton customisation
        """
        signature = funcsigs.signature(cls.__init__)
        bound_params = signature.bind(cls, *args, **kwargs)
        if 'do_sync' in bound_params.arguments:
            instance.load_from_api()

    @classmethod
    def initialize_endpointset_api(cls):
        if cls._ENDPOINTSET_API is None:
            cls.LOGGER.debug('Initializing endpointset api')
            cls._ENDPOINTSET_API = nanny_rpc_client.RetryingRpcClient(
                rpc_url='https://yp-lite-ui.nanny.yandex-team.ru/api/yplite/endpoint-sets/',
                oauth_token=ServiceTokenStore.get_token('nanny')
            )
            cls.endpointset_api = endpoint_sets_api_stub.YpLiteUIEndpointSetsServiceStub(cls._ENDPOINTSET_API)

    def __repr__(self):
        return 'Endpointset({}@{})'.format(self.cluster, self.name)

    def __str__(self):
        return '{}@{}'.format(self.cluster.lower().replace('_', '-'), self.name)

    @classmethod
    def get_endpointset(cls, name, cluster):
        cls.initialize_endpointset_api()
        req = endpoint_sets_api_pb2.GetEndpointSetRequest(id=name, cluster=cluster)
        try:
            res = cls.endpointset_api.get_endpoint_set(req, request_timeout=20)
            cls.LOGGER.debug('GetEndpointSet req: %s response: %s', req, res)
            return cls.from_proto_msg(cluster, res.endpoint_set)
        except NotFoundError:
            return None

    @classmethod
    def remove_endpointset(cls, endpointset_name, cluster):
        if cls.get_endpointset(endpointset_name, cluster) is None:
            logging.warning("Endpointset {} doesn't exist in {}".format(endpointset_name, cluster))
            return None
        req = endpoint_sets_api_pb2.RemoveEndpointSetRequest()
        req.id = endpointset_name
        req.version = cls.get_endpointset(endpointset_name, cluster).version
        req.cluster = cluster

        return cls.endpointset_api.remove_endpoint_set(req)

    @classmethod
    def from_proto_msg(cls, cluster, proto_msg):
        es = cls(proto_msg.meta.id, cluster)
        es.nanny_service_id = proto_msg.meta.service_id
        es.version = proto_msg.meta.version
        es.protocol = proto_msg.spec.protocol
        es.port = proto_msg.spec.port
        es.pod_filter = proto_msg.spec.pod_filter
        es.description = proto_msg.spec.description
        return es

    def __init__(self, name, cluster, do_sync=False):
        # type: (six.string_types, six.string_types, bool) -> None
        self.initialize_endpointset_api()
        self.name = name
        self.cluster = cluster.upper()

        self.nanny_service_id = None
        self.version = None
        self.protocol = None
        self.port = None
        self.pod_filter = None
        self.description = None

        if do_sync:
            try:
                self.load_from_api()
            except:
                self.LOGGER.exception("Can't sync endpointset %s in cluster %s", self.name, self.cluster)
                raise Exception("Can't sync endpointset {} in cluster {}".format(self.name, self.cluster))

    def load_from_api(self):

        req = endpoint_sets_api_pb2.GetEndpointSetRequest(id=self.name, cluster=self.cluster)
        proto_msg = self.endpointset_api.get_endpoint_set(req, request_timeout=20)
        return self.from_proto_msg(self.cluster, proto_msg.endpoint_set)

    def push_to_api(self):

        req = endpoint_sets_api_pb2.UpdateEndpointSetRequest()
        req.cluster = self.cluster
        req.id = self.name
        req.version = self.version

        req.spec.pod_filter = self.pod_filter
        req.spec.protocol = self.protocol
        req.spec.port = self.port
        req.spec.description = self.description

        a = self.endpointset_api.update_endpoint_set(req, request_timeout=20)
        return a

    def update(self, protocol=None, port=None, pod_filter=None, description=None):
        self.load_from_api()

        if protocol is not None:
            self.protocol = protocol
        if port is not None:
            self.port = port
        if pod_filter is not None:
            self.pod_filter = pod_filter
        if description is not None:
            self.description = description
        self.LOGGER.debug('Updating endpointset %s in cluster %s', self.name, self.cluster)

        self.push_to_api()

    def resolve(self):

        self.LOGGER.debug('Resolving endpointset %s@%s', self.cluster, self.name)
        request = api_pb2.TReqResolveEndpoints()
        request.cluster_name = self.cluster.lower()
        request.endpoint_set_id = self.name
        request.client_name = '1'
        request.label_selectors.append('/')

        result = self._SD_RESOLVER.resolve_endpoints(request)
        return list(result.endpoint_set.endpoints)
