from __future__ import print_function, absolute_import
import os
import traceback
import sys
from .common import oldstyle_main, make_event, timestamp, truncate_description_string
from juggler.bundles import as_check, Status

CHECK_NAME = 'walle_fstab'

SUPPORTED_DEVICES = {
    "UUID": "by-uuid",
    "PARTUUID": "by-partuuid",
    "PATH": "by-path",
    "LABEL": "by-label",
    "PARTLABEL": "by-partlabel",
    "ID": "by-id"
}

MAX_DESCRIPTION_LEN = 700


class UnsupportedDeviceType(Exception):
    pass


class FSTabEntry(object):
    def __init__(self, line, mounts):
        self._mounts = mounts

        components = line.split()

        dev_parts = components[0].split('=')

        # key-value device discovery
        if len(dev_parts) == 2:
            self.dev_type, self.name = dev_parts[0], dev_parts[1]
            if self.dev_type not in SUPPORTED_DEVICES:
                raise UnsupportedDeviceType("device type {} is not supported when parsing line {}".format(
                    self.dev_type,
                    components[0]
                ))

            self.dev = os.path.realpath(os.path.join("/dev/disk", SUPPORTED_DEVICES[self.dev_type], self.name))

        else:
            self.dev_type = 'DEVFILE'
            self.name = os.path.basename(components[0])
            self.dev = os.path.realpath(components[0]) if components[0].startswith('/') else components[0]

        self.mnt = components[1]
        self.fstype = components[2]
        self.opts = set(components[3].split(','))
        self.dump = components[4]
        self.pass_ = components[5]

    def __str__(self):
        return 'FSTAB:{}({} {} {} {} {} {})'.format(
            self.dev_type, self.dev, self.mnt, self.fstype, ','.join(self.opts), self.dump, self.pass_)

    def __repr__(self):
        return str(self)

    def dev_exists(self):
        return self.dev.startswith('/') and os.path.exists(self.dev)

    def get_maj_min(self):
        stat = os.stat(self.dev)
        return os.major(stat.st_rdev), os.minor(stat.st_rdev)

    def get_maj_min_str(self):
        return '{}:{}'.format(*self.get_maj_min())

    def is_mounted(self):
        info = self._mounts.find_mount_info_for_path(self.mnt)
        if self.dev_exists():
            major, minor = self.get_maj_min()
            if major == 0 and minor == 0:
                if self.mnt == info.mountpoint:
                    # guess loop mount
                    backing_file = info.get_loop_backing_file()
                    if backing_file == self.dev:
                        return True

                    # guess bind mount
                    if self.dev == os.path.realpath(info.root):
                        return True

            if info.mountpoint == self.mnt and info.maj_min == self.get_maj_min_str():
                return True

        if not self.dev_exists():
            # maybe pseudo-fs or complex fs like ZFS
            if info.mountpoint == self.mnt:
                return True

        return False

    def is_equal_rw_flag(self):
        info = self._mounts.find_mount_info_for_path(self.mnt)
        if ('rw' in self.opts or 'ro' not in self.opts) and 'rw' in info.opts:
            return True

        if 'ro' in self.opts and 'ro' in info.opts:
            return True

        return False

    def shoud_be_mounted(self):
        return 'noauto' not in self.opts


class MountNode(object):
    def __init__(self, name, parent=None, info=None):
        self.info = info
        self.parent = parent
        self.name = name
        self.children = {}

    def add_child(self, child):
        self.children[child.name] = child
        child.parent = self

    def get_child(self, name):
        return self.children[name]

    def has_child(self, name):
        return name in self.children


class MountInfo(object):
    def __init__(self, line):
        components = line.split()
        self.mount_id = int(components[0])
        self.parent_mount_id = int(components[1])
        self.maj_min = components[2]
        self.root = components[3]
        self.mountpoint = components[4]
        self.opts = components[5].split(',')

    def get_loop_backing_file(self, open_func=open):
        try:
            with open_func('/sys/dev/block/{}/loop/backing_file'.format(self.maj_min), 'r') as f:
                return f.read().strip()
        except IOError:
            return None

    def __str__(self):
        return 'MountInfo: {} {} {} {} {} {}'.format(
            self.mount_id,
            self.parent_mount_id,
            self.maj_min,
            self.root,
            self.mountpoint,
            ','.join(self.opts)
        )

    def __repr__(self):
        return str(self)


