from infra.awacs.proto import api_pb2, api_stub
from saas.library.python.nanny_proto.rpc_client_base import NannyRpcClientBase
from saas.library.python.deploy_manager_api.client import DeployManagerApiClient
from infra.nanny.yp_lite_api.py_stubs import pod_reallocation_api_stub as prapi_stub
from infra.nanny.yp_lite_api.proto.pod_reallocation_api_pb2 import StartPodReallocationRequest, \
    GetPodReallocationSpecRequest
from infra.nanny.yp_lite_api.proto.pod_sets_api_pb2 import ListPodsRequest
from saas.library.python.token_store import PersistentTokenStore
from infra.nanny.yp_lite_api.py_stubs import pod_sets_api_stub as pds_stub
import infra.nanny.nanny_services_rest.nanny_services_rest.client as nanny_client
import nanny_rpc_client
from google.protobuf import json_format


def kv_search(arr, key):
    for el in arr:
        if el.key == key:
            return el.value
    return None


class L3BalancerManager(NannyRpcClientBase):
    _RPC_URL = "https://awacs.yandex-team.ru/api/"
    _OAUTH_SLUG = 'awacs'
    _API_STUB = api_stub.L3BalancerServiceStub
    _API_DM = None

    @classmethod
    def init_dm_client(cls):
        if cls._API_DM is None:
            cls._API_DM = DeployManagerApiClient()

    def __init__(self):
        super(NannyRpcClientBase, self).__init__()
        self._init_client()
        self.init_dm_client()

    @staticmethod
    def _default_l3_balancer_request_data(ctype, namespace_id, l3_balancer_id='', author='', comment='', balancer_ids=None):
        if balancer_ids is None:
            balancer_ids = []
        request_data = {
            'meta': {
                'id': l3_balancer_id if l3_balancer_id != '' else '{}.saas.yandex.net'.format(ctype.lower().replace('_', '-')),
                'namespace_id': namespace_id,
                'author': author,
                'comment': comment,
                'auth': {
                    'type': "STAFF",
                    'staff': {
                        'owners': {
                            "group_ids": ["29985", ],
                        }
                    }
                },
            },
            'order': {
                'abc_service_id': 664,
                'fqdn': l3_balancer_id if l3_balancer_id != '' else '{}.saas.yandex.net'.format(ctype.lower().replace('_', '-')),
                'real_servers': {
                    'type': 'BALANCERS',
                    'balancers': [{'id': balancer} for balancer in balancer_ids],
                },
                'ctl_version': '1'
            }
        }
        return request_data

    def create_l3_balancer(self, ctype, namespace_id, balancer_ids, l3_balancer_id='', author='', comment='', advanced_settings=None):
        request_data = self._default_l3_balancer_request_data(ctype, namespace_id, l3_balancer_id, author, comment, balancer_ids)
        req_pb = api_pb2.CreateL3BalancerRequest()
        if advanced_settings is not None:
            request_data.update(advanced_settings)
        json_format.ParseDict(request_data, req_pb)
        return self._CLIENT.create_l3_balancer(req_pb)

    @staticmethod
    def _request_from_pod(pod, l3_balancer):
        storage_class = '/hdd'
        bandwidth_limit = 5
        bandwidth_guarantee = 5
        work_dir_quota = 512
        volume_attrs = []
        for el in pod.disk_volume_requests:
            if kv_search(el.labels.attributes, 'mount_path') == '/':
                storage_class = el.storage_class
                bandwidth_limit = int(int(el.quota_policy.bandwidth_limit) / 2 ** 20)
                bandwidth_guarantee = int(int(el.quota_policy.bandwidth_guarantee) / 2 ** 20)
                work_dir_quota = int(int(kv_search(el.labels.attributes, 'work_dir_snapshot_quota')) / 2 ** 20)
            else:
                volume_attrs.append({
                    'mount_point': kv_search(el.labels.attributes, 'mount_path'),
                    'storage_class': el.storage_class,
                    'bandwidth_guarantee_megabytes_per_sec': int(int(el.quota_policy.bandwidth_guarantee)/2**20),
                    'bandwidth_limit_megabytes_per_sec': int(int(el.quota_policy.bandwidth_limit)/2**20),
                })
        root_fs_quota = int(pod.iss.instances[0].instanceRevision.container[1].resource_allocation.limit[0].scalar.value)

        quota_source = pod.iss.instances[0].entity.instance.volumes
        for el in quota_source:
            for vol in range(len(volume_attrs)):
                if volume_attrs[vol]['mount_point'] == el.mountPoint:
                    volume_attrs[vol]['disk_quota_megabytes'] = int(int(el.quotaBytes)/2**20)
        network_macro = pod.ip6_address_requests[0].network_id
        pod = json_format.MessageToDict(pod)
        req = {
            'allocation_request': {
                'persistent_volumes': volume_attrs,
                'vcpu_guarantee': pod['resourceRequests']['vcpuGuarantee'],
                'vcpu_limit': pod['resourceRequests']['vcpuLimit'],
                'memory_guarantee_megabytes': int(int(pod['resourceRequests']['memoryGuarantee']) / 2 ** 20),
                'snapshots_count': 10,
                'root_volume_storage_class': storage_class,
                'sysctl_properties': pod['sysctlProperties'],
                'network_macro': network_macro,
                'virtual_service_ids': [l3_balancer, ],
                'root_fs_quota_megabytes': int(int(root_fs_quota) / 2 ** 20),
                'root_bandwidth_limit_megabytes_per_sec': bandwidth_limit,
                'root_bandwidth_guarantee_megabytes_per_sec': bandwidth_guarantee,
                'work_dir_quota_megabytes': work_dir_quota,
            },
            'degrade_params': {
                'max_unavailable_pods': 1,
                'min_update_delay_seconds': 300,
            }
        }
        return req

    def reallocate_pods(self, balancer_nanny_service, cluster, l3_balancer_id, retrying_num=5):
        pods_stub = pds_stub.YpLiteUIPodSetsServiceStub(nanny_rpc_client.RetryingRpcClient(
            rpc_url='https://yp-lite-ui.nanny.yandex-team.ru/api/yplite/pod-sets/',
            oauth_token=PersistentTokenStore.get_token_from_store_env_or_file('nanny')))
        list_request = ListPodsRequest(service_id=balancer_nanny_service, cluster=cluster)
        pods = [pod.spec for pod in pods_stub.list_pods(list_request).pods]
        reallocated_pods = []
        for spec in pods:
            for i in range(retrying_num):
                try:
                    reallocated_pods.append(self._reallocate_pod(spec, balancer_nanny_service, cluster, l3_balancer_id))
                    break
                except IndexError:
                    print("Index error")
                    pass
        return reallocated_pods

    def _reallocate_pod(self, pod_spec, balancer_nanny_service, cluster, l3_balancer_id):
        yp_stub = prapi_stub.YpLiteReallocationServiceStub(nanny_rpc_client.RetryingRpcClient(
            rpc_url='https://yp-lite-ui.nanny.yandex-team.ru/api/yplite/pod-reallocation/',
            oauth_token=PersistentTokenStore.get_token_from_store_env_or_file('nanny')))

        start_pod_reallocation_request = StartPodReallocationRequest(service_id=balancer_nanny_service)
        ncl = nanny_client.ServiceRepoClient('https://nanny.yandex-team.ru',
                                             token=PersistentTokenStore.get_token('nanny'))
        start_pod_reallocation_request.snapshot_id = ncl.get_runtime_attrs(balancer_nanny_service)['_id']
        json_format.ParseDict(self._request_from_pod(pod_spec, l3_balancer_id), start_pod_reallocation_request)

        get_reallocation_request = GetPodReallocationSpecRequest(service_id=balancer_nanny_service)
        try:
            start_pod_reallocation_request.previous_reallocation_id = yp_stub.get_pod_reallocation_spec(get_reallocation_request).spec.id
        except nanny_rpc_client.exceptions.BadRequestError:
            pass

        return yp_stub.start_pod_reallocation(start_pod_reallocation_request)

    def list_l3_balancers(self, namespace_id):
        req_pb = api_pb2.ListBalancersRequest()
        req_pb.namespace_id = namespace_id
        return self._CLIENT.list_l3_balancers(req_pb)

    def get_l3_balancer(self, namespace_id, balancer_id):
        req_pb = api_pb2.GetL3BalancerRequest()
        req_pb.namespace_id = namespace_id
        req_pb.id = balancer_id
        return self._CLIENT.get_l3_balancer(req_pb)
