"""
Gathers node resources information (diskman, gpuman, etc) which can be consumed by YP users.
"""
import logging
import traceback
import json
import platform

import grpc

from infra.diskmanager.proto import diskman_pb2, diskman_pb2_grpc
from infra.rsm.nvgpumanager.api import nvgpu_pb2, nvgpu_pb2_grpc

from infra.ya_salt.lib import pbutil
from infra.ya_salt.lib import subprocutil

from infra.rtc.nodeinfo.lib.modules import cpu
from infra.rtc.nodeinfo.lib.modules import lshw_cpu
from infra.rtc.nodeinfo.lib.modules import mem
from infra.rtc.nodeinfo.lib.modules import oops_disks2
from infra.rtc.nodeinfo.lib.modules import os_info as osi
from infra.rtc.nodeinfo.lib.modules import location_info as loc_info
from infra.rtc.nodeinfo.lib.modules import net
from infra.rtc.nodeinfo.lib.modules import numa
from infra.rtc.nodeinfo.lib.modules import overrides
from infra.rtc.nodeinfo.lib import traceutil

log = logging.getLogger('node_info')

# Timeouts must not be larger than our budget, which is ~30 seconds.
DM_UNIX_PATH = 'unix:/run/diskman.sock'
DM_TIMEOUT = 10
NVGPU_UNIX_PATH = 'unix:/run/nvgpu-manager.sock'
NVGPU_TIMEOUT = 10

SOXAUDIT_PATH = '/usr/sbin/rtc-soxaudit'
SOXAUDIT_TIMEOUT = 10


def sox_info(node_info):
    out, err, status = subprocutil.check_output([SOXAUDIT_PATH], SOXAUDIT_TIMEOUT)
    if not status.ok:
        return 'failed to get sox info: {}'.format(err)
    else:
        try:
            d = json.loads(out)
            return pbutil.pb_from_dict(node_info.sox_info, d)
        except Exception as e:
            return "failed to parse sox info: {}".format(e)


def mem_info(node_info):
    mi, err = mem.mem_info()
    if err is not None:
        return err
    node_info.mem_info.total_bytes = mi.total_bytes
    return None


def infer_lshw_cpu_info(node_info, lshw=lshw_cpu.get_cpu_info):
    if platform.machine() not in ('aarch64', ):
        return None
    ci, err = lshw()
    if err is not None:
        return err
    cpu_info_pb = node_info.cpu_info
    cpu_info_pb.model_name = ci[0].model_name
    cpu_info_pb.vendor_id = ci[0].vendor
    cpu_info_pb.cpus = sum(i.threads for i in ci)
    cpu_info_pb.cores = sum(i.cores for i in ci)
    cpu_info_pb.sockets = len(ci)


def cpu_info(node_info):
    ci, err = cpu.cpu_info()
    if err is not None:
        return err
    # Transform cpu info into node_info
    if not ci:
        return 'no cpus found'
    # Grab first CPU to report, we assume all CPUs are equal
    c = ci[0]
    cpu_info_pb = node_info.cpu_info
    cpu_info_pb.model_name = c.model_name
    cpu_info_pb.vendor_id = c.vendor_id
    cpu_info_pb.mhz = c.mhz
    del cpu_info_pb.flags[:]
    cpu_info_pb.flags.extend(c.flags)
    nsockets = len(set(c.physical_id for c in ci))
    ncores = len(set(c.core_id for c in ci))
    nthreads = len(ci)
    cpu_info_pb.cpus = nthreads
    cpu_info_pb.sockets = nsockets
    cpu_info_pb.cores = ncores


def numa_info(node_info):
    ni = numa.get_numa_info()
    pb_from_numa_info(node_info, ni)


def pb_from_numa_info(node_info, ni):
    del node_info.cpu_info.numa_nodes[:]
    del node_info.mem_info.numa_nodes[:]
    for n in ni.nodes:
        pb = node_info.cpu_info.numa_nodes.add()
        pb.node = n.node
        pb.cpus.extend(n.cpus)
        pb = node_info.mem_info.numa_nodes.add()
        pb.node = n.node
        pb.total_bytes = n.mem_total


def pb_from_oops_info(node_info, disks_info):
    for disk in disks_info:
        pb = node_info.oops_disks.add()

        pb.mountPoint = disk.mount_point if disk.mount_point else ''
        pb.type = disk.type if disk.type else ''
        pb.fsSize = disk.fs_size if disk.fs_size else 0
        pb.hwInfo = disk.hw_info if disk.hw_info else ''
        pb.slaves.extend(disk.slaves)
        pb.name = disk.name if disk.name else ''


def oops_disks_info(node_info):
    try:
        disks_info = sorted(oops_disks2.oops_disks2(), key=lambda x: x.name)
        # clean up old disk records
        del node_info.oops_disks[:]
        pb_from_oops_info(node_info, disks_info)
    except Exception as e:
        return "cannot get oops_disks_info:{}\n{}".format(str(e), traceback.format_exc())
    return None


def pb_from_os_info(node_info, os_info):
    node_info.os_info.version = os_info.version or "unknown os version"
    node_info.os_info.kernel = os_info.kernel or "unknown kernel version"
    node_info.os_info.type = os_info.type or "unknown os type"
    node_info.os_info.codename = os_info.codename or "unknown os codename"


def os_info(node_info):
    try:
        info = osi.get_os_info()
        pb_from_os_info(node_info, info)
    except Exception as e:
        return "cannot get os_info:{}\n{}".format(str(e), traceback.format_exc())
    return None


