#!/usr/bin/env python3

from __future__ import print_function
import subprocess
import re
import time
from library.python.ssh_client import SshClient


class KernelMenu(object):
    etc_path = "/etc/default/grub"
    cfg_path = "/boot/grub/grub.cfg"
    env_path = "/boot/grub/grubenv"

    config_re = re.compile(r'^(?P<key>[A-Z_]+)=(?P<val>.*)$')
    default_re = re.compile(r'\s*set default="(?P<default>.*)"')
    submenu_re = re.compile(r'submenu (?:\'|\")(?P<name>[^\"\']+)(?:\'|\")'
                            r'(?: \$menuentry_id_option \'(?P<id>.*)\')? {')
    entry_re = re.compile(r'\s*menuentry (?:\'|\")'
                          r'(?P<name>Ubuntu, with Linux (?P<kernel>\S+)'
                          r'(?: \((?P<flavor>.*)\))?)'
                          r'(?:\'|\")'
                          r' --class ubuntu'
                          r' --class gnu-linux'
                          r' --class gnu'
                          r' --class os'
                          r'(?: \$menuentry_id_option \'(?P<id>gnulinux-.*)\')?'
                          r' {')

    def __init__(self, hostname=None, ssh_port=None,
                 username=None, private_key=None, sudo=None):

        self.remote = SshClient(hostname=hostname, port=ssh_port,
                                username=username, private_key=private_key,
                                connect_timeout=30, keepalive=5)

        if sudo is None:
            sudo = self.call_output(['whoami']) != 'root\n'
        if sudo is True:
            self.sudo = ['sudo']
        elif sudo is False:
            self.sudo = []
        else:
            self.sudo = sudo

        self.load()

    def call_output(self, args, **kwargs):
        return self.remote.check_output(args, stdin=subprocess.DEVNULL, encoding='utf-8', **kwargs)

    def call_root(self, args, stdin=subprocess.DEVNULL, **kwargs):
        print("+", "'" + "' '".join([a.replace("'", "'\\''") for a in args]) + "'")
        return self.remote.check_call(self.sudo + args, stdin=stdin, encoding='utf-8', **kwargs)

    def load(self):
        self.current_kernel = self.call_output(['cat', "/proc/sys/kernel/osrelease"]).strip()
        self.current_cmdline = self.call_output(['cat', "/proc/cmdline"]).strip()

        self.installed_kernels = self.call_output(['ls', '-1', '/lib/modules']).splitlines()

        try:
            self.cfg = self.call_output(['cat', self.cfg_path])
        except:
            self.cfg = ""

        try:
            self.etc_cfg = self.call_output(['cat', self.etc_path])
        except:
            self.etc_cfg = ""

        self.config = {}
        for line in self.etc_cfg.splitlines():
            m = self.config_re.match(line)
            if m is not None:
                self.config[m.group('key')] = m.group('val').strip('"')

        self.kernels = {}
        self.entries = {}
        self.default_entry = None

        submenu = ""
        submenu_id = None
        submenu_index = ""

        index = 0

        for line in self.cfg.splitlines():
            m = self.default_re.match(line)
            if m is not None:
                default = m.group('default')
                if default == "${saved_entry}":
                    self.default_entry = 'saved'
                elif default != "${next_entry}":
                    self.default_entry = default
            m = self.submenu_re.match(line)
            if m is not None:
                submenu = m.group('name') + '>'
                submenu_id = m.group('id')
                if submenu_id is not None:
                    submenu_id = submenu_id + '>'
                submenu_index = str(index) + '>'
                index = 0
            m = self.entry_re.match(line)
            if m is not None:
                entry = submenu + m.group('name')

                entry_id = m.group('id')
                if entry_id is not None and submenu_id is not None:
                    entry_id = submenu_id + entry_id

                entry_index = submenu_index + str(index)
                index += 1

                kernel = m.group('kernel')
                flavor = m.group('flavor')
                if flavor is not None:
                    continue

                self.kernels[kernel] = entry
                self.entries[entry] = kernel
                if entry_id is not None:
                    self.entries[entry_id] = kernel
                self.entries[entry_index] = kernel

        try:
            self.env = self.call_output(['grub-editenv', self.env_path, 'list'])
        except:
            self.env = ""

        self.vars = {}
        for line in self.env.splitlines():
            var, val = line.split('=', 1)
            self.vars[var] = val.strip("'")

        self.saved_entry = self.vars.get('saved_entry')
        self.next_entry = self.vars.get('next_entry')
        self.prev_saved_entry = self.vars.get('prev_saved_entry')
        if self.prev_saved_entry and not self.next_entry:
            self.bootonce_entry = self.saved_entry
            self.saved_entry = self.prev_saved_entry
        else:
            self.bootonce_entry = self.next_entry

        self.saved_kernel = self.entries.get(self.saved_entry)
        self.bootonce_kernel = self.entries.get(self.bootonce_entry)

        if self.default_entry == 'saved':
            self.default_kernel = self.saved_kernel
        else:
            self.default_kernel = self.entries.get(self.default_entry)

    def update(self):
        self.call_root(['update-grub'])
        self.load()

    def uptime(self):
        try:
            return self.call_output(['uptime']).strip()
        except:
            return ''

    def distro(self):
        try:
            return self.call_output(['lsb_release', '-ds']).strip()
        except:
            return ''

    def taints(self, struct=True):
        try:
            tainted = int(self.call_output(['cat', '/proc/sys/kernel/tainted']))
        except:
            return ['???']
        taints = [
            'P:Proprietary module',
            'F:Forced module load',
            'S:SMP cpu out of specs',
            'R:Forced module remove',
            'M:Machine check exception',
            'B:Bad page state',
            'U:User did unsafe operations',
            'D:Kernel died after oops or bug',
            'A:ACPI table overridden',
            'W:Kernel warning',
            'C:Staging module',
            'I:Firmware workaround',
            'O:Out-of-tree module',
            'E:Unsigned module',
            'L:Soft lokup',
            'K:Livepatch',
            'X:Auxiliary taint',
            'T:Randstruct',
        ]
        bit = 0
        ret = []
        while tainted >= (1 << bit):
            if tainted & (1 << bit):
                if bit < len(taints):
                    ret.append(taints[bit])
                else:
                    ret.append('?:Unknown bit {}'.format(bit))
            bit += 1
        return ret

    def edit_config(self, key, val):
        self.call_root(['sed', 's/^' + key + '=.*$/' + key + '="' + val + '"/', '-i', self.etc_path])

    def set_etc_default(self, entry):
        self.edit_config('GRUB_DEFAULT', entry)
        self.update()
        if self.default_entry != entry:
            raise Exception('Set GRUB_DEFAULT="{}" but default_entry="{}"'.format(entry, self.default_entry))

    def lock_etc_default(self):
        self.call_root(['chattr', '+i', self.etc_path])

    def unlock_etc_default(self):
        self.call_root(['chattr', '-i', self.etc_path])

    def set_default(self, kernel, mode='auto'):

        if mode not in ('kexec', 'symlink'):
            entry = self.kernels.get(kernel)
            if entry is None:
                if mode == 'auto':
                    try:
                        # autodetect extlinux and switch to symlink
                        self.call_output(['test', '-f', '/extlinux.conf'])
                        self.call_output(['test', '-f', '/boot/vmlinuz-' + kernel])
                        self.call_output(['test', '-f', '/boot/initrd.img-' + kernel])
                        mode = 'symlink'
                    except:
                        raise
                        pass
                if mode != 'symlink':
                    raise Exception("Grub entry not found for kernel " + kernel)

        if mode == 'saved':
            self.call_root(['grub-set-default', "'" + entry + "'"])
            if self.default_entry != 'saved':
                self.set_etc_default('saved')
        elif mode == 'bootonce':
            self.call_root(['grub-reboot', "'" + entry + "'"])
        elif mode == 'default':
            self.set_etc_default(entry)
        elif mode == 'auto':
            if self.default_entry == 'saved':
                self.call_root(['grub-set-default', "'" + entry + "'"])
            else:
                self.set_etc_default(entry)
        elif mode == 'kexec':
            self.call_root(['kexec', '--load', '/boot/vmlinuz-' + kernel, '--initrd=/boot/initrd.img-' + kernel, '--reuse-cmdline'])
        elif mode == 'symlink':
            self.call_root(['ln', '-sf', 'boot/vmlinuz-' + kernel, '/vmlinuz'])
            self.call_root(['ln', '-sf', 'boot/initrd.img-' + kernel, '/initrd.img'])
        else:
            raise Exception('unknown mode: ' + mode)

        if mode in ('saved', 'default', 'auto') and self.next_entry is not None:
            self.call_root(['grub-editenv', self.env_path, 'unset', 'next_entry'])

        if self.default_entry != 'saved' and self.saved_entry is not None:
            self.call_root(['grub-editenv', self.env_path, 'unset', 'saved_entry'])

        self.load()

        if mode not in ('kexec', 'symlink'):
            assert self.default_kernel == kernel

    def get_boot_id(self, timeout=10):
        try:
            return self.call_output(['cat', '/proc/sys/kernel/random/boot_id'], timeout=timeout).strip()
        except Exception:
            self.remote.close()
            return None

    def reboot(self, wait=True, timeout=600):
        boot_id = self.get_boot_id()
        deadline = time.time() + timeout
        try:
            exitcode = self.call_root(['shutdown', '--reboot', 'now'], timeout=10)
            if exitcode not in [0, 255]:
                raise Exception("Reboot failed, exitcode: {}".format(exitcode))
        except subprocess.CalledProcessError as e:
            if e.returncode > 0:
                raise
            self.remote.close()
            pass
        except subprocess.TimeoutExpired:
            self.remote.close()
            pass
        while wait:
            current_id = self.get_boot_id(timeout=1)
            if current_id is not None and current_id != boot_id:
                break
            print('.', end='', flush=True)
            time.sleep(1)
            if time.time() > deadline:
                raise Exception("Timeout exceeded")
        print()
        self.load()

    def install(self, kernel, headers=True, dbgsym=False, dbgsym_extra=False, extra=False, tools=False):
        pkg = []
        pkg.append('linux-image-{}'.format(kernel))
        if headers:
            pkg.append('linux-headers-{}'.format(kernel))
        if extra:
            pkg.append('linux-image-extra-{}'.format(kernel))
        if dbgsym:
            pkg.append('linux-image-{}-dbgsym'.format(kernel))
        if dbgsym_extra:
            pkg.append('linux-image-extra-{}-dbgsym'.format(kernel))
        if tools:
            pkg.append('linux-tools={}'.format(kernel))
        env = {'DEBIAN_FRONTEND': "noninteractive"}
        self.call_root(['apt-get', 'update'], env=env)
        self.call_root(['apt-get', 'install', '--yes'] + pkg, env=env)
        self.load()

    def tarball_release(self, tarball):
        out = subprocess.check_output(['tar', 'tf', tarball, '--wildcards', '*boot/vmlinuz-*'], encoding='utf-8').strip()
        return out.split('/vmlinuz-')[1]

    def install_tarball(self, tarball, release=None):
        if release is None:
            release = self.tarball_release(tarball)
        self.call_root(['tar',
                        '--extract',
                        '--no-overwrite-dir',
                        '--no-same-owner',
                        '--no-same-permissions',
                        '--gzip',
                        '-C', '/'],
                        stdin=open(tarball, 'rb'))
        self.call_root(['update-initramfs', '-c', '-k', release])
        self.update()

    def remove(self, kernel, protect_current=True, force=False):
        if protect_current and kernel == self.current_kernel:
            raise Exception("Cannot remove current kernel")
        if force:
            try:
                self.call_root(['apt-get', 'remove', '--yes', '^linux-.*-{}$'.format(kernel)])
            except:
                assert kernel != ''
                assert '*' not in kernel
                assert '..' not in kernel
                self.call_root(['rm', '-fr', '/lib/modules/' + kernel])
                self.call_root(['find', '/boot', '-name', '*-' + kernel, '-print', '-delete'])
                self.call_root(['update-initramfs', '-d', '-k', kernel])
                self.call_root(['update-grub'])
        else:
            self.call_root(['apt-get', 'remove', '--yes', '^linux-.*-{}$'.format(kernel)])
        self.load()


