#!/usr/bin/env python
import argparse
import contextlib
import logging
import os
import sys
import shutil
import subprocess
import tempfile
import time
import retrying

import library.python.svn_version as sv
import library.python.resource as resource
# from infra.kernel.tools.kvm_kernel.builder.lib.main import build_initrd
import qemu  # noqa
from qemutool.guest.qavm import QAVM

TESTVM_ROOT = "infra/qemu/guest_images/qavm"
SSH_KEY_FILE = os.path.join(TESTVM_ROOT, "keys/id_rsa")
QEMU_BIN = "infra/qemu/bin/qemu/bin/qemu-system-x86_64"

CACHED_IMG_MAP = {
    "rtc-focal" : "infra/environments/rtc-focal/release/vm-image",
    "rtc-focal-gpu" : "infra/environments/rtc-focal-gpu/release/vm-image",
    "rtc-focal-gpu:dev" : "infra/environments/rtc-focal-gpu/release/vm-image-dev",

    "rtc-bionic" : "infra/environments/rtc-bionic/release/vm-image",
    "rtc-bionic:unstable" : "infra/environments/rtc-bionic/release/vm-image-unstable",
    "rtc-bionic:experimental" : "infra/environments/rtc-bionic/release/vm-image-experimental",

    "rtc-xenial" : "infra/environments/rtc-xenial/release/vm-image",
    "rtc-xenial:ng" : "infra/environments/rtc-xenial/release/vm-image-ng",
    "rtc-xenial:unstable" : "infra/environments/rtc-xenial/release/vm-image-unstable",
    "rtc-xenial:experimental" : "infra/environments/rtc-xenial/release/vm-image-experimental",

    "rtc-xenial-gpu" : "infra/environments/rtc-xenial-gpu/release/vm-image",
    "rtc-xenial-gpu:dev" : "infra/environments/rtc-xenial-gpu/release/vm-image-dev",

    "rtc-precise" : "infra/environments/rtc-precise/release/vm-image",
    "rtc-precise:ng" : "infra/environments/rtc-precise/release/vm-image-ng",
    "rtc-precise:unstable" : "infra/environments/rtc-precise/release/vm-image-unstable",

    "rtc-trusty" : "infra/environments/rtc-trusty/release/vm-image",
    "rtc-trusty:ng" : "infra/environments/rtc-trusty/release/vm-image-ng",
    "rtc-trusty:unstable" : "infra/environments/rtc-trusty/release/vm-image-unstable",

    "vanilla-focal" : "infra/environments/vanilla-focal/release/vm-image",
    "vanilla-bionic" : "infra/environments/vanilla-bionic/release/vm-image",
    "vanilla-xenial" : "infra/environments/vanilla-xenial/release/vm-image",
    "vanilla-precise" : "infra/environments/vanilla-precise/release/vm-image",
    "vanilla-trusty" : "infra/environments/vanilla-trusty/release/vm-image",

    "qavm-bionic" : "infra/environments/qavm-bionic/release/vm-image",
    "qavm-xenial" : "infra/environments/qavm-xenial/release/vm-image",
    "qavm-precise" : "infra/environments/qavm-precise/release/vm-image",
    "qavm-trusty" : "infra/environments/qavm-trusty/release/vm-image",
    "qavm" : "infra/environments/qavm/release/vm-image",
}


def get_cachedir(path=""):
    cache_root = os.getenv('YA_CACHE_DIR') or os.path.join(os.path.expanduser('~'), '.ya')
    cache = os.path.join(cache_root, 'vmexec', path)
    os.makedirs(cache, exist_ok=True)
    return cache


@contextlib.contextmanager
def tmpdir(**kwargs):
    tmp = tempfile.mkdtemp(**kwargs)
    try:
        yield tmp
    finally:
        shutil.rmtree(tmp)


def run(args, logger=None, check=True, text=False, stdin=subprocess.DEVNULL, **kwargs):
    logger.debug("+ '" + "' '".join(args) + "'")
    ret = subprocess.run(args, check=check, universal_newlines=text, stdin=stdin, **kwargs)
    return ret