def pb_from_location_info(node_info, location_info):
    node_info.location_info.country = location_info.country or "unknown country"
    node_info.location_info.city = location_info.city or "unknown city"
    node_info.location_info.building = location_info.building or "unknown building"
    node_info.location_info.line = location_info.line or "unknown dc line"
    node_info.location_info.rack = location_info.rack or "unknown rack"


def location_info(node_info):
    info, err = loc_info.get_location_info()
    if err is not None:
        return "failed to load location_info: {}".format(err)
    try:
        pb_from_location_info(node_info, info)
    except Exception as e:
        return 'failed to parse location info: {}'.format(e)
    return None


def pb_from_network_info(node_info, net_info):
    node_info.network_info.interface_name = net_info.interface_name
    node_info.network_info.bandwidth = net_info.bandwidth
    node_info.network_info.bb_fqdn = net_info.bb_fqdn
    node_info.network_info.fb_fqdn = net_info.fb_fqdn
    node_info.network_info.mtn_prefix = net_info.bb_prefix
    node_info.network_info.mtn_fb_prefix = net_info.fb_prefix
    node_info.network_info.bb_ipv6_addr = net_info.bb_ipv6_addr
    node_info.network_info.fb_ipv6_addr = net_info.fb_ipv6_addr


def network_info(node_info, azure=False):
    try:
        info = net.get_network_info(azure=azure)
        pb_from_network_info(node_info, info)
    except Exception as e:
        return "cannot get network_info:{}\n{}".format(str(e), traceback.format_exc())
    return None


def query_node_info(node_info, azure=False):
    t = traceutil.Trace('node')
    t.step('diskman')
    try:
        with grpc.insecure_channel(DM_UNIX_PATH) as ch:
            stub = diskman_pb2_grpc.DiskManagerStub(ch)
            resp = stub.ListDisks(diskman_pb2.ListDisksRequest(), timeout=DM_TIMEOUT)
    except Exception as e:
        err = 'diskman.ListDisks() failed: {}'.format(e)
        return 'diskman info failed: {}'.format(err)
        # Should we delete all disks? Maybe after some timeout?
    # For now - gather all disks
    del node_info.disks[:]
    node_info.disks.extend(resp.disks)
    # Other stuff
    t.step('mem')
    err = mem_info(node_info)
    if err is not None:
        return 'mem info failed: {}'.format(err)
    t.step('cpu')
    err = cpu_info(node_info)
    if err is not None:
        return 'cpu info failed: {}'.format(err)
    t.step('lshw_cpu')
    err = infer_lshw_cpu_info(node_info)
    if err is not None:
        return 'lshw cpu info failed: {}'.format(err)
    t.step('oops')
    err = oops_disks_info(node_info)
    if err is not None:
        return err
    t.step('location')
    err = location_info(node_info)
    if err is not None:
        return err
    t.step('network')
    err = network_info(node_info, azure=azure)
    if err is not None:
        return err
    t.step('sox')
    err = sox_info(node_info)
    if err is not None:
        log.error('Failed to run sox info: {}'.format(err))
    t.step('numa')
    numa_info(node_info)
    err = overrides.apply_overrides(node_info)
    if err is not None:
        return err
    t.log_if_long(20000)
    return None


def query_os_info(node_info):
    return os_info(node_info)


def query_gpu_info(gpu_info):
    del gpu_info.devices[:]
    try:
        with grpc.insecure_channel(NVGPU_UNIX_PATH) as ch:
            stub = nvgpu_pb2_grpc.NvGpuManagerStub(ch)
            resp = stub.ListDevices(nvgpu_pb2.Empty(), timeout=NVGPU_TIMEOUT)
    except Exception as e:
        return 'nvgpumanager.ListDevices() failed: {}'.format(e)
    gpu_info.devices.extend(resp.devices)
    return None


def query_boot_info(node_info, path='/proc/sys/kernel/random/boot_id'):
    # We panic (raise Exception) if failed to read boot id.
    with open(path) as f:
        node_info.boot_id = f.read().rstrip()


class NodeInfo(object):
    def __init__(self, query_gpu=False, azure=False):
        self._query_gpu = query_gpu
        self._azure = azure

    def run(self, node_info):
        errs = []
        t = traceutil.Trace('run')
        t.step('boot')
        query_boot_info(node_info)
        t.step('os')
        err = query_os_info(node_info)
        if err is not None:
            errs.append(err)
        err = query_node_info(node_info, self._azure)
        if err is not None:
            errs.append(err)
        if self._query_gpu:
            t.step('gpu')
            err = query_gpu_info(node_info.gpu_info)
            if err is not None:
                errs.append(err)
        t.log_if_long(20000)
        if errs:
            pbutil.false_cond(node_info.ok, '; '.join(errs))
        else:
            pbutil.true_cond(node_info.ok)


def gpu_present(walle_tags):
    """
    Checks if walle tags set has gpu mark.

    If we had labels, we'd effectively check if we have gpu label and it is not equal to "none".
    match:
        key: rtc.gpu
        op: NEQ
        value: none
    """
    for t in walle_tags:
        if t.startswith('rtc.gpu-') and t != 'rtc.gpu-none':
            return True
    return False


def running_in_azure(walle_tags):
    return "azure" in walle_tags


def run_node_info(walle_tags, ni_status):
    log.info('Running node info component...')
    # We do not want to query gpu manager on hosts where it is not present.
    # How do we determine that? We could look at packages or some other mechanism.
    # But at the moment this approach (walle tags) seems reasonable.
    query_gpu = gpu_present(walle_tags)
    # Do not gather node information if initial setup has not passed yet,
    # because some utilities or services may be not ready/broken and **IT IS OKAY**
    # when host is being setup after redeploy.
    ni = NodeInfo(query_gpu=query_gpu, azure=running_in_azure(walle_tags))
    ni.run(ni_status)