class MountTree(object):
    def __init__(self, pid='1', open_func=open):
        info_by_parent = {}
        info_by_id = {}
        mounts_queue = []
        with open_func('/proc/{}/mountinfo'.format(pid), 'r') as f:
            for line in f:
                line = line.strip()
                if line == '':
                    continue
                info = MountInfo(line.strip())
                info_by_id[info.mount_id] = info
                if info.parent_mount_id not in info_by_parent:
                    info_by_parent[info.parent_mount_id] = []
                info_by_parent[info.parent_mount_id].append(info)
        # Prioir to kernel 5.4 tree root mount_id == 0
        # On kernel >= 5.4 tree root mount_id == 1
        # In any way, its minimal parent_mount_id in mountinfo
        root_index = min(info_by_parent.keys())
        self.root = MountNode('', info=info_by_parent[root_index][0])
        mounts_queue.extend(info_by_parent[self.root.info.mount_id])

        while len(mounts_queue) > 0:
            cur = mounts_queue.pop(0)
            mounts_queue.extend(info_by_parent.get(cur.mount_id, []))
            self._add_node_by_info(cur)

    def _add_node_by_info(self, child_info):
        child_path = child_info.mountpoint.split('/')[1:]
        cur_node = self.root
        while len(child_path) > 0:
            name = child_path.pop(0)
            if not cur_node.has_child(name):
                node = MountNode(name, cur_node)
                cur_node.add_child(node)
            cur_node = cur_node.get_child(name)

        cur_node.info = child_info

    def find_mount_info_for_path(self, path):
        queue = path.split('/')[1:]
        cur_node = self.root
        while len(queue) > 0:
            name = queue.pop(0)
            if not cur_node.has_child(name):
                break
            cur_node = cur_node.get_child(name)
        next_node = cur_node
        while next_node:
            if cur_node.info:
                return cur_node.info
            cur_node = next_node
            next_node = next_node.parent

        return next_node.info if next_node else self.root.info


def create_fstab_entry(line, mounts):
    return FSTabEntry(line, mounts)


def create_minfo_entry(line):
    return MountInfo(line)


def check_and_format_description(unmounted, ro):
    if not unmounted and not ro:
        return Status.OK, "Ok"
    messages = []
    for u in unmounted:
        messages.append("{} is not mounted".format(u))

    for i in ro:
        messages.append("{} has inconsistent flags".format(i))

    return Status.CRIT, truncate_description_string(', '.join(messages), MAX_DESCRIPTION_LEN)


def read_fstab(mounts, open_func=open):
    result = []
    with open_func('/etc/fstab', 'r') as f:
        for line in f:
            line = line.strip()
            if not line.startswith("#") and line != '':
                result.append(create_fstab_entry(line, mounts))

    return result


def run_check(open_func=open):
    unmounted_entries = []
    rw_inconsistent = []

    mounts = MountTree(open_func=open_func)
    fstab = read_fstab(mounts, open_func)

    for fstab_entry in fstab:
        if not fstab_entry.is_mounted():
            if fstab_entry.shoud_be_mounted():
                unmounted_entries.append(fstab_entry)
            continue
        if not fstab_entry.is_equal_rw_flag():
            rw_inconsistent.append(fstab_entry)

    return check_and_format_description(unmounted_entries, rw_inconsistent)


@as_check(name=CHECK_NAME)
def juggler_check(open_func=open):
    description = ""
    try:
        check_exit_status, description = run_check(open_func=open_func)
    except Exception as e:
        _, _, tb = sys.exc_info()
        tb = traceback.format_exception(type(e), e, tb)
        check_exit_status = Status.WARN
        description = truncate_description_string("Can't get check data: {}".format(''.join(tb)), MAX_DESCRIPTION_LEN)
    finally:
        metadata = {
            "timestamp": timestamp(),
            "reason": description
        }
    return make_event(check_exit_status, metadata)


if __name__ == '__main__':
    oldstyle_main(CHECK_NAME, juggler_check())
