import json
import logging
import os
import sys
import traceback

from .abstract import PlainModule, Warnings, iso_date, run_command

LOG = logging.getLogger(__name__)


def get_mem_value(value):
    s = value
    s.strip()
    res = s
    try:
        if s.endswith('MiB'):
            res = int(float(s[:-3]) * 1048576)
        if s.endswith('GiB'):
            res = int(float(s[:-3]) * 1073741824)
        if s.endswith('MB'):
            res = int(float(s[:-2]) * 1000000)
        if s.endswith('GB'):
            res = int(float(s[:-2]) * 1000000000)
    except:
        pass
    return res


def get_cuda_info(warnings):  # type: (Warnings) -> dict
    script = '''
import pycuda.driver as drv
import json
import os

res = {}
drv.init()
for n in range(drv.Device.count()):
    r = {}
    dev = drv.Device(n)
    cap = '%d.%d' % dev.compute_capability()
    r['capability'] = cap
    r['mp_count'] = dev.get_attribute(drv.device_attribute.MULTIPROCESSOR_COUNT)
    if os.path.exists('/dev/nvidia%s' % n):
        r['path'] = '/dev/nvidia%s' % n
    res[n] = r
print(json.dumps(res))
'''

    try:
        cmd = ['/usr/bin/python', '-c', script]
        result = run_command(cmd)
        if result.returncode != 0:
            warnings.log("get_cuda_info failed with: %s", result.returncode)
            return {}
        else:
            return json.loads(result.out)
    except Exception:
        warnings.log("cuda information get failed: %s", traceback.format_exc())
        return {}


class AgentModule(PlainModule):
    util = '/usr/bin/nvidia-smi'
    key = 'adapters'

    good_keys = {
        'Product Name': 'name',
        'Product Brand': 'brand',
        'Driver Model.Serial Number': 'driverSn',
        'Driver Model.GPU UUID': 'driverUUID',
        'Driver Model.Minor Number': 'driverMinorNumber',
        'Driver Model.VBIOS Version': 'driverVbiosVersion',
        'FB Memory Usage.Total': 'memorySize',
    }

    def get_value(self, test_text=None, gpu_test=None):
        result = {}
        if test_text:
            result[self.key] = self.parse_gpu(test_text)
        else:
            if os.path.isfile(self.util):
                s = run_command(self.util + ' -q').out.strip().split('\n')
                result[self.key] = self.parse_gpu(s)

        cuda_info = gpu_test if gpu_test else get_cuda_info(self.warnings)
        if cuda_info:
            for k in cuda_info:
                for r in result[self.key]:
                    if r.get('driverMinorNumber') == str(k):
                        r['cudaVersion'] = cuda_info[k].get('capability')
                        r['devicePath'] = cuda_info[k].get('path')
                        r['multiprocessorCount'] = cuda_info[k].get('mp_count')
                        break

        if result:
            result['changeTime'] = iso_date()
        return self.format_answer('gpu', result) if result else None

    def parse_gpu(self, lines):
        res = []
        gpu = {}
        new_key = True
        key = ''
        for l in lines:
            if not l.strip():
                continue
            if l.startswith('GPU'):
                if gpu:
                    res.append(gpu)
                    gpu = {}
                gpu['busId'] = l.split(' ', 2)[1]
                new_key = True
                key = ''
                continue
            if ':' in l:
                kk, v = [x.strip() for x in l.split(':', 1)]
                k = key + '.' + kk if key else kk
                if k in self.good_keys.keys():
                    gpu[self.good_keys[k]] = get_mem_value(v)
                new_key = True

            else:
                if new_key:
                    key = ''
                    new_key = False
                key = key + '.' + l.strip() if key else l.strip()
        if gpu:
            res.append(gpu)
        return res


if __name__ == '__main__':
    logging.basicConfig(level='INFO')
    print json.dumps(AgentModule(sys.platform).get_value(), indent=4)
