import logging
import os
import subprocess

from infra.diskmanager.lib import mount
from infra.diskmanager.lib import consts

log = logging.getLogger('diskmanager.lib.disk')


def guess_md_slave_name(slave):
    """
    "Guess" md slave name.
    Examples:
    sda -> sda
    sda1 -> sda
    nvme0n1p1 -> nvme0n1
    nvme0n1 -> nvme0n1
    nvme0c33n1p2 -> nvme0c33n1
    nvme0c33n1 -> nvme0c33n1
    """
    # Regular disks sd?N, vd?N
    if slave.startswith('sd') or slave.startswith('vd'):
        for i, _ in enumerate(slave):
            if slave[i].isdigit():
                return slave[:i]
        return slave[:len(slave)]
    else:
        # <parent>pN scheme
        for i in range(len(slave) - 2, 0, -1):
            if slave[i] == 'p' and slave[i + 1].isdigit():
                return slave[:i]
        return slave[:len(slave)]


def get_md_slaves(mdname):
    slaves_path = os.path.join('/sys/block', mdname, 'slaves')
    # According to linux/block/genhd.c:register_disk() slaves dir always present
    return os.listdir(slaves_path)


def get_md_storage_class(mdname):
    """
    Get storage class of md device by analyzing md slave devices.
    Same implementation as in infra.rtc.nodeinfo.lib.modules.oops_disks2.
    """
    slaves = get_md_slaves(mdname)
    if len(slaves) == 0:
        return consts.STORAGE_VIRT
    else:
        slave_dev = guess_md_slave_name(slaves[0])
        if 'nvme' in slave_dev or any('nvme' in s for s in slaves):
            return consts.STORAGE_NVME
        else:
            rotational = safe_read_sysfs_file('/sys/block/{}/queue/rotational'.format(slave_dev))
            if rotational is None:
                return consts.STORAGE_VIRT
            rotational = int(rotational)
            if rotational == 0:
                return consts.STORAGE_SSD
            elif rotational == 1:
                return consts.STORAGE_HDD
            else:
                return consts.STORAGE_VIRT


def read_sysfs_file(path, fname=None):
    if fname is not None:
        path = os.path.join(path, fname)
    with open(path, 'rb') as f:
        return f.read(8192).strip()


def safe_read_sysfs_file(path, fname=None, default=None):
    if fname is not None:
        path = os.path.join(path, fname)
    if not os.path.exists(path):
        return default
    with open(path, 'rb') as f:
        return f.read(8192).strip()


def write_sysfs_file(data, path, fname=None):
    if fname is not None:
        path = os.path.join(path, fname)
    log.debug('echo %s > %s' % (data, path))
    with open(path, 'wb') as f:
        return f.write(data)


def _check_nvme_multipath_enabled():
    enabled = False
    try:
        with open('/sys/module/nvme_core/parameters/multipath', 'r') as fd:
            enabled = fd.read(1) == 'Y'
    except Exception as e:
        log.exception("Can not read nvme multipath state: %s" % e)
    return enabled


class Partition(object):
    def __init__(self, name, virtual=False):
        p_path = os.path.join(Disk._SYS_CLASS_BLOCK, name)
        major, minor = Disk.dev_by_syspath(p_path)
        self._udev_info = Disk.get_udev_info(major, minor)
        self.name = name
        self.device_path = os.path.join(Disk._DEVFS, name)
        self.usec_init = long(self._udev_info.get('USEC_INITIALIZED', '0'))
        self.type = self._udev_info.get('ID_PART_ENTRY_TYPE', '')
        self.uuid = self._udev_info.get('ID_PART_ENTRY_UUID', '')
        self.fs_label = self._udev_info.get('ID_FS_LABEL', '')
        self.fs_uuid = self._udev_info.get('ID_FS_UUID', '')
        self.fs_type = self._udev_info.get('ID_FS_TYPE',  '')
        self.fs_owner = self._udev_info.get('ID_FS_USAGE', '').lower()
        if self.fs_owner == 'filesystem':
            self.mnt_path = mount.find_mount_path(major, minor)
        else:
            self.mnt_path = None

        if virtual:
            self.start = 0
            self.uuid = self.fs_uuid
        else:
            val = read_sysfs_file(p_path, 'start')
            self.start = long(val) * 512
        val = read_sysfs_file(p_path, 'size')
        self.size = long(val) * 512


