import uuid
import grpc

from sepelib.util.retry import Retry, RetrySleeper
from infra.yp_service_discovery.api import api_pb2
from infra.yp_service_discovery.python.resolver import resolver as sd_resolver


class SdClientError(Exception):
    pass


class ISdClient(object):
    def get_endpoints(self, service_id, cluster_name):
        """

        :type service_id: str
        :type cluster_name: str
        :rtype: list[api_pb2.TEndpoint]
        """
        raise NotImplemented

    def get_pods(self, service_id, cluster_name):
        """

        :type service_id: str
        :type cluster_name: str
        :rtype: list[api_pb2.TPod]
        """
        raise NotImplemented


class SdEndpointSetEmptyError(SdClientError):
    pass


class SdGrpcClient(ISdClient):
    """
    proto: https://a.yandex-team.ru/arc/trunk/arcadia/infra/yp_service_discovery/api/api.proto
    """

    def __init__(self, slot, sd_url=None):
        client_name = "instancectl:{}".format(slot)
        self.sd_url = sd_url or sd_resolver.SD_GRPC_ADDRESS
        self.resolver = sd_resolver.Resolver(client_name=client_name, timeout=5, grpc_address=self.sd_url)
        self._retry = Retry(retry_sleeper=RetrySleeper(max_tries=3, delay=1, backoff=3),
                            retry_exceptions=(grpc.RpcError, SdEndpointSetEmptyError))

    def get_endpoints(self, endpoint_set_id, cluster):
        """

        :type endpoint_set_id: str
        :type cluster: str
        :rtype: list[api_pb2.TEndpoint]
        :raises: (grpc.RpcError, SdEndpointSetEmptyError)
        """

        req = api_pb2.TReqResolveEndpoints()
        req.cluster_name = cluster.lower()
        req.endpoint_set_id = endpoint_set_id
        req.ruid = "{}-{}".format(req.cluster_name, uuid.uuid4())

        def make_request():
            res = self.resolver.resolve_endpoints(req)
            if res.resolve_status == api_pb2.EResolveStatus.EMPTY:
                raise SdEndpointSetEmptyError(
                    'EndpointSet: {r.endpoint_set_id} in cluster: {r.cluster_name} '
                    'has empty endpoints list'.format(r=req)
                )
            return res

        resp = self._retry(make_request)

        return resp.endpoint_set.endpoints

    def get_pods(self, service_id, cluster):
        """

        :type cluster: str
        :type service_id: str
        :rtype: list[api_pb2.TPod]
        """
        pod_set_id = service_id.replace("_", "-")

        req = api_pb2.TReqResolvePods()
        req.cluster_name = cluster.lower()
        req.pod_set_id = pod_set_id
        req.ruid = "{}-{}".format(req.cluster_name, uuid.uuid4())
        resp = self._retry(self.resolver.resolve_pods, req)
        return resp.pod_set.pods