# Logger
logger = logging.getLogger("vmexec")


def get_yatool(name):
    out = subprocess.check_output(['ya', 'tool', name, '--print-path'])
    return out.decode('utf-8').strip()


def tuple_str(tok):
    if not tok:
        return tok
    l = tok.split(':', 1)
    if len(l) == 1:
        return(l[0], l[0])
    else:
        return(l[0], l[1])


def init_parser():
    parser = argparse.ArgumentParser(description='              -- COMMAND ARGUMENT...\n\nvmexec (rev:%s)' % sv.svn_revision(),
                                     formatter_class=argparse.RawDescriptionHelpFormatter)

    parser.add_argument('--verbose', action='store_true')
    parser.add_argument('--debug', action='store_true')
    parser.add_argument('--boot-timeout', type=int, default=300, help="seconds to wait vm boot: default: %(default)s")
    parser.add_argument('-S', '--arcadia-source', dest='source_root', type=str, help='arcadia source root')
    parser.add_argument('-B', '--arcadia-build', dest='build_root', type=str, help='arcadia build root')
    parser.add_argument('-D', '--arcadia-distbuild', dest='distbuild_root', help='arcadia distbuild root')
    parser.add_argument('--ignore-env', dest='use_env', default=True, action='store_false', help='Ignore environment variables')
    parser.add_argument("-L", "--logdir", help="logdir", default=None)
    parser.add_argument("-I", "--rootfs", help="Path to rootfs image (Arcadia related)")
    parser.add_argument("-i", dest="cached_rootfs", help="Precreated rootfs", default="qavm", choices=CACHED_IMG_MAP.keys())
    parser.add_argument("--mem", help="Amount of memory for running VM", default="2G")
    parser.add_argument("--vcpus", help="Amount of cores for running VM", default=2)
    parser.add_argument("--kernel", dest='kernel', type=str, default=None, help="linux kernel")
    parser.add_argument("--initrd", dest='initrd', type=str, default=None, help="linux initrd")
    parser.add_argument("--modules", dest='modules', type=str, default=None, help="linux modules image or dir")
    parser.add_argument('--rwdelta', dest='rwdelta', default=False, action='store_true', help='Use kvm-init boot scheme')
    parser.add_argument("--append", dest='append', type=str, default=None, help="linux kernel command line")
    parser.add_argument('--gen-initrd', default=True, action='store_false', help='generate initrd')
    parser.add_argument("-K", "--kernel-release-dir", dest="krelease_dir", default=None, help="Kernel release dir (with vmlinuz, modules.squashfs kvm-initrd.img")
    parser.add_argument("--tmpfs", help="Create and mount tmpfs dir inside VM")
    parser.add_argument("-v", "--volume", dest='volumes', type=tuple_str, action='append', default=[], help="Mount volume")
    parser.add_argument("--volume-security", help="virtfs security option", default="none", choices=['none', 'mapped-file', 'mapped-xattr', 'passthrough'])
    parser.add_argument("-w", "--workdir", dest='workdir', type=str, default=os.getcwd(), help="Working directory inside the VM")
    parser.add_argument("--ssh-key", help="Path to ssh key to access VM (Arcadia related)")
    parser.add_argument('--qemu-bin', type=str, help='custom qemu binary')
    parser.add_argument("--qemu-opts", type=str, help="custom qemu options")
    parser.add_argument('--trap', '--shell', dest='trap', default=False, action='store_true', help='Start vm, and wait for interactive actions')

    return parser


def _setup_arc_env(args):
    if not args.source_root:
        args.source_root = os.environ.get('ARCADIA_SOURCE_ROOT')
    if not args.source_root:
        args.source_root = os.environ.get('ARCADIA_ROOT')

    if not args.build_root:
        args.build_root = os.environ.get('ARCADIA_BUILD_ROOT')
    if not args.distbuild_root:
        args.distbuild_root = os.environ.get('ARCADIA_ROOT_DISTBUILD')


