import ctypes
import gi
from classes.stats import qos
from gi.repository import Gst
from classes import logger as log
from pynvml import nvmlInit, nvmlSystemGetDriverVersion, nvmlDeviceGetCount, \
                   nvmlDeviceGetHandleByIndex, nvmlDeviceGetName, \
                   nvmlShutdown, nvmlDeviceGetUtilizationRates


ENCODER_STATUS_NOT_SUPPORTED = 'NOT_SUPPORTED'
ENCODER_STATUS_DRIVER_OUTDATED = 'NEED_DRIVER_UPGRADE'
ENCODER_STATUS_OKAY = 'OKAY'


class GpuUtilization(dict):
    def __init__(self, **kwargs):
        super(GpuUtilization, self).__init__(**kwargs)
        self.__dict__ = self


class BaseGpu:
    def __init__(self, vendor):
        self.vendor = vendor
        self.initialized = self.init_vendor_api()
        self.disable_encoder = False

    @property
    def driver_version(self):
        return '0.0.0'

    @property
    def device_count(self):
        return 0

    @property
    def encoder_status(self):
        return ENCODER_STATUS_NOT_SUPPORTED

    def init_vendor_api(self):
        return False

    def get_device_by_index(self, index):
        return None

    def get_device_name_by_index(self, index):
        return ""

    def get_devices_name(self):
        return [self.get_device_name_by_index(i) for i in range(0, self.device_count)]

    def get_utilization(self):
        return GpuUtilization(gpu=0, memory=0)

    def __str__(self):
        return "GPU vendor: {}, driver version: {}, encoder status: {}, device count: {}, devices name: {}".format(
                self.vendor, self.driver_version, self.encoder_status,
                self.device_count, self.get_devices_name())


class UnknownGpu(BaseGpu):
    def __init__(self):
        BaseGpu.__init__(self, "unknown")


class AmdGpu(BaseGpu):
    def __init__(self):
        BaseGpu.__init__(self, "amd")


class NvidiaGpu(BaseGpu):
    nvenc_minimum_driver_version = 376

    def __init__(self):
        BaseGpu.__init__(self, "nvidia")

    def __del__(self):
        if self.initialized:
            nvmlShutdown()

    @property
    def driver_version(self):
        try:
            return nvmlSystemGetDriverVersion().decode('utf-8')
        except:
            return "0.0.0"

    @property
    def device_count(self):
        try:
            return nvmlDeviceGetCount()
        except:
            return 0

    @property
    def encoder_status(self):
        if self.disable_encoder:
            return ENCODER_STATUS_NOT_SUPPORTED

        driver_version = self.driver_version
        driver_version_major = int(driver_version.split('.')[0])

        # checking by nvenc dll found is broken
        nvenc_dll_found = False
        try:
            element = Gst.ElementFactory.find('d3dnvh264enc')
            if element:
                nvenc_dll_found = True
        except:
            pass

        if driver_version_major < self.nvenc_minimum_driver_version:
            return ENCODER_STATUS_DRIVER_OUTDATED
        elif nvenc_dll_found:
            return ENCODER_STATUS_OKAY
        else:
            return ENCODER_STATUS_NOT_SUPPORTED

    def init_vendor_api(self):
        try:
            nvmlInit()
            return True
        except Exception as e: # most likely, not nvidia
            return False

    def get_device_by_index(self, index):
        try:
            return nvmlDeviceGetHandleByIndex(index)
        except Exception:
            return None

    def get_device_name_by_index(self, index):
        try:
            device_handle = self.get_device_by_index(index)
            return nvmlDeviceGetName(device_handle).decode('utf-8')
        except:
            return ""

    def get_utilization(self):
        try:
            device_handle = None

            # TODO change up how we send it up to the backend for supporting multi-adapters.
            # this is to filter out bad nvidia devices and take the first available ones
            for i in range(0, self.device_count):
                device_handle = self.get_device_by_index(i)
                # found good device, break the loop!
                if device_handle:
                    break

            return nvmlDeviceGetUtilizationRates(device_handle)
        except:
            return GpuUtilization(gpu=0, memory=0)


def _get_gpu_by_test():
    gpu = NvidiaGpu()
    if gpu.initialized:
        return gpu
    gpu = AmdGpu()
    if gpu.initialized:
        return gpu
    return UnknownGpu()


gpu = _get_gpu_by_test()

