from __future__ import print_function
from __future__ import division
import atexit
import datetime
import hashlib
import logging
import os
import requests
import sys
import subprocess
import time
import tempfile
import yaml
import qemu


class QVMSetupError(Exception):
    pass


class QVMRuntimeError(Exception):
    pass


def format_script(cmd_list):
    ret = ""
    assert isinstance(cmd_list, list)
    for tok in cmd_list:
        if tok.startswith("-"):
            ret += "\\\n        " + tok + " "
        else:
            ret += tok + " "
    return ret


def guess_syslib(qemu_bin):
    # Workaround for arcadia deploy technique, where qemu-boundle untared as tree of symlinks so
    # qemu can not resolve it's sysmlib path.
    try_list = ['../share/qemu']
    for t in try_list:
        s = os.path.join(os.path.dirname(qemu_bin), t)
        if os.path.exists(os.path.join(s, 'bios-256k.bin')):
            return s
    return ''


class BaseVM(object):
    name = "#BaseVM"
    LINUX_CMDLINE="root=LABEL=rootfs rootfstype=ext4 rootflags=errors=panic ro console=tty1 console=ttyS0,115200n8 loglevel=7 oops=panic panic=60 biosdevname=0 net.ifnames=0"

    def __init__(self, qemu_bin, ssh_key, ssh_pub_key,
                 vcpus=2, mem="2G", wdir=None, rootfs=None, cloud_init=None,
                 logger=None, serial_log=None, ssh_stdout=None, ssh_stderr=None):
        if not qemu.kvm_available():
            qemu.QEMUMachineError('kvm is not available')

        self.guest_user = None
        self._guest = None
        self._qemu_bin = qemu_bin
        self._wdir = wdir or os.getcwd()
        self._rootfs = rootfs or self.name + ".img"
        self.cloud_init = cloud_init
        self._cloud_init_dir = None
        self.ssh_key = ssh_key
        self.ssh_pub_key = ssh_pub_key
        self._ssh_pub_key_file = os.path.join(self._wdir, "id_rsa.pub")
        open(self._ssh_pub_key_file, "w").write(self.ssh_pub_key)
        self._ssh_key_file = os.path.join(self._wdir, "id_rsa")
        open(self._ssh_key_file, "w").write(self.ssh_key)
        os.chmod(self._ssh_key_file, 0o600)

        self.logger = logger or logging.getLogger(__name__)
        self.ssh_stdout = ssh_stdout
        self.ssh_stderr = ssh_stdout

        self._args = ['-nodefaults']
        syslib = guess_syslib(qemu_bin)
        if syslib:
            self._args += ['-L', syslib]
        self._args += [
            "-enable-kvm",
            "-smp", str(vcpus),
            "-m", mem,
            "-netdev", "user,id=vnet,hostfwd=:127.0.0.1:0-:22,ipv6-net=fdc::/64,ipv6-host=fdc::1",
            "-device", "virtio-net-pci,netdev=vnet",
            "-device", "virtio-rng-pci"]

        self.serial_log = serial_log
        if serial_log is None:
            self._args += ['-serial', 'stdio']
        else:
            self._args += ['-serial', 'file:' + serial_log]

        self._data_args = []

    def set_auth(self, user, passwd, root_passwd):
        self.guest_user = user or "qemu"
        self.guest_pass = passwd or "qemupass"
        self.root_pass = root_passwd or "qemupass"

    def set_qa_auth(self, user="qemu", passwd="qemupass"):
        self.set_auth(user, passwd, passwd)

    def _make_ssh_cmd(self, cmd, user=None, force_term=False, debug=False, timeout=60):
        ssh_cmd = ["/usr/bin/ssh",
                   "-F", "/dev/null",
                   "-vvv" if debug else "-q",
                   "-t" if force_term else "-T",
                   "-o", "ConnectTimeout={}".format(timeout),
                   "-o", "ServerAliveInterval={}".format(timeout // 3),
                   "-o", "ServerAliveCountMax=3",
                   "-o", "StrictHostKeyChecking=no",
                   "-o", "UserKnownHostsFile=" + os.devnull,
                   "-p", self.ssh_port,
                   "-i", self._ssh_key_file,
                   "-l", user or self.guest_user,
                   "127.0.0.1"]

        if isinstance(cmd, str):
            cmd = [cmd]

        return ssh_cmd + cmd

    def _ssh_cmd(self, cmd, user=None, check=False,
                 force_term=False, debug=False, timeout=60,
                 stdin=None, stdout=None, stderr=None):
        ssh_cmd = self._make_ssh_cmd(cmd, user=user, force_term=force_term, debug=debug, timeout=timeout)
        self.logger.debug("ssh_cmd: %s", format_script(ssh_cmd))
        if stdout is None:
            stdout = self.ssh_stdout
        if stderr is None:
            stderr = self.ssh_stderr
        r = subprocess.call(ssh_cmd, stdin=stdin, stdout=stdout, stderr=stderr)
        if check and r != 0:
            raise QVMRuntimeError("SSH command failed: %s" % cmd)
        return r

    def ssh(self, cmd, stdin=None, stdout=None, stderr=None):
        return self._ssh_cmd(cmd, stdin=stdin, stdout=stdout, stderr=stderr)

    def ssh_interactive(self, cmd=[], user=None):
        return self._ssh_cmd(cmd, user=user, force_term=True, stdin=sys.stdin, stdout=sys.stdout, stderr=sys.stderr)

    def ssh_root(self, cmd, stdin=None, stdout=None, stderr=None):
        return self._ssh_cmd(cmd, user='root', stdin=stdin, stdout=stdout, stderr=stderr)

    def ssh_check(self, cmd, stdin=None, stdout=None, stderr=None):
        self._ssh_cmd(cmd, check=True, stdin=stdin, stdout=stdout, stderr=stderr)

    def ssh_root_check(self, cmd, stdin=None, stdout=None, stderr=None):
        self._ssh_cmd(cmd, user='root', check=True, stdin=stdin, stdout=stdout, stderr=stderr)

    def boot(self, rootfs=None, rootfs_snapshot=True, kernel=None, initrd=None, kargs=None, extra_args=[]):
        if not rootfs:
            rootfs = self._rootfs

        assert rootfs
        if rootfs_snapshot:
            rootfs += ",snapshot=on"

        args = self._args + ["-nographic"]
        rootfs_boot_opt = ",bootindex=0"
        if kernel:
            args += ["-kernel", kernel]
            # Disable boot index in case of direct kernel boot
            rootfs_boot_opt = ""
            if initrd:
                args += ["-initrd", initrd]
            if kargs:
                args += ["-append", kargs]
            else:
                args += ["-append", BaseVM.LINUX_CMDLINE]

        # Add rootfs options
        args += ["-drive", "file=%s,if=none,id=drive0,cache=writeback,discard=unmap" % rootfs,
                 "-device", "virtio-blk,drive=drive0%s" % rootfs_boot_opt]

        args += extra_args
        if self.cloud_init:
            self._cloud_init_dir = tempfile.TemporaryDirectory(prefix='cloud-init_')
            args += ['-drive', 'file=fat:' + self._cloud_init_dir.name +
                     ',if=virtio,file.label=cidata,readonly=on']
            for filename, content in self.cloud_init.items():
                with open(os.path.join(self._cloud_init_dir.name, filename), 'w') as f:
                    if filename == 'user-data':
                        f.write('#cloud-config\n')
                    f.write(yaml.dump(content))

        self.logger.info("QEMU command:")
        self.logger.info("%s", format_script([self._qemu_bin] + args))
        guest = qemu.QEMUMachine(binary=self._qemu_bin, args=args, name=self.name)
        try:
            guest.launch()
        except:
            self.logger.error("VM start failed:")
            self.logger.error("%s", format_script([self._qemu_bin] + args))
            self.logger.error("guest log:")
            self.logger.error(guest.get_log())
            raise

        self._guest = guest
        atexit.register(self.shutdown)

        resp = guest.qmp("human-monitor-command",
                                 command_line="info usernet")
        self.ssh_port = None
        for l in resp["return"].splitlines():
            fields = l.split()
            if "TCP[HOST_FORWARD]" in fields and "22" in fields:
                self.ssh_port = l.split()[3]
        if not self.ssh_port:
            raise QVMSetupError("Cannot find ssh port from 'info usernet': %s", resp)

        self.logger.info("SSH port: %s", self.ssh_port)

        self.logger.info("SSH command:")
        self.logger.info("%s", format_script(self._make_ssh_cmd([], force_term=True, user="root")))

    def wait_ssh(self, debug=False, timeout=300):
        remain = timeout
        now = datetime.datetime.now()
        end = now + datetime.timedelta(seconds=timeout)
        while datetime.datetime.now() < end:
            ret = self._ssh_cmd(['true'], debug=debug, timeout=10)
            if ret == 0:
                return
            remain = (end - datetime.datetime.now()).total_seconds()
            self.logger.debug("%ds before timeout", remain)
            time.sleep(1)
        raise QVMRuntimeError("Timeout while waiting for guest ssh, timeout:%d" % timeout)

    def shutdown(self):
        if self._guest:
            self._guest.shutdown()

    def wait(self):
        self._guest.wait()

    def qmp(self, *args, **kwargs):
        return self._guest.qmp(*args, **kwargs)

    def _fetch_img_with_cache(self, url, sha1sum=None):
        def check(fname):
            if not sha1sum:
                return True
            checksum = subprocess.check_output(["sha1sum", fname]).split()[0]
            return sha1sum == checksum

        def _do_fetch(url, out):
            # NOTE the stream=True parameter below
            with requests.get(url, stream=True) as r:
                with open(out, 'wb') as f:
                    for chunk in r.iter_content(chunk_size=8192):
                        # filter out keep-alive new chunks
                        if chunk:
                            f.write(chunk)
                    f.flush()
        cache_dir = os.path.expanduser("~/.cache/img/download")
        if not os.path.exists(cache_dir):
            os.makedirs(cache_dir)
        fname = os.path.join(cache_dir, hashlib.sha1(url).hexdigest())
        if os.path.exists(fname) and check(fname):
            self.logger.debug("Found good file in cache %s" % fname)
            return fname
        self.logger.debug("Fetch%s to %s...", url, fname)
        _do_fetch(url, fname + ".download")
        self.logger.debug("Save to cache %s -> %s" % (url, fname))
        os.rename(fname + ".download", fname)
        return fname