class IOScheduler(object):
    def __init__(self, name, options={}):
        self.name = name
        self.options = dict(options)

    @staticmethod
    def parse_scheduler(schedulers):
        sc_list = schedulers.split()
        if len(sc_list) == 1:
            return sc_list[0]

        for sched in schedulers.split():
            if sched[0] == '[' and sched[-1] == ']':
                return sched[1:-1]
        return 'none'

    def query(self, disk_path):
        schedulers = read_sysfs_file(disk_path, 'queue/scheduler')
        self.name = IOScheduler.parse_scheduler(schedulers)
        opts = []
        s_path = os.path.join(disk_path, 'queue/iosched')
        if self.name == 'cfq':
            opts = ['group_idle']
        elif self.name == 'kyber':
            opts = ['read_lat_nsec', 'write_lat_nsec']
        for o in opts:
            self.options[o] = read_sysfs_file(s_path, o)

    def equal(self, s):
        if self.name != s.name or len(self.options) != len(s.options):
            return False
        for k, v in self.options.iteritems():
            if s.options.get(k, '') != v:
                return False
        return True

    def copy(self):
        return IOScheduler(self.name, dict(self.options))

    def apply(self, disk_path):
        write_sysfs_file(self.name, disk_path, 'queue/scheduler')
        s_path = os.path.join(disk_path, 'queue/iosched')
        for k, v in self.options.iteritems():
            write_sysfs_file(v, s_path, k)


