import collections
import logging

import os

log = logging.getLogger(__name__)

WELL_KNOWN_MOUNTPOINTS = ['/', '/home', '/place', '/ssd']
YT_PREFIX = '/yt'
DevInfo = collections.namedtuple('DevInfo', 'major,minor,name,model_serial,slaves,parent,type')


class OOPSDiskInfo(object):
    def __init__(self, type=None, hw_info=None, mount_point=None, slaves=None, fs_size=None, name=None):
        self.type = type
        self.hw_info = hw_info
        self.mount_point = mount_point
        self.slaves = slaves if slaves is not None else []
        self.fs_size = fs_size
        self.name = name

    def __str__(self):
        return 'OOPSDiskInfo(type: {}, hw_info: {}, mount_point: {}, slaves: {}, fs_size: {}, name: {})'.format(
            self.type, self.hw_info, self.mount_point, self.slaves, self.fs_size, self.name)

    def __repr__(self):
        return self.__str__()

    def __eq__(self, other):
        rv = True
        # don't check hw_info if any part has 'pv' as hw_info to avoid diff on hosts with LVM volumes
        if 'pv' not in (self.hw_info, other.hw_info):
            rv &= self.hw_info == other.hw_info

        # TODO: remove dirty hack together with oops_disks
        # dirty hack for buggy oops_disks sdm device handling
        if not (self.name.startswith('sdm') or other.name.startswith('sdm')):
            rv &= self.type == other.type

        rv &= (self.mount_point == other.mount_point and self.slaves == other.slaves and self.fs_size == other.fs_size
               and self.name == other.name)
        return rv


def get_maj_min(path):
    # we should not follow symlinks
    st = os.lstat(path)
    return os.major(st.st_dev), os.minor(st.st_dev)


def get_mountpoint_numbers(listdir=os.listdir, maj_min=get_maj_min, isdir=os.path.isdir):
    mounts = {}
    possible_mounts = [m for m in WELL_KNOWN_MOUNTPOINTS]
    if isdir(YT_PREFIX):
        # try discover all yt mounts, except pmem
        possible_mounts.extend(os.path.join(YT_PREFIX, i) for i in listdir(YT_PREFIX) if not i.startswith('pmem'))
    for i in possible_mounts:
        if not isdir(i):
            continue
        num = maj_min(i)
        if num not in mounts:
            mounts[num] = i
        else:
            # find shortest mount path for same (maj, min)
            # if / and /ssd present on same fs, so mount point of whole fs will be at / not at /ssd
            if len(mounts[num]) > len(i):
                mounts[num] = i
    return mounts


def read_udev_info(major, minor, udev_prefix='/run/udev/data', open_func=open):
    rv = {}
    path = '{}/b{}:{}'.format(udev_prefix, major, minor)
    with open_func(path, 'r') as f:
        for line in f:
            if line.startswith('E:'):
                k, v = line[2:].strip().split('=', 1)
                rv[k] = v
    return rv


def guess_md_parent_name(slave):
    if slave.startswith('sd') or slave.startswith('vd'):
        for i in range(len(slave)):
            if slave[i].isdigit():
                return slave[:i]
        return slave[:len(slave)]
    else:
        # <parent>pN scheme
        for i in range(len(slave) - 2, 0, -1):
            if slave[i] == 'p' and slave[i + 1].isdigit():
                return slave[:i]
        return slave[:len(slave)]