def _do_virtfs_opt(src, dst, security, tag, env=None):
    # use security_model=mapped-file for compatibility with porto layer exporting
    # porto sets owner/mode on exported archive and gets EPERM with security_model=none
    opt = ["-virtfs", "local,path={},mount_tag={},security_model={}".format(src, tag, security)]
    return (opt, dst, tag, env)


def _prep_virtfs(args):
    vopt = []
    if args.distbuild_root:
        vopt.append(_do_virtfs_opt(args.distbuild_root, args.distbuild_root, args.volume_security,
                                   'ARCADIA_ROOT_DISTBUILD', 'ARCADIA_ROOT_DISTBUILD'))
    if args.build_root:
        vopt.append(_do_virtfs_opt(args.build_root, args.build_root, args.volume_security,
                                   'ARCADIA_BUILD_ROOT', 'ARCADIA_BUILD_ROOT'))
    if args.source_root:
        vopt.append(_do_virtfs_opt(args.source_root, args.source_root, args.volume_security,
                                   'ARCADIA_ROOT', 'ARCADIA_ROOT'))

    i = 0
    for src, dst in args.volumes:
        i += 1
        vopt.append(_do_virtfs_opt(src, dst, args.volume_security, 'volume-{}'.format(i), None))
    return vopt


def _mount_virtfs(vm, path, tag, env_name=None):
    # Handle relative path
    real_path = os.path.realpath(path)
    vm.ssh_root_check("mkdir -p {}".format(real_path))
    if real_path != path:
        vm.ssh_root_check("mkdir -p {}".format(os.path.dirname(path)))
        vm.ssh_root_check("ln -s {} {}".format(real_path, path))
    vm.ssh_root_check("mount -t 9p {} {}  -o trans=virtio,rw".format(tag, real_path))
    if env_name:
        vm.ssh_root_check("echo export {}={} | tee -a /vmexec_env.sh".format(env_name, path),
                          stdout=sys.stderr, stderr=sys.stderr)


def _read_key(fname, key_name):
    if fname:
        assert os.path.exists(fname)
        with open(fname) as f:
            return f.read().decode("utf-8")
    else:
        return resource.find('arcadia/infra/qemu/guest_images/qavm/keys/' + key_name).decode("utf-8")


def get_qemu_binary(args):
    if args.qemu_bin:
        return args.qemu_bin
    if args.build_root and os.path.exists(os.path.join(args.build_root, QEMU_BIN)):
        return os.path.join(args.build_root, QEMU_BIN)
    try:
        path = get_yatool('qemu')
        logger.debug('qemu_path: {}'.format(path))
        if os.path.exists(path):
            return path
    except Exception as e:
        logger.error("Can't find qemu-binary, err: {}".format(str(e)))
        sys.exit(1)
    return ''


def get_rootfs(args):
    if args.rootfs:
        if args.build_root:
            return os.path.join(args.build_root, args.rootfs)
        else:
            return args.rootfs
    # Default for vmexec w/o -I arguments
    if args.build_root:
        return os.path.join(args.build_root, "infra/environments/qavm/release/vm-image/rootfs.img")
    try:
        cache_dir = get_cachedir()
        if args.cached_rootfs.startswith("arcadia_path:"):
            cache_dir = get_cachedir(args.cached_rootfs[13:])
        elif args.cached_rootfs not in CACHED_IMG_MAP:
            logger.error("cacher_rootfs :{} is unknown".format(args.cached_rootfs))
            return ''
        sb_path = CACHED_IMG_MAP[args.cached_rootfs]
        cache_dir = get_cachedir(os.path.join("cache", sb_path))
        cache_file = os.path.join(cache_dir, 'rootfs.img')
        if os.path.exists(cache_file):
            st = os.stat(cache_file,  follow_symlinks=False)
            if time.time() - st.st_ctime < 86400:
                ret = os.readlink(cache_file)
                logger.debug('Use cached rootfs: {}'.format(ret))
                return ret
        # Remove potentialy dangling symlink
        try:
            os.unlink(cache_file)
        except:
            pass
        # Cache is outdated, fetch rootfs.img from sandbox
        with tmpdir() as tdir:
            sbctl = get_yatool('sandboxctl')
            jq_resource = os.path.join(sb_path, 'ya.make.autoupdate')
            jdata = resource.find(jq_resource)
            if len(jdata) == 0:
                logger.debug("Unknown resource file: {}".format(jq_resource))
                return ""
            jq_file = os.path.join(tdir, 'jq.json')
            with open(jq_file, 'wb') as f:
                logger.debug("lookup resource {}".format(jdata.decode("utf-8")))
                f.write(jdata)

            out = subprocess.check_output([sbctl, 'list_resource', '--limit', '1', '-q', '--jq', jq_file])
            resource_id = out.decode('utf-8').strip()
            if len(resource_id) == 0:
                logger.error("Can not find sandbox resource, query:{}".format(jdata.decode("utf-8")))
                return ""

            out = subprocess.check_output([sbctl, 'get_resource', '-q', resource_id])
            resource_path = out.decode('utf-8').strip()
            logger.debug("Resource id:{} path:{}".format(resource_id, resource_path))
            os.symlink(resource_path, cache_file)
            logger.debug("New cached rootfs :{}".format(cache_file))
            return cache_file
    except Exception as e:
        logger.error("Can not get rootfs resource, error: {}".format(str(e)))
    return ''


