#!/usr/bin/python

import os
import json
import urllib
import base64
import re
import requests_unixsocket
from yt import yson
from yp_proto.yp.client.api.proto import cluster_api_pb2
from infra.qyp.proto_lib import vmset_pb2, vmagent_pb2


def mac2ll(mac):
    mac = mac.split(":")
    mac.insert(3, 'fe')
    mac.insert(3, 'ff')
    mac[0] = str(int(mac[0]) ^ 6)
    return "fe80::%s:%s:%s:%s" % ("".join(mac[0:2]), "".join(mac[2:4]), "".join(mac[4:6]), "".join(mac[6:8]))


class PodSpec(object):
    DEFAULT_SOCKET_PATH = '/var/run/iss/pod.socket'
    IP6_ADDRESS_ALLOCATION_LABEL_VALUE = base64.b64encode(yson.dumps('vm'))
    GPU_RESOURCE_ID_PATTERN = re.compile(r'gpu-\w*-(\w*)-(\w*)-(\w*)-.*')

    @classmethod
    def build(cls, pod_agent_socket=None):  # type: (str) -> PodSpec
        pod_agent_socket = pod_agent_socket or cls.DEFAULT_SOCKET_PATH
        session = requests_unixsocket.Session()
        response = session.get(
            'http+unix://{}/pod_spec'.format(
                urllib.urlencode({'remove_this': pod_agent_socket}).replace('remove_this=', '')))
        return cls(json.loads(response.content))

    def __init__(self, pod_spec_data):
        self._pod_spec_data = pod_spec_data
        self._iss_payload = None
        self._properties = None
        self._dynamic_properties = None
        self._vmagent_port = None
        self._porto_properties = None
        self._pod_labels = None
        self._pod_annotations = None
        self._tags = None
        self._volumes = None
        self._resource_requests = None
        self._gpu_bus_ids = None

    @property
    def iss_payload(self):  # type: () -> cluster_api_pb2.HostConfiguration
        if self._iss_payload is None:
            iss_payload_string = base64.b64decode(self._pod_spec_data['issPayload'])

            self._iss_payload = cluster_api_pb2.HostConfiguration()
            self._iss_payload.ParseFromString(iss_payload_string)
        return self._iss_payload

    @property
    def properties(self):
        if not self._properties:
            self._properties = {k: v for k, v in self.iss_payload.instances[0].properties.items()}
            dynamic_properties = {k: v for k, v in self.iss_payload.instances[0].dynamicProperties.items()}
            self._properties.update(dynamic_properties)
        return self._properties

    @property
    def vmagent_port(self):  # type: () -> int
        if not self._vmagent_port:
            self._vmagent_port = int(self.properties['PORT'])
        return self._vmagent_port

    @property
    def tags(self):
        if not self._tags:
            self._tags = self.properties['tags']
        return self._tags

    @property
    def host_dc(self):
        return next((tag.split('a_dc_')[1] for tag in self.tags.split() if tag.startswith('a_dc_')), '')

    @property
    def ip6_address_allocations(self):
        for ip_addr_allocation in self._pod_spec_data.get('ip6AddressAllocations'):
            labels = ip_addr_allocation.get('labels', {}).get('attributes', [])
            if any([l.get('value') == self.IP6_ADDRESS_ALLOCATION_LABEL_VALUE] for l in labels):
                yield ip_addr_allocation

    @property
    def io_bandwidth_limits(self):
        limits = {}
        dvr = self._pod_spec_data.get('diskVolumeRequests', [])
        for d in dvr:
            strg = d.get('storageClass')
            if strg in limits:
                limits[strg] += int(d.get('quotaPolicy', {}).get('bandwidthLimit', 0))
            else:
                limits[strg] = int(d.get('quotaPolicy', {}).get('bandwidthLimit', 0))
        return limits

    @property
    def vm_ip(self):
        for addr in self.ip6_address_allocations:
            if addr.get('vlanId') == 'backbone':
                return addr.get('address')

    @property
    def vm_aux_ip(self):
        for addr in self.ip6_address_allocations:
            if addr.get('vlanId') == 'fastbone':
                return addr.get('address')

    @property
    def pod_labels(self):
        if self._pod_labels is None:
            pod_labels_raw = self._pod_spec_data.get('podDynamicAttributes', {}).get('labels', {}).get('attributes')
            self._pod_labels = {i['key']: yson.loads(base64.b64decode(i['value'])) for i in pod_labels_raw}
        return self._pod_labels

    @property
    def pod_annotations(self):
        if self._pod_annotations is None:
            pod_annotations_raw = self._pod_spec_data.get('podDynamicAttributes', {}).get('annotations', {}).get(
                'attributes', [])
            self._pod_annotations = {i['key']: yson.loads(base64.b64decode(i['value'])) for i in pod_annotations_raw}
        return self._pod_annotations

    @property
    def vm_hostname(self):
        return self._pod_spec_data.get('dns', {}).get('persistentFqdn')

    @property
    def vm(self):  # type: () -> vmset_pb2.VM
        v = vmset_pb2.VM.FromString(self.pod_annotations['qyp_vm_spec'])
        for key, val in self.io_bandwidth_limits.iteritems():
            v.spec.qemu.io_limits_per_storage[key] = val
        return v

    @property
    def ssh_authorized_keys(self):
        return self.pod_annotations.get('qyp_ssh_authorized_keys', None)

    @property
    def gpu_bus_ids(self):
        if self._gpu_bus_ids is None:
            self._gpu_bus_ids = []
            for gpu_allocation in self._pod_spec_data.get('gpuAllocations', {}):
                bus_id = self.GPU_RESOURCE_ID_PATTERN.match(gpu_allocation['resourceId']).groups()
                self._gpu_bus_ids.append('{}:{}.{}'.format(*bus_id))
        return self._gpu_bus_ids


