#!/usr/bin/env python3

import subprocess
import contextlib
import tempfile
import argparse
import logging
import shutil
import sys
import os
import library.python.resource as resource


@contextlib.contextmanager
def tmpdir(**kwargs):
    tmp = tempfile.mkdtemp(**kwargs)
    try:
        yield tmp
    finally:
        shutil.rmtree(tmp)


BUILD_DEPS = [
    'busybox-static',
    'cpio',
    'squashfs-tools'
]

BASE_MODULES = [
    '9p',
    'squashfs',
    'overlay',
    'nls_iso8859_1'  # required for vfat cloud-init image
]


def run(args, logger=None, check=True, text=False, stdin=subprocess.DEVNULL, **kwargs):
    logging.info("+ '" + "' '".join(args) + "'")
    ret = subprocess.run(args, check=check, universal_newlines=text, stdin=stdin, **kwargs)
    return ret


def run_output(args, text=True, **kwargs):
    logging.info("+ '" + "' '".join(args) + "'")
    return run(args, stdout=subprocess.PIPE, text=text, **kwargs).stdout


def install_deps(args):
    cmd = ['sudo', 'apt-get', 'install', '--no-install-recommends', '-y'] + BUILD_DEPS
    if args.dry_run:
        logging.info("# Install build dependencies.")
        logging.info(" ".join(cmd))
        sys.exit(0)
    else:
        run(['sudo', 'apt-get', 'update', '-y'])
        run(cmd)


def install_kernel(args):
    cmd = ['sudo', 'apt-get', 'install', '--no-install-recommends', '-y', '--allow-downgrades',
           'linux-image-' + args.version,
           'linux-image-extra-' + args.version,
           'linux-tools=' + args.version]

    if args.dry_run:
        logging.info("# Install kernel.")
        logging.info(" ".join(cmd))
        sys.exit(0)
    else:
        run(['sudo', 'apt-get', 'update', '-y'])
        run(['sudo', 'apt-get', 'remove', '-y', 'linux-headers-' + args.version])
        run(cmd)


def build_initrd(args):
    targets = ['vmlinuz', 'initrd.img', 'modules.sqfs']
    if args.initrd_only:
        targets = ['initrd.img']

    assert args.version is not None
    if args.install_deps:
        install_deps(args)

    if args.install_kernel:
        install_kernel(args)

    with tmpdir(prefix='build') as build_dir:
        if os.path.isdir(args.out):
            out_dir = os.path.abspath(args.out)
        else:
            out_dir = os.path.abspath(build_dir + '/out')
            os.mkdir(out_dir)

        # copy vmlinuz
        if 'vmlinuz' in targets:
            shutil.copy2(args.mod_path + 'boot/vmlinuz-' + args.version, out_dir + "/vmlinuz")

        # generate moules.sqfs
        if 'modules.sqfs' in targets:
            modules_dir = build_dir + '/modules/lib/modules/' + args.version
            os.makedirs(os.path.dirname(modules_dir))
            run(['cp', '-r', args.mod_path + 'lib/modules/' + args.version, modules_dir])
            run(['mksquashfs', build_dir + '/modules', out_dir + '/modules.sqfs', '-comp', 'lzo', '-noappend'])

        # sgage3: gnerate initrd
        # Copy /init
        initrd_dir = os.path.join(build_dir, 'initrd')
        os.mkdir(initrd_dir)
        if args.init:
            os.chmod(shutil.copy2(os.path.abspath(args.init),  initrd_dir + "/init"), 0o755)
        else:
            init_data = resource.find('arcadia/infra/kernel/tools/kvm_kernel/init/init.sh').decode("utf-8")
            assert(init_data)
            open(os.open(initrd_dir + "/init", os.O_CREAT | os.O_WRONLY, 0o555), 'w').write(init_data)

        # install busybox
        os.mkdir(os.path.join(initrd_dir, "bin"))
        busybox_os = shutil.which('busybox')
        busybox_bin = initrd_dir + '/busybox'
        shutil.copy2(busybox_os, busybox_bin)
        run([busybox_bin, '--install', 'bin'], cwd=initrd_dir)

        # install base modules
        with open(initrd_dir + "/init_base_modules.sh", "w+") as f:
            for module in BASE_MODULES:
                for line in run_output(['modprobe', '--show-depends', '-d', os.path.abspath(args.mod_path), '--set-version=' + args.version, module]).splitlines():
                    tokens=line.split()
                    prefix=tokens[0]
                    module_path=tokens[1]
                    if prefix != "insmod":
                        continue
                    dst = module_path.replace(os.path.abspath(args.mod_path), '/')
                    abs_dst = initrd_dir + dst
                    if not os.path.exists(os.path.dirname(abs_dst)):
                        os.makedirs(os.path.dirname(abs_dst))
                    shutil.copy2(module_path, abs_dst)
                    f.write('insmod {}\n'.format(dst))
            os.fchmod(f.fileno(), 0o755)
        # Place for custom init scripts
        os.mkdir(os.path.join(initrd_dir, "scripts"))

        run(["find . -print0 | cpio --null --create --format=newc | gzip > " + out_dir + "/initrd.img"], shell=True, cwd=initrd_dir)
        run(['ln', '-f', 'initrd.img', 'kvm-initrd.img'], cwd=out_dir)
        # Build tar archive if requested
        if out_dir != os.path.abspath(args.out):
            run(['tar', '-acf', args.out, '-C', out_dir] + targets)


def main(cmdline=None):
    if cmdline is None:
        cmdline = sys.argv[1:]

    parser = argparse.ArgumentParser()
    subparsers = parser.add_subparsers(title="Possible extra actions", dest='command')

    deps = subparsers.add_parser(name="install-deps")
    deps.set_defaults(handle=install_deps)
    deps.add_argument("--dry-run", default=False, action='store_true')

    build = subparsers.add_parser(name="build")
    build.set_defaults(handle=build_initrd)
    build.add_argument("-q", "--quiet", default=False, action='store_true')
    build.add_argument("--dry-run", default=False, action='store_true')
    build.add_argument("--install-deps", default=False, action='store_true')
    build.add_argument("--install-kernel", default=False, action='store_true')
    build.add_argument("--initrd-only", default=False, action='store_true')
    build.add_argument("-m", "--mod-path", default='/')
    build.add_argument("--init", help='init script')
    build.add_argument("version", help='kernel version')
    build.add_argument("out", help="directory or tarball")
    args = parser.parse_args(cmdline)
    if args.quiet:
        logging.basicConfig(level=logging.WARN)
    else:
        logging.basicConfig(level=logging.INFO)

    args.handle(args)


if __name__ == '__main__':
    main()
