from instancectl.lib import envutil
from instancectl.sd.client import SdEndpointSetEmptyError


class SdInstancesGetter(object):
    DEFAULT_CLUSTER_NAMES = ('SAS', 'MAN', 'VLA', 'IVA', 'MYT')

    def __init__(self, sd_client, env):
        """

        :type sd_client: any
        :type env: any
        """
        self.sd_client = sd_client  # type: instancectl.sd.client.ISdClient
        self.env = env
        self.cluster_endpoint_set_id_to_endpoints = {}
        self.cluster_service_id_to_pods = {}

    def _get_current_cluster_name(self):
        dc = envutil.extract_dc_from_orthogonal_tags(self.env.orthogonal_tags)
        if not dc:
            raise ValueError("Could not get current cluster name")
        return dc

    def get_cluster_endpoints(self, endpoint_set_id, cluster_name):
        key = (endpoint_set_id, cluster_name)
        if key not in self.cluster_endpoint_set_id_to_endpoints:
            try:
                self.cluster_endpoint_set_id_to_endpoints[key] = list(
                    self.sd_client.get_endpoints(endpoint_set_id, cluster_name))
            except SdEndpointSetEmptyError:
                self.cluster_endpoint_set_id_to_endpoints[key] = []
        return self.cluster_endpoint_set_id_to_endpoints[key]

    def get_endpoints(self, endpoint_set_id, cluster_names=None):
        if not cluster_names:
            cluster_names = self.DEFAULT_CLUSTER_NAMES
        rv = []
        for cluster_name in cluster_names:
            rv.extend(self.get_cluster_endpoints(endpoint_set_id, cluster_name))
        if not rv:
            raise ValueError('EndpointSet: {}, has no endpoints in clusters: {} (sd_url: {})'.format(
                endpoint_set_id, cluster_names, self.sd_client.sd_url))
        return rv

    def get_endpoints_current_cluster(self, endpoint_set_id):
        return self.get_cluster_endpoints(endpoint_set_id, cluster_name=self._get_current_cluster_name())

    def get_cluster_pods(self, service_id, cluster_name):
        key = (service_id, cluster_name)
        if key not in self.cluster_service_id_to_pods:
            self.cluster_service_id_to_pods[key] = list(self.sd_client.get_pods(service_id, cluster_name))
        return self.cluster_service_id_to_pods[key]

    def get_pods(self, service_id, cluster_names=None):
        if not cluster_names:
            cluster_names = self.DEFAULT_CLUSTER_NAMES
        rv = []
        for cluster_name in cluster_names:
            rv.extend(self.get_cluster_pods(service_id, cluster_name))
        if not rv:
            raise ValueError('Service: {} has not pods in clusters: {} (sd_url: {})'.format(
                service_id, cluster_names, self.sd_client.sd_url))
        return rv

    def get_pods_current_cluster(self, service_id):
        return self.get_cluster_pods(service_id, cluster_name=self._get_current_cluster_name())