class VmagentContext(object):
    """
    Files Tree:
        - WORKDIR (default: iss_hook_start ~)
            - DEFAULT_CONFIG_FILE (pod resource)
            - MONITOR_SOCKET_FILE
            - dump.json
            - vmagent
                - vmagent
                    - SCRIPTS_FOLDER
                        -
                    - vmagent (binary file)

        - MAIN_STORAGE (default: /qemu-persistent)
            - IMAGE_FOLDER
                - IMAGE_FILE
                - DELTA_FILE
            - LOGS_FOLDER
                - SERIAL_LOG_FILE
                - VMAGENT_LOG_FILE
                - QDM_LOG_FILE
            - EXTRAS_FOLDER
                - qdm_cli (binary file)
                - progress_file (for upload|download operations)
            - CLOUD_INIT_CONFIGS_FOLDER
            - CURRENT_STATE_FILE
            - LAST_STATUS_FILE
            - QEMU_LAUNCHER_FILE

    """
    INFRA_BLOCKED_MEMORY = 1024 ** 3
    INFRA_BLOCKED_DISK_SPACE = 1024 ** 3

    VMAGENT_VERSION = '___VMAGENT_VERSION___'
    MAIN_QEMU_VOLUME_NAME = '/qemu-persistent'
    DEFAULT_MAIN_STORAGE_PATH = '/qemu-persistent'

    EXTRAS_FOLDER_NAME = 'extras'
    LOGS_FOLDER_NAME = 'logs'
    CLOUD_INIT_CONFIGS_FOLDER_NAME = 'cloud_init_configs'

    DEFAULT_CONFIG_FILE_NAME = 'vm.config'
    MONITOR_SOCKET_FILE_NAME = 'mon.sock'
    VNC_SOCKET_FILE_NAME = 'vnc.sock'
    CURRENT_CONFIG_FILE_NAME = 'current.state'
    LAST_STATUS_FILE_NAME = 'last_status'
    SERIAL_LOG_FILE_NAME = 'serial.log'
    VMAGENT_LOG_FILE_NAME = 'vmagent.log'

    QEMU_LAUNCHER_FILE_NAME = 'qemu_launcher.sh'
    REBUILD_QEMU_LAUNCHER_FILE_NAME = 'do_not_rebuild_qemu_launcher'
    DEFAULT_QEMU_SYSTEM_CMD_BIN_PATH = '/usr/local/bin/qemu-system-x86_64'
    DEFAULT_QEMU_IMG_CMD_BIN_PATH = '/usr/local/bin/qemu-img'

    DEFAULT_VMAGENT_PORT = 7255
    MAX_SINGLE_NUMA_CPU = 64
    MAX_SINGLE_NUMA_VFIO = 9

    @staticmethod
    def validate_volumes(volumes):
        main_volumes = [v for v in volumes if v.is_main]
        if len(main_volumes) != 1:
            raise ValueError('Only 1 main volume required')
        main_volume = main_volumes[0]
        if main_volume.resource_url is None:
            raise ValueError('Main volume should have resource for init')

        if main_volume.image_type is None:
            raise ValueError('Main volume should have image_type')

        if len(set([v.mount_path for v in volumes])) != len(volumes):
            raise ValueError('Volume.mount_path should be unique')
        return volumes

    @classmethod
    def build_from_pod_spec(cls, pod_spec=None, pod_agent_socket=None, hostname=None):
        pod_spec = pod_spec or PodSpec.build(pod_agent_socket)
        return cls(
            vm=pod_spec.vm,
            vm_ip=pod_spec.vm_ip,
            vm_aux_ip=pod_spec.vm_aux_ip,
            vm_hostname=pod_spec.vm_hostname,
            node_hostname=os.environ.get('PORTO_HOST', hostname),
            cluster=pod_spec.host_dc,
            vmagent_port=pod_spec.vmagent_port,
            ssh_authorized_keys=pod_spec.ssh_authorized_keys,
            gpu_bus_ids=pod_spec.gpu_bus_ids,
            iss_properties=pod_spec.properties,
        )

    def __init__(self,
                 vm,  # type: vmset_pb2.VM
                 vm_ip,
                 vm_aux_ip,
                 vm_hostname,
                 node_hostname,
                 cluster,
                 gpu_bus_ids,
                 iss_properties,
                 ssh_authorized_keys=None,
                 vmagent_port=None,
                 workdir=None,
                 windows_ready=None,
                 ):
        self._workdir = workdir or os.getcwd()
        self._vm = vm
        self._windows_ready = windows_ready
        self._vmagent_port = int(vmagent_port or self.DEFAULT_VMAGENT_PORT)
        self._vm_ip = vm_ip
        self._vm_aux_ip = vm_aux_ip
        self._vm_hostname = vm_hostname
        self._node_hostname = node_hostname
        self._cluster = cluster
        self._ssh_authorized_keys = ssh_authorized_keys if ssh_authorized_keys is not None else []

        self._vm_config = self._build_vm_config()
        self._main_volume = self._vm_config.volumes[0]
        self._extra_volumes = self._vm_config.volumes[1:] if len(self._vm_config.volumes) > 1 else []
        self._gpu_bus_ids = gpu_bus_ids
        self._iss_properties = iss_properties
        self._numa_nodes = None
        self._vfio_devices = None
        self._vfio_numa_mapping = None
        self._use_numa = None

    def _build_vm_config(self):
        config = vmagent_pb2.VMConfig()
        config.mem = self._vm.spec.qemu.resource_requests.memory_limit - self.INFRA_BLOCKED_MEMORY
        config.vcpu = int(self._vm.spec.qemu.resource_requests.vcpu_limit / 1000)
        config.autorun = self._vm.spec.qemu.autorun
        config.type = self._vm.spec.qemu.vm_type
        config.audio = self._vm.spec.qemu.qemu_options.audio

        for order, qemu_volume in enumerate(self._vm.spec.qemu.volumes):
            vm_volume = config.volumes.add()  # type: vmagent_pb2.VMVolume
            vm_volume.is_main = (order == 0)
            vm_volume.order = order
            vm_volume.name = qemu_volume.name
            vm_volume.mount_path = qemu_volume.pod_mount_path
            vm_volume.available_size = qemu_volume.capacity - self.INFRA_BLOCKED_DISK_SPACE
            vm_volume.image_type = qemu_volume.image_type
            vm_volume.resource_url = qemu_volume.resource_url
            vm_volume.vm_mount_path = qemu_volume.vm_mount_path
            vm_volume.req_id = qemu_volume.req_id

        return config

    @property
    def VM(self):
        return self._vm

    @property
    def VM_ID(self):
        return self._vm.meta.id

    @property
    def VM_CONFIG(self):
        return self._vm_config

    @property
    def MAIN_VOLUME(self):
        return self._main_volume

    @property
    def EXTRA_VOLUMES(self):
        return self._extra_volumes

    @property
    def SSH_AUTHORIZED_KEYS(self):
        return self._ssh_authorized_keys

    @property
    def REBUILD_QEMU_LAUNCHER(self):
        if not os.path.exists(self.CLOUD_INIT_CONFIGS_FOLDER_PATH):
            return True
        if os.path.exists(os.path.join(self.MAIN_STORAGE_PATH, self.REBUILD_QEMU_LAUNCHER_FILE_NAME)):
            return False
        return True

    @property
    def CLOUD_INIT_CONFIGS_FOLDER_PATH(self):
        return os.path.join(self.MAIN_STORAGE_PATH, self.CLOUD_INIT_CONFIGS_FOLDER_NAME)

    @property
    def CLUSTER(self):
        return self._cluster.upper()

    @property
    def NODE_HOSTNAME(self):
        return self._node_hostname

    @property
    def VM_HOSTNAME(self):
        return self._vm_hostname

    @property
    def VNC_SOCKET_FILE_PATH(self):
        return os.path.join(self.WORKDIR, self.VNC_SOCKET_FILE_NAME)

    @property
    def VMAGENT_PORT(self):  # type: () -> int
        return self._vmagent_port

    @property
    def MON_PORT(self):
        return self.VMAGENT_PORT + 1

    @property
    def SERIAL_PORT(self):
        return self.VMAGENT_PORT + 2

    @property
    def VNC_PORT(self):
        return self.VMAGENT_PORT + 3

    @property
    def VM_IP(self):
        return self._vm_ip

    @property
    def VM_MAC(self):
        return "52:54:00:12:34:56"

    @property
    def VM_AUX_IP(self):
        return self._vm_aux_ip

    @property
    def LLADDR(self):
        return "52:54:00:00:{}:{}".format(
            format(int(self.VMAGENT_PORT) >> 8, '02x'),
            format(int(self.VMAGENT_PORT) & 0xff, '02x'))

    @property
    def VMLLADDR(self):
        return '52:54:00:00:12:34:56'

    @property
    def TAP_DEV(self):
        return "tap{}".format(self.VMAGENT_PORT)

    @property
    def TAP_LL(self):
        return mac2ll(self.LLADDR)

    @property
    def MAIN_STORAGE_PATH(self):
        return self._main_volume.mount_path

    @property
    def WORKDIR(self):
        return self._workdir or os.getcwd()

    @property
    def MONITOR_PATH(self):
        return os.path.join(self.WORKDIR, self.MONITOR_SOCKET_FILE_NAME)

    @property
    def CURRENT_CONFIG_FILE_PATH(self):
        return os.path.join(self.MAIN_STORAGE_PATH, self.CURRENT_CONFIG_FILE_NAME)

    @property
    def LAST_STATUS_FILE_PATH(self):
        return os.path.join(self.MAIN_STORAGE_PATH, self.LAST_STATUS_FILE_NAME)

    @property
    def EXTRAS_FOLDER_PATH(self):
        return os.path.join(self.MAIN_STORAGE_PATH, self.EXTRAS_FOLDER_NAME)

    @property
    def LOGS_FOLDER_PATH(self):
        return os.path.join(self.MAIN_STORAGE_PATH, self.LOGS_FOLDER_NAME)

    @property
    def SERIAL_LOG_FILE_PATH(self):
        return os.path.join(self.LOGS_FOLDER_PATH, self.SERIAL_LOG_FILE_NAME)

    @property
    def VMAGENT_LOG_FILE_PATH(self):
        return os.path.join(self.LOGS_FOLDER_PATH, self.VMAGENT_LOG_FILE_NAME)

    @property
    def WINDOWS_READY(self):
        if self._windows_ready is not None:
            return self._windows_ready

        return (os.access('/opt/CloudbaseInitSetup_0_9_11_x64.msi', os.F_OK) and
                os.access('/opt/CloudbaseInitSetup_0_9_11_x86.msi', os.F_OK) and
                os.access('/opt/cloudbase-init.conf', os.F_OK) and
                os.access('/opt/cloudbase-init-unattend.conf', os.F_OK))

    @property
    def QEMU_LAUNCHER_FILE_PATH(self):
        return os.path.join(self.MAIN_STORAGE_PATH, self.QEMU_LAUNCHER_FILE_NAME)

    @property
    def QEMU_SYSTEM_CMD_BIN_PATH(self):
        binary_path = self._iss_properties.get('QEMU_SYSTEM_CMD_BIN_PATH', self.DEFAULT_QEMU_SYSTEM_CMD_BIN_PATH)
        binary_path = os.path.expandvars(binary_path)
        return binary_path

    @property
    def QEMU_IMG_CMD_BIN_PATH(self):
        binary_path = self._iss_properties.get('QEMU_IMG_CMD_BIN_PATH', self.DEFAULT_QEMU_IMG_CMD_BIN_PATH)
        binary_path = os.path.expandvars(binary_path)
        return binary_path

    @property
    def GPU_BUS_IDS(self):
        return self._gpu_bus_ids

    @property
    def NUMA_NODES(self):
        if self._numa_nodes is None:
            self._numa_nodes = self._get_numa_nodes()
        return self._numa_nodes

    @property
    def VFIO_DEVICES(self):
        if self._vfio_devices is None:
            self._vfio_devices = self._get_vfio_devices()
        return self._vfio_devices

    @property
    def VFIO_NUMA_MAPPING(self):
        if self._vfio_numa_mapping is None:
            self._vfio_numa_mapping = self._get_vfio_numa_mapping()
        return self._vfio_numa_mapping

    @property
    def USE_NUMA(self):
        if self._use_numa is None:
            self._use_numa = self._is_numa_enabled()
        return self._use_numa

    def _get_numa_nodes(self):
        """
        :rtype: list[int]
        """
        with open('/sys/devices/system/node/online') as numa_nodes_file:
            nodes = numa_nodes_file.read()
        start, end = nodes.split('-')
        return range(int(start), int(end) + 1)

    def _get_vfio_devices(self):
        """
        :rtype: list[str]
        """
        try:
            return os.listdir('/dev/vfio')
        except OSError:
            return []

    def _get_vfio_numa_mapping(self):
        """
        :rtype: dict[int, list[str]]
        """
        vfio_numa_mapping = {}
        for iommu_group in self.VFIO_DEVICES:
            if iommu_group == 'vfio':
                continue
            dev_id = os.listdir('/sys/kernel/iommu_groups/{}/devices'.format(iommu_group))[0]
            with open('/sys/bus/pci/devices/{}/numa_node'.format(dev_id)) as numa_node_file:
                numa_node = int(numa_node_file.read())
            if numa_node not in vfio_numa_mapping:
                vfio_numa_mapping[numa_node] = []
            formatted_dev_id = ':'.join(dev_id.split(':')[1:])
            vfio_numa_mapping[numa_node].append(formatted_dev_id)
        return vfio_numa_mapping

    def _is_numa_enabled(self):
        """
        :rtype: bool
        """
        if self._vm_config.vcpu > self.MAX_SINGLE_NUMA_CPU:
            return True
        elif len(self.VFIO_DEVICES) >= self.MAX_SINGLE_NUMA_VFIO:
            return True
        else:
            return False