def read_device(name, path='/sys/block', parent=None, open_func=open, listdir=os.listdir, isdir=os.path.isdir):
    dev_path = os.path.join(path, name)
    with open_func(os.path.join(dev_path, 'dev')) as f:
        major, minor = map(int, f.read().strip().split(':'))
    slaves_path = os.path.join(dev_path, 'slaves')
    if isdir(slaves_path):
        slaves = listdir(slaves_path)
    else:
        slaves = []
    udev = read_udev_info(major, minor, open_func=open_func)
    # skip pmem devices for compatibility with oops_disks
    if name.startswith('pmem'):
        return []

    dev_type_path = os.path.join(dev_path, 'queue', 'rotational')
    # https://github.com/torvalds/linux/blob/v5.11/drivers/md/md.c#L6020
    # md devices has unstable rotational flag, so we need to guess rotational by first slave's type
    if name.startswith('md'):
        if len(slaves) > 1:
            first_slave_parent = guess_md_parent_name(slaves[0])
            dev_type_path = os.path.join(path, first_slave_parent, 'queue', 'rotational')
    try:
        with open_func(dev_type_path) as f:
            rotational = int(f.read().strip())
    except IOError or OSError:
        dev_type = None
    else:
        if rotational == 1:
            dev_type = 'HDD'
        # guess NVME type based on dev name or slave name
        elif 'nvme' in dev_path or any('nvme' in s for s in slaves):
            dev_type = 'NVME'
        else:
            dev_type = 'SSD'
    # oops_disks uses md device name as hw_info in case of md device
    # and device model with serial in case of physical device
    if 'ID_SERIAL' in udev:
        hw_info = udev['ID_SERIAL']
        # preserve oops_disks bug in get_disks_hw_info() parsing symlink names
        hw_info = hw_info.split('-', 1)[0]
        # replace spaces to underscores to match oops_disks format
        hw_info = hw_info.replace(' ', '_')
    elif 'MD_NAME' in udev:
        hw_info = udev['MD_NAME']
    elif 'MD_UUID' in udev:
        hw_info = udev['MD_UUID']
    elif 'DM_NAME' in udev:
        hw_info = udev['DM_NAME']
    else:
        hw_info = None

    rv = [DevInfo(major, minor, name, hw_info, slaves, parent, dev_type)]
    for i in listdir(dev_path):
        if i.startswith(name):
            rv.extend(read_device(i, dev_path, rv[0], open_func=open_func, listdir=listdir, isdir=isdir))
    return rv


def enumerate_devices(listdir=os.listdir, rd=read_device, isfile=os.path.isfile):
    devices = {}
    sys_block = '/sys/block'
    for i in listdir(sys_block):
        # skip devices without dev file containing maj:min
        # skip loop and nbd devices
        if isfile(os.path.join(sys_block, i, 'dev')) and not (i.startswith('loop') or i.startswith('nbd')):
            for k in rd(i):
                # skip IPMI devices
                if k.model_serial and 'IPMI' in k.model_serial:
                    continue
                devices[(k.major, k.minor)] = k
    return devices


def get_dev_type(dev_info):
    cur = dev_info
    while cur:
        if cur.type:
            return cur.type
        else:
            cur = cur.parent
    return None


# add non-mounted devices from /dev/disk/by-uuid for compatibility with oops_disks
def mix_in_additional_devices(existing, additional, path='/dev/disk/by-uuid', listdir=os.listdir, readlink=os.readlink):
    existing_set = {d.name for d in existing}
    for i in listdir(path):
        dev = os.path.basename(readlink(os.path.join(path, i)))
        if dev not in existing_set and dev not in additional:
            additional.add(dev)


def oops_disks2(statvfs=os.statvfs, mounts=get_mountpoint_numbers, devs=enumerate_devices,
                add_devs=mix_in_additional_devices):
    rv = []
    mounts = mounts()
    devices = devs()
    additional_devices = set()
    for (major, minor), mount in mounts.items():
        try:
            dev = devices[(major, minor)]
            vfs = statvfs(mount)
            fs_size = vfs.f_blocks * vfs.f_frsize
            rv.append(OOPSDiskInfo(
                type=get_dev_type(dev),
                hw_info=dev.model_serial,
                mount_point=mount,
                slaves=dev.slaves,
                fs_size=fs_size,
                name=dev.name
            ))
            additional_devices.update(dev.slaves)
        except KeyError:
            log.warn('failed to get device for mount: {}'.format(mount))
    add_devs(rv, additional_devices)
    dev_by_name = {d.name: d for d in devices.values()}
    devices_seen = set()
    while additional_devices:
        dev = additional_devices.pop()
        if dev in devices_seen:
            continue
        devices_seen.add(dev)
        dev = dev_by_name.get(dev)
        if dev:
            # also add slaves of additional devices
            additional_devices.update(dev.slaves)
            rv.append(
                OOPSDiskInfo(
                    type=get_dev_type(dev),
                    hw_info=dev.model_serial,
                    mount_point=None,
                    slaves=dev.slaves,
                    fs_size=None,
                    name=dev.name
                )
            )
    return rv