def main():
    import argparse

    parser = argparse.ArgumentParser(description='Kernel version remote control')
    parser.add_argument('-H', '--host', metavar='HOSTNAME', help='default: localhost')
    parser.add_argument('--ssh-port', metavar='PORT', help='ssh port', type=int)
    parser.add_argument('--private-key', metavar='FILE', help='ssh private key file')
    parser.add_argument('--domain', metavar='DOMAIN', default='search.yandex.net', help='default: %(default)s')
    parser.add_argument('--login', metavar='USERNAME', help='default: current')
    parser.add_argument('--update', action='store_true', help='call update-grub')
    parser.add_argument('--saved', dest='mode', default='auto', const='saved', action='store_const', help='set GRUB_DEFAULT=saved')
    parser.add_argument('--default', dest='mode', const='default', action='store_const', help='set GRUB_DEFAULT directly')
    parser.add_argument('--bootonce', dest='mode', const='bootonce', action='store_const', help='set GRUB_DEFAULT=saved and boot kernel once using grub-reboot')
    parser.add_argument('--kexec', dest='mode', const='kexec', action='store_const', help='load kernel using kexec')
    parser.add_argument('--symlink', dest='mode', const='symlink', action='store_const', help='set symlink /vmlinuz and /initrd.img')
    parser.add_argument('--lock', action='store_true', help='make /etc/default/grub immutable')
    parser.add_argument('--unlock', action='store_true', help='undo /etc/default/grub immutable')
    parser.add_argument('--reboot', action='store_true', help='reboot into new default kernel')
    parser.add_argument('--force-reboot', action='store_true', help='reboot unconditionally')
    parser.add_argument('--install', action='store_true', help='install kernel package')
    parser.add_argument('--tarball', help='install kernel from tarball')
    parser.add_argument('--remove', metavar='VERSION', help='remove kernel package')
    parser.add_argument('--force-remove', metavar='VERSION', help='remove kernel files')
    parser.add_argument('--force-remove-all', action='store_true', help='remove all kernels except current')
    parser.add_argument('kernel', nargs='?', metavar='VERSION', help='set default kernel')

    args = parser.parse_args()

    hostname = args.host
    if hostname is None and args.ssh_port is not None:
        hostname = 'localhost'
    if hostname is not None and hostname != 'localhost' and '.' not in hostname and args.domain:
        hostname += '.' + args.domain

    private_key = None if args.private_key is None else open(args.private_key).read()

    menu = KernelMenu(hostname=hostname, ssh_port=args.ssh_port,
                      username=args.login, private_key=private_key)

    if args.update:
        menu.update()

    print('kernels:')
    for k in menu.installed_kernels:
        print(k)
    print()

    print('config:')
    for k, e in menu.config.items():
        print(k, '=', e)
    print()

    print('menu:')
    for k, e in menu.kernels.items():
        print(k, '\t', e)
    print()

    print('cmdline:', menu.current_cmdline)
    print('current:', menu.current_kernel)
    print('distro:', menu.distro())
    print('uptime:', menu.uptime())
    print('taints:', ', '.join(menu.taints()))
    print('default:', menu.default_kernel, '\t', menu.default_entry)
    print('saved:', menu.saved_kernel, '\t', menu.saved_entry)

    if menu.bootonce_kernel:
        print('bootonce:', menu.bootonce_kernel, 'entry', menu.bootonce_entry)

    if args.install:
        print('installing', args.kernel)
        menu.install(args.kernel)

    if args.tarball:
        if args.kernel is None:
            args.kernel = menu.tarball_release(args.tarball)
        menu.install_tarball(args.tarball, args.kernel)

    if args.unlock:
        print('unlock', menu.etc_path)
        menu.unlock_etc_default()

    if args.kernel:
        print('set', args.kernel)
        menu.set_default(args.kernel, mode=args.mode)

    if args.lock:
        print('lock', menu.etc_path)
        menu.lock_etc_default()

    if args.remove:
        print('removing', args.remove)
        menu.remove(args.remove)
        print('default:', menu.default_kernel)

    if args.force_remove:
        print('removing', args.force_remove)
        menu.remove(args.force_remove, force=True)
        print('default:', menu.default_kernel)

    if args.force_remove_all:
        for kernel in menu.kernels:
            if kernel != menu.current_kernel:
                print('removing', kernel)
                menu.remove(kernel, force=True)

    if args.force_reboot or args.reboot and args.kernel != menu.current_kernel:
        print('booting', args.kernel)
        start = time.time()
        menu.reboot()
        finish = time.time()
        print('reboot:', int(finish - start), 'seconds')
        print('cmdline:', menu.current_cmdline)
        print('current:', menu.current_kernel)

if __name__ == '__main__':
    main()
