import google.protobuf.field_mask_pb2
import nanny_rpc_client
import requests
from infra.nanny.yp_lite_api.proto import pod_sets_api_pb2, endpoint_sets_api_pb2
from infra.nanny.yp_lite_api.py_stubs import pod_sets_api_stub, endpoint_sets_api_stub
from nanny_repo import repo_api_pb2, repo_api_stub
from six.moves.urllib import parse as urlparse


DEFAULT_REQUEST_TIMEOUT = 90


class NannyClient(object):
    def __init__(self, nanny_url, yp_lite_ui_url, oauth_token):
        self.nanny_api_url = urlparse.urljoin(nanny_url, '/v2/')
        self.oauth_token = oauth_token
        self.s = requests.Session()
        self.s.headers = {
            'Authorization': 'OAuth {}'.format(self.oauth_token),
        }

        self.repo_client = nanny_rpc_client.SessionedRpcClient(
            urlparse.urljoin(nanny_url, '/api/repo/'),
            request_timeout=DEFAULT_REQUEST_TIMEOUT,
            oauth_token=oauth_token)
        self.repo_stub = repo_api_stub.RepoServiceStub(self.repo_client)

        self.pod_sets_client = nanny_rpc_client.SessionedRpcClient(
            urlparse.urljoin(yp_lite_ui_url, '/api/yplite/pod-sets/'),
            request_timeout=DEFAULT_REQUEST_TIMEOUT,
            oauth_token=oauth_token)
        self.endpoint_sets_client = nanny_rpc_client.SessionedRpcClient(
            urlparse.urljoin(yp_lite_ui_url, '/api/yplite/endpoint-sets/'),
            request_timeout=DEFAULT_REQUEST_TIMEOUT,
            oauth_token=oauth_token)
        self.pod_sets_stub = pod_sets_api_stub.YpLiteUIPodSetsServiceStub(self.pod_sets_client)
        self.endpoint_sets_stub = endpoint_sets_api_stub.YpLiteUIEndpointSetsServiceStub(self.endpoint_sets_client)

    def get_replication_policy(self, service_id):
        replication_policy_request = repo_api_pb2.GetReplicationPolicyRequest(policy_id=service_id)
        return self.repo_stub.get_replication_policy(replication_policy_request).policy

    def update_pod_by_script(self, cluster, service_id, pod_id, pod_version):
        """
        :type cluster: six.text_type
        :type service_id: six.text_type
        :type pod_id: six.text_type
        :type pod_version: six.text_type
        :rtype: list[six.text_type]
        """
        req_pb = pod_sets_api_pb2.UpdatePodByScriptRequest(
            service_id=service_id,
            cluster=cluster,
            pod_id=pod_id,
            version=pod_version,
            script=pod_sets_api_pb2.UpdatePodByScriptRequest.SPLIT_OFF_AWACS_VOLUME
        )
        return self.pod_sets_stub.update_pod_by_script(req_pb)

    def get_pod(self, cluster, pod_id):
        """
        :type cluster: six.text_type
        :type pod_id: six.text_type
        :rtype: list[six.text_type]
        """
        req_pb = pod_sets_api_pb2.GetPodRequest(cluster=cluster, pod_id=pod_id)
        return self.pod_sets_stub.get_pod(req_pb).pod

    def iter_pods(self, cluster, service_id, field_mask_pb=None):
        """
        :type service_id: six.text_type
        :type cluster: six.text_type
        :type field_mask_pb: google.protobuf.field_mask_pb2.FieldMask
        """
        limit = 10
        req_pb = pod_sets_api_pb2.ListPodsRequest(
            service_id=service_id,
            cluster=cluster,
            limit=limit
        )
        if field_mask_pb is not None:
            req_pb.field_mask.CopyFrom(field_mask_pb)
        while True:
            resp_pb = self.pod_sets_stub.list_pods(req_pb)
            if not resp_pb.pods:
                break
            for pod_pb in resp_pb.pods:
                yield pod_pb
            req_pb.offset += limit

    def list_all_pod_ids(self, cluster, service_id):
        """
        :type service_id: six.text_type
        :type cluster: six.text_type
        :rtype: list[six.text_type]
        """
        field_mask_pb = google.protobuf.field_mask_pb2.FieldMask()
        field_mask_pb.paths.append('meta.id')
        rv = []
        for pod_pb in self.iter_pods(cluster=cluster, service_id=service_id, field_mask_pb=field_mask_pb):
            rv.append(pod_pb.meta.id)
        return rv

    def iter_endpoint_sets(self, cluster, service_id):
        """
        :type service_id: six.text_type
        :type cluster: six.text_type
        :type field_mask_pb: google.protobuf.field_mask_pb2.FieldMask
        """
        req_pb = endpoint_sets_api_pb2.ListEndpointSetsRequest(
            service_id=service_id,
            cluster=cluster,
        )
        resp_pb = self.endpoint_sets_stub.list_endpoint_sets(req_pb)
        for endpoint_set_pb in resp_pb.endpoint_sets:
            yield endpoint_set_pb

    def list_all_endpoint_set_ids(self, cluster, service_id):
        """
        :type service_id: six.text_type
        :type cluster: six.text_type
        :rtype: list[six.text_type]
        """
        rv = []
        for es_pb in self.iter_endpoint_sets(cluster=cluster, service_id=service_id):
            rv.append(es_pb.meta.id)
        return rv

    def get_service(self, service_id):
        resp = self.s.get(self.nanny_api_url + 'services/' + service_id)
        resp.raise_for_status()
        return resp.json()

    def get_service_info_attrs(self, service_id):
        resp = self.s.get(self.nanny_api_url + 'services/' + service_id + '/info_attrs/', headers={
            'Authorization': 'OAuth {}'.format(self.oauth_token),
        })
        resp.raise_for_status()
        return resp.json()

    def get_service_runtime_attrs(self, service_id, exclude_runtime_attrs=False):
        params = {}
        if exclude_runtime_attrs:
            params['exclude_runtime_attrs'] = 1
        resp = self.s.get(self.nanny_api_url + 'services/' + service_id + '/runtime_attrs/', headers={
            'Authorization': 'OAuth {}'.format(self.oauth_token),
        }, params=params)
        resp.raise_for_status()
        return resp.json()

    def get_state(self, service_id):
        resp = self.s.get(self.nanny_api_url + 'services/' + service_id + '/state/', headers={
            'Authorization': 'OAuth {}'.format(self.oauth_token),
        })
        resp.raise_for_status()
        return resp.json()

    def list_current_instances(self, service_id):
        """
        :rtype: dict
        """
        url = self.nanny_api_url + 'services/{service_id}/current_state/instances/'.format(service_id=service_id)
        response = self.s.get(url, headers={
            'Authorization': 'OAuth {}'.format(self.oauth_token),
        })
        response.raise_for_status()
        return response.json()['result']

    def update_runtime_attrs_content(self, service_id, snapshot_id, snapshot_priority, runtime_attrs_content,
                                     comment):
        resp = self.s.put(self.nanny_api_url + 'services/' + service_id + '/runtime_attrs/', headers={
            'Authorization': 'OAuth {}'.format(self.oauth_token),
        }, json={
            'snapshot_id': snapshot_id,
            'content': runtime_attrs_content,
            'comment': comment,
            'meta_info': {
                'scheduling_config': {
                    'scheduling_priority': snapshot_priority,
                },
            }
        })
        resp.raise_for_status()
        return resp.json()