class Disk(object):
    _DEVFS = '/dev'
    _SYS_BLOCK = '/sys/block'
    _SYS_CLASS_BLOCK = '/sys/class/block'
    PART_TYPES = {
        'bios': '21686148-6449-6E6F-744E-656564454649',
        'linux_swap': '0657FD6D-A4AB-43C4-84E5-0933C84B4F4F',
        'linux_filesystem': '0FC63DAF-8483-4772-8E79-3D69D8477DE4',
        'linux_server data': '3B8F8425-20E0-4F3B-907F-1A25A76F98E8',
        'linux_root_x86': '44479540-F297-41B2-9AF7-D131D5F0458A',
        'linux_root_x8664': '4F68BCE3-E8CD-4DB1-96E7-FBCAF984B709',
        'linux_lvm': 'E6D6D379-F507-44C2-A23C-238F2A3DF928'
    }

    def __init__(self, name):
        d_path = os.path.join(Disk._SYS_BLOCK, name)
        self.major, self.minor = Disk.dev_by_syspath(d_path)
        self._udev_info = Disk.get_udev_info(self.major, self.minor)
        self.name = name
        self.sysfs_path = d_path
        self.device_path = os.path.join(Disk._DEVFS, name)
        self.usec_init = long(self._udev_info.get('USEC_INITIALIZED', '0'))
        self.serial = self._udev_info.get('ID_SERIAL', '')
        self.wwn = self._udev_info.get('ID_WWN', '')
        self.model = self._udev_info.get('ID_MODEL', 'Unknown')
        self.wwid = safe_read_sysfs_file(d_path, 'wwid')
        # Use uuid only if nguid is not available, to prevent kernel's warning spam
        self.hw_uuid = safe_read_sysfs_file(d_path, 'nguid')
        if not self.hw_uuid:
            self.hw_uuid = safe_read_sysfs_file(d_path, 'uuid')
        self.type = self._udev_info.get('ID_TYPE', 'Unknown')
        self.bus = self._udev_info.get('ID_BUS', 'Unknown')
        self.fs_type = self._udev_info.get('ID_FS_TYPE', '')
        self.fs_owner = self._udev_info.get('ID_FS_USAGE', '').lower()
        self.part_type = self._udev_info.get('ID_PART_TABLE_TYPE', '')
        self.uuid = self._udev_info.get('ID_PART_TABLE_UUID', '')
        size = read_sysfs_file(d_path, 'size')
        self.size = long(size) * 512
        self.scheduler = IOScheduler("none")
        self.pv_count = 0
        if self.fs_type == 'lvm2_member':
            self.pv_count = 1

        if name.startswith('nvme'):
            self._update_nvme_multipath_names()
        self.parts = []
        try:
            for p_name in Disk.list_disk_parts(name):
                is_virtual = p_name.startswith("loop")
                p = Partition(p_name, is_virtual)
                self.parts.append(p)
                if p.fs_type == 'lvm2_member':
                    self.pv_count = self.pv_count + 1
            # Partitionless disk
            if self.fs_owner == 'filesystem':
                self.virt_part = Partition(self.name, virtual=True)
            else:
                self.virt_part = None

            rot_val = read_sysfs_file(d_path, 'queue/rotational')
            if rot_val == '1':
                self._rotational = True
            else:
                self._rotational = False
            if name.startswith('nvme') and self.nvme_multipath_names:
                for nvme_mp_name in self.nvme_multipath_names:
                    self.scheduler.query(os.path.join(Disk._SYS_BLOCK, nvme_mp_name))
            else:
                self.scheduler.query(self.sysfs_path)

        except Exception as e:
            log.error("Disk %s. Exception: %s" % (name, str(e)))
            raise e

        self.discard_support = False
        if name.startswith('nvme'):
            self.storage_class = consts.STORAGE_NVME
            self.discard_support = True

        elif name.startswith('sd'):
            if not self._rotational:
                self.discard_support = True
                self.storage_class = consts.STORAGE_SSD
            else:
                self.storage_class = consts.STORAGE_HDD
        elif name.startswith('md'):
            self.storage_class = get_md_storage_class(name)
            if self.storage_class in consts.STORAGE_NON_ROTATIONAL:
                self.discard_support = True
            else:
                self.discard_support = False
        else:
            self.discard_support = False
            self.storage_class = consts.STORAGE_VIRT

        self.infer_id()

    # Sane device has valid wwid, wwn or serial number, otherwise generate semi-unique one
    def infer_id(self):
        if self.storage_class == consts.STORAGE_NVME:
            nvme_id = safe_read_sysfs_file(self.sysfs_path, 'device/serial')
            if nvme_id:
                self.id = nvme_id
                return

        if self.hw_uuid:
            self.id = self.hw_uuid
        elif self.wwid:
            self.id = self.wwid
        elif self.wwn:
            self.id = 'wwn-' + self.wwn
        elif self.serial:
            self.id = self.serial
        else:
            self.id = 'virt-' + self.name + '_t' + str(self.usec_init)

    def set_scheduler(self, sched):
        try:
            if not self.name.startswith("loop"):
                log.info("Set ioscheduler for {} old:{}, new:{}, options:{}".format(self.name, self.scheduler.name, sched.name, sched.options))
                self.scheduler = sched.copy()
            else:
                log.info("Disk {} is a loop device. Set ioscheduler to 'none'".format(self.name))
                self.scheduler = IOScheduler("none")
            if self.storage_class == consts.STORAGE_NVME and self.nvme_multipath_names:
                for name in self.nvme_multipath_names:
                    self.scheduler.apply(os.path.join(Disk._SYS_BLOCK, name))
            else:
                self.scheduler.apply(self.sysfs_path)
        except Exception as e:
            log.error("Got exception on try updating ioscheduler for {}. Exception: {}".format(self.name, str(e)))
            raise e

    def _update_nvme_multipath_names(self):
        names = []
        try:
            with open(os.path.join(self.sysfs_path, 'nsid'), 'r') as fd:
                nsid = fd.read(10).strip()
                ns_subname = 'n%s' % nsid
                nvme_prefix = self.name[:-len(ns_subname)]
                for name in os.listdir(Disk._SYS_BLOCK):
                    if name.startswith(nvme_prefix + 'c') and name.endswith(ns_subname):
                        names.append(name)
        except Exception as e:
            log.exception("Cannot read nsid of %s: %s" % (self.name, e))
        self.nvme_multipath_names = names

    @staticmethod
    def list_disks():
        rc = []
        for d in os.listdir(Disk._SYS_BLOCK):
            d_path = os.path.join(Disk._SYS_BLOCK, d)
            if not os.path.islink(d_path):
                continue
            # Fake devices do not have dev-file, ignore it
            # https://st.yandex-team.ru/DISKMAN-50
            if not os.path.exists(os.path.join(Disk._SYS_BLOCK, d, 'dev')):
                continue
            rc.append(d)
        rc.sort()
        return rc

    @staticmethod
    def list_disk_parts(name):
        rc = []
        d_path = os.path.join(Disk._SYS_BLOCK, name)
        if not os.path.exists(d_path):
            log.debug("list_disk_parts for device %s. d_path %s does not exist. Returning empty partitions list" % (name, d_path,))
            return []
        for p in os.listdir(d_path):
            if p.startswith(name):
                p_path = os.path.join(d_path, p)
                if os.path.isdir(p_path):
                    rc.append(p)
            elif Disk._is_loop_device(p):
                rc.append(name)
        rc.sort()
        return rc

    @staticmethod
    def _is_loop_device(file_in_sys_class_block):
        return file_in_sys_class_block == "loop"

    @staticmethod
    def find_device(path):
        name = os.path.basename(path)
        p = os.path.join(Disk._SYS_BLOCK, name)
        if os.path.exists(p):
            # return here full dev path if exists since _raw_disks dict
            # key format is '/dev/<device>'
            dev_path = Disk._get_dev_path(name)
            if not os.path.exists(dev_path):
                return name
            return dev_path

        p = os.path.join(Disk._SYS_CLASS_BLOCK, name)
        if not os.path.exists(p):
            return ''

        p = os.path.join(Disk._SYS_CLASS_BLOCK, name, '..')
        rpath = os.path.realpath(p)
        if not os.path.exists(rpath) or not os.path.isdir(rpath):
            return ''

        bname = os.path.basename(rpath)
        dev_path = Disk._get_dev_path(bname)
        if not os.path.exists(dev_path):
            return ''
        return dev_path

    @staticmethod
    def _get_dev_path(dev_name):
        dev_path = os.path.join(Disk._DEVFS, dev_name)
        if not os.path.exists(dev_path):
            log.debug("get_dev_path. Device name %s path %s does not exist" % (dev_name, dev_path,))
            return ''
        return dev_path

    @staticmethod
    def dev_by_syspath(path):
        dev = read_sysfs_file(os.path.join(path, 'dev'))
        return dev.split(':')

    @staticmethod
    def get_udev_info(major, minor):
        _UDEV_DATA = '/run/udev/data'
        rc = {}
        try:
            udev_path = os.path.join(_UDEV_DATA, 'b' + major + ':' + minor)
            with open(udev_path, 'rb') as f:
                for line in f:
                    if line.startswith('E:'):
                        k, v = line[2:].strip().split('=')
                        rc[k] = v
                    elif line.startswith('I:'):
                        rc['USEC_INITIALIZED'] = line[2:]
        except IOError:
            return {}
        return rc

    @staticmethod
    def make_fs(device, fstype, opts=[], ext_opts=[], reserved_space=None):
        if fstype == 'ext4':
            if opts:
                opts = ['-O', ','.join(opts)]
            if ext_opts:
                opts.extend(['-E', ','.join(ext_opts)])
            if reserved_space is not None:
                opts.extend(['-m', str(reserved_space)])

            cmd = ['mkfs.ext4', '-q', '-F'] + opts
        elif fstype == 'xfs':
            cmd = ['mkfs.xfs', '-q', '-f']
        else:
            raise Exception('Unknown fstype %s' % fstype)

        cmd.append(device)
        log.debug("exec : %s" % cmd)
        subprocess.check_call(cmd)

    @staticmethod
    def mount_fs(device,  path, fstype=None, fs_opt=None):
        cmd = ['mount', device, path]
        if fstype == 'ext4':
            cmd.extend(['-t', 'ext4'])
            if fs_opt:
                cmd.extend(['-o', fs_opt])
        elif fstype == 'xfs':
            cmd.extend(['-t', 'xfs'])
            if fs_opt:
                cmd.extend(['-o', fs_opt])
        else:
            raise Exception('Unknown fstype %s' % fstype)
        log.debug("exec : %s" % cmd)
        subprocess.check_call(cmd)

    @staticmethod
    def umount_fs(path):
        cmd = ['umount', path]
        log.debug("exec : %s" % cmd)
        subprocess.check_call(cmd)

    @staticmethod
    def remount_fs(path, fs_opt):
        cmd = ['/opt/diskmanager/utils/remount', path, fs_opt]
        try:
            log.debug("exec : %s" % cmd)
            p = subprocess.Popen(cmd, stdout=subprocess.PIPE)
            p.wait()
        except OSError as e:
            raise e

    @staticmethod
    def check_fs(device, fstype, force_check=True, force_fix=False):
        if fstype == 'ext4':
            cmd = ['fsck.ext4']
            if force_check:
                cmd.append('-f')
            if force_fix:
                cmd.append('-y')
        else:
            raise Exception('Unknown fstype %s' % fstype)
        cmd.append(device)
        log.debug("exec : %s" % cmd)
        subprocess.check_call(cmd)

    @staticmethod
    def make_gpt(device, guid=''):
        if not guid:
            # Random guid
            guid = 'R'
        cmd = ['sgdisk', '-o', '-g', '-U', guid, device]
        log.debug("exec : %s" % cmd)
        subprocess.check_call(cmd)

    @staticmethod
    def add_gpt_part(device, idx, start, size, type_=None, name=''):
        if type_ is None:
            type_ = Disk.PART_TYPES['linux_filesystem']
        if not name:
            name = 'part-%d' % idx
        cmd = ['sgdisk',
               '-n', '{}:{}:{}'.format(idx, start, size),
               '-c', '{}:"{}"'.format(idx, name),
               '-t', '{}:{}'.format(idx, type_),
               device]
        log.debug("exec : %s" % cmd)
        subprocess.check_call(cmd)

    @staticmethod
    def grub_install(device):
        cmd = ['grub-install', device]
        log.debug("exec : %s" % cmd)
        subprocess.check_call(cmd)