def main():
    parser = init_parser()
    args, sandbox_args = parser.parse_known_args()

    if len(sandbox_args) and sandbox_args[0] != '--':
        raise argparse.ArgumentTypeError('unexpected arguments:{}, use -- as separator'.format(sandbox_args))
    sandbox_args = sandbox_args[1:]

    if args.debug:
        loglevel = logging.DEBUG
    elif args.verbose:
        loglevel = logging.INFO
    else:
        loglevel = logging.WARNING

    logging.basicConfig(level=loglevel,
                        stream=sys.stderr,
                        format="%(asctime)s [%(levelname)-5.5s] %(name)s: %(message)s")

    if not sandbox_args:
        sandbox_args = ['echo', 'hello from vm']

    if args.use_env:
        _setup_arc_env(args)
    if not args.source_root:
        logger.info('ARCADIA_ROOT not set, please set pass --arcadia-build, or ARCADIA_BUILD_ROOT env')
    if not args.source_root:
        logger.info('ARCADIA_BUILD not set, please set pass --arcadia-build, or ARCADIA_BUILD_ROOT env')

    logger.debug("source_root :{}".format(args.source_root))
    logger.debug("build_root :{}".format(args.build_root))
    logger.debug("distbuild_root :{}".format(args.distbuild_root))

    ssh_key_data = _read_key(args.ssh_key, 'id_rsa')
    ssh_pubkey_data = _read_key(args.ssh_key, 'id_rsa.pub')
    assert(ssh_key_data)
    assert(ssh_pubkey_data)

    qemu_bin = get_qemu_binary(args)
    rootfs = get_rootfs(args)

    if not os.path.exists(rootfs):
        logger.error("rootfs image: {} not found".format(rootfs))
        sys.exit(1)

    name = QAVM.gen_name()
    if not args.logdir:
        args.logdir = os.path.join(get_cachedir(), name)
        os.makedirs(args.logdir, exist_ok=True)
    serial_log = os.path.join(args.logdir, "serial.log")

    vm = QAVM(qemu_bin=qemu_bin,
              name=name,
              ssh_key=ssh_key_data,
              ssh_pub_key=ssh_pubkey_data,
              vcpus=args.vcpus,
              mem=args.mem,
              rootfs=rootfs,
              wdir=args.logdir,
              serial_log=serial_log,
              logger=logger.getChild('qemu'))

    vopt = _prep_virtfs(args)
    logger.debug("Create VM %s" % vm.name)
    boot_opt = ['-cpu', 'host']
    for opt, dst, tag, env in vopt:
        boot_opt += opt
    if args.qemu_opts:
        boot_opt += args.qemu_opts.split()

    rw_snapshot = None
    kargs = ""

    if args.krelease_dir:
        args.kernel = os.path.join(args.krelease_dir, 'vmlinuz')
        args.initrd = os.path.join(args.krelease_dir, 'kvm-initrd.img')
        args.modules = os.path.join(args.krelease_dir, 'modules.sqfs')

    if args.modules:
        if os.path.isdir(args.modules):
            boot_opt += ["-virtfs", "local,path={},mount_tag=modules,security_model=none,readonly".format(args.modules)]
            kargs = "root=/dev/vda  modules_dev=9p:modules rwdelta=/dev/vdb"
            rw_snapshot_idx=1
        else:
            boot_opt += ["-drive", "file={},if=virtio,index=1,snapshot".format(args.modules)]
            kargs = "root=/dev/vda  modules_dev=/dev/vdb rwdelta=/dev/vdc"
            rw_snapshot_idx=2

        kargs += " console=tty1 console=ttyS0,115200n8 loglevel=7 oops=panic panic=60 biosdevname=0 net.ifnames=0"
        rw_snapshot = os.path.join(args.logdir, 'rw_delta.img')
        if args.append:
            kargs += args.append
        args.append = kargs

    if rw_snapshot:
        boot_opt += ["-drive", "file={},if=virtio,index={}".format(rw_snapshot, rw_snapshot_idx)]
        run(['mkfs.ext4',  '-q', '-F', '-O', '^has_journal,^64bit,^metadata_csum', rw_snapshot, '20G'], logger=logger)

    try:
        vm.boot(extra_args=boot_opt, kernel=args.kernel, initrd=args.initrd, kargs=args.append)
        # Unlink rw_snapshot
    except Exception as e:
        logging.error('Fail to boot vm: %s', e)
        raise
    finally:
        if rw_snapshot:
            os.unlink(rw_snapshot)
    try:
        vm.wait_ssh(debug=args.debug, timeout=args.boot_timeout)
    except Exception as e:
        logging.error('Fail to boot vm: %s', e)
        raise

    try:
        vm.ssh_root_check('hostname {}'.format(vm.name))
        vm.ssh_root_check("touch /vmexec_env.sh")

        # set up proper default gateway for ipv6 net with 'scope global' addresses
        @retrying.retry(stop_max_delay=30000, wait_fixed=500)
        def _add_ipv6_def_route():
            vm.ssh_root_check("ip -6 route add default via fdc::1 metric 99")

        _add_ipv6_def_route()

        # set up dns servers properly for VMs without keep_resolv_conf hook needed by QYP
        vm.ssh_root_check("if grep -q 10.0.2.3 /etc/resolv.conf; then printf 'nameserver 2a02:6b8:0:3400::1\nnameserver 2a02:6b8::1:1\n' > /etc/resolv.conf; fi")

        for opt, dst, tag, env in vopt:
            _mount_virtfs(vm, dst, tag, env)

        if args.tmpfs:
            vm.ssh_root_check("mkdir -p {}".format(args.tmpfs))
            vm.ssh_root_check("mount -t tmpfs none {}".format(args.tmpfs))
            vm.ssh_root_check("echo export TMPFS={} | tee -a /vmexec_env.sh".format(args.tmpfs),
                              stdout=sys.stderr, stderr=sys.stderr)
        script = """
        #!/bin/bash
        set -e
        . /vmexec_env.sh
        cd {workdir}
        {command}
        """.format(workdir=args.workdir, command=" ".join(sandbox_args))
        vm.ssh_root_check("echo '{}' | tee -a /vmexec.sh".format(script),
                          stdout=sys.stderr, stderr=sys.stderr)
        vm.ssh_root_check("chmod +x /vmexec.sh")
    except Exception as e:
        logging.error('Fail to config vm: %s', e)
        raise

    if args.trap:
        logger.info("vmexec script paused, you are in interactive ssh")
        ret = vm.ssh_interactive(user='root')
        if ret:
            sys.exit(ret)

    try:
        vm.ssh_root_check("/vmexec.sh", stdout=sys.stdout, stderr=sys.stderr)
    except Exception as e:
        logging.error('vmexec script failed: %s', e)
        sys.exit(1)


if __name__ == '__main__':
    main()
