#!/usr/bin/python

import Queue
import logging
import os
import sys
import time
import uuid
import shutil

from threading import Thread, Lock

from porto.exceptions import (SocketError as PortoSocketError,
                              ContainerDoesNotExist as PortoContainerNotExists)
from google.protobuf import json_format
from infra.vmagent.src.vmagent_pb import vmagent_pb2, vmagent_api_pb2

from . import config as server_config
from . import helpers
from .helpers import log_trace, log_msg, SkynetError

GIGA = 1024 * 1024 * 1024


def _init_state(state_type, msg=""):
    ret = vmagent_pb2.VMState()
    ret.type = state_type
    if msg:
        ret.info = msg

    return ret


class BadAction(Exception):
    pass


class RuntimeError(Exception):
    pass


class DownloadingImageError(Exception):
    pass


EMPTY = _init_state(vmagent_pb2.VMState.EMPTY)
CONFIGURED = _init_state(vmagent_pb2.VMState.CONFIGURED)
STOPPED = _init_state(vmagent_pb2.VMState.STOPPED)
RUNNING = _init_state(vmagent_pb2.VMState.RUNNING)
BUSY = _init_state(vmagent_pb2.VMState.BUSY)
PREPARING = _init_state(vmagent_pb2.VMState.PREPARING)
CRASHED = _init_state(vmagent_pb2.VMState.CRASHED)
READY_TO_START = frozenset((vmagent_pb2.VMState.CONFIGURED, vmagent_pb2.VMState.CRASHED))


class VMWorker(object):
    def __init__(self, mode="gencfg"):
        self._queue = Queue.Queue(maxsize=1)
        self._config = vmagent_pb2.VMConfig()
        self._state = vmagent_pb2.VMState()
        self._link_alive = False
        self._net_alive = False
        self._state_lock = Lock()
        self._shutdown_needed = False
        self._state.CopyFrom(EMPTY)
        self._qemu = None  # Target for qemu Popen object
        self._process = None
        self._mode = mode
        self.log = logging.getLogger('VMWorker')

    def emergency_purge(self):
        self._set_state(BUSY)
        self._revert_qemu()
        helpers.remove_disk(self._config.disk.resource.path)
        helpers.purge_dir_contents(server_config.STORAGE_PATH)
        self._store_config(vmagent_pb2.VMConfig())
        self._set_state(EMPTY)

    def init(self):
        self.log.info('Starting worker...')
        last_config = self._get_last_config()
        payload_config = self._get_payload_config()
        if last_config is None and payload_config is None:
            self._store_config(vmagent_pb2.VMConfig())
            return

        # Store last config in case of update
        if last_config is not None:
            self._store_config(last_config)
        if payload_config and last_config and payload_config.id != last_config.id:
            # Id mismatch means appearing new payload config file, so need to run config
            self.log.info('{} is newer, will use it'.format(server_config.DEFAULT_CONFIG_PATH))
            self._handle_config(payload_config, timeout=None)
            return

        # If no disk presence prepare any available config, priority to last
        # At this step we know that they're equal if both present
        if not os.path.exists(server_config.IMAGE_PATH):
            self._handle_config(last_config or payload_config, timeout=None)
            return

        # Get last state and generation
        state = self._get_init_state()
        self._set_state(state)

        # Delta exists, check qemu running
        self._qemu = helpers.get_qemu_container(new_one=False)
        self._check_qemu()

    def _get_last_config(self):
        """
        Parses configuration from last saved file
        """
        if os.path.exists(server_config.CONFIG_PATH):
            self.log.info('Found config in {} file'.format(server_config.CONFIG_PATH))
            config_pb = vmagent_pb2.VMConfig()
            with open(server_config.CONFIG_PATH) as f:
                config_pb.ParseFromString(f.read())
            return config_pb
        return None

    def _get_payload_config(self):
        """
        Parses configuration from iss configuration payload vm.config file
        """
        if os.path.exists(server_config.DEFAULT_CONFIG_PATH):
            self.log.info('Found config in {} file'.format(server_config.DEFAULT_CONFIG_PATH))
            config_pb = vmagent_pb2.VMConfig()
            with open(server_config.DEFAULT_CONFIG_PATH) as f:
                json_format.Parse(f.read(), config_pb)
            return config_pb
        return None

    def _store_config(self, config):
        self._config = config

        with open(server_config.CONFIG_PATH, "w") as f:
            f.write(config.SerializeToString())

    def _get_init_state(self):
        if os.path.exists(server_config.STATE_PATH):
            state = vmagent_pb2.VMState()
            with open(server_config.STATE_PATH) as f:
                state.ParseFromString(f.read())
            if state.type == vmagent_pb2.VMState.RUNNING:
                self.log.warn('Found last state RUNNING, assuming unexpected stop')
                state.type = vmagent_pb2.VMState.CRASHED
            if state.type == vmagent_pb2.VMState.EMPTY:
                self.log.warn('Found last state EMPTY, assuming unexpected stop')
                state.type = vmagent_pb2.VMState.CRASHED
            return state
        else:
            # Most likely vmagent update, do nothing
            self.log.warn('File with last status not found, assuming STOPPED state')
            return STOPPED

    def check_request(self, task):
        if task.action != vmagent_api_pb2.VMActionRequest.PUSH_CONFIG:
            return None

        if not task.config.disk.resource.rb_torrent:
            return None

        if task.config.vcpu <= 0:
            raise BadAction("Invalid vcpu count")

        if task.config.mem <= 0:
            raise BadAction("Invalid mem demand")

        if task.config.type == vmagent_pb2.VMConfig.WINDOWS and not server_config.WINDOWS_READY:
            raise BadAction("Cannot start windows workload on this agent")

        return None

    def is_alive(self):
        return self._process.isAlive()

    def start(self):
        self._process = Thread(target=self.process)
        self._process.start()

    def stop(self):
        self._state_lock.acquire()
        self._shutdown_needed = True
        self._state_lock.release()
        self._process.join()

    def push_task(self, task):
        try:
            self.check_request(task)
        except BadAction as e:
            return e.message

        try:
            self._queue.put(task, block=False)
        except Queue.Full:
            return "Request already pending"

        return None

    def get_config(self):
        self._state_lock.acquire()
        config = vmagent_pb2.VMConfig()
        config.CopyFrom(self._config)
        self._state_lock.release()

        return config

    def get_state(self):
        self._state_lock.acquire()
        res = vmagent_pb2.VMState()
        res.CopyFrom(self._state)
        self._state_lock.release()

        return res

    def _set_state(self, state):
        if self._state.type == state.type and self._state.info == state.info:
            return self._state
        self._state_lock.acquire()
        prev = vmagent_pb2.VMState()
        prev.CopyFrom(self._state)
        self._state.CopyFrom(state)
        if prev.backup.HasField('creation_time'):
            self._state.backup.CopyFrom(prev.backup)
            self._state.backup.image_url = ''
            self._state.backup.delta_url = ''
        if not state.generation:
            if state.type == vmagent_pb2.VMState.RUNNING or state.type == vmagent_pb2.VMState.PREPARING:
                self._state.generation = prev.generation + 1
            else:
                self._state.generation = prev.generation
        if state.type != vmagent_pb2.VMState.BUSY:
            # BUSY means nothing do not write it
            with open(server_config.STATE_PATH, 'w') as f:
                f.write(self._state.SerializeToString())
        self.log.info('Set new state: %s', self._state)
        self._state_lock.release()
        return prev

    def _set_backup(self, backup):
        self._state_lock.acquire()
        self._state.type = vmagent_pb2.VMState.STOPPED
        self._state.backup.CopyFrom(backup)
        with open(server_config.STATE_PATH, 'w') as f:
            f.write(self._state.SerializeToString())
        self._state_lock.release()

    def get_net_state(self):
        self._state_lock.acquire()
        net_state = (self._link_alive, self._net_alive)
        self._state_lock.release()
        return net_state

    def _set_net_state(self, link_alive, net_alive):
        self._state_lock.acquire()
        self._link_alive = link_alive
        self._net_alive = net_alive
        self._state_lock.release()

    def _check_qemu(self):
        ret = None

        if not self._qemu:
            if self._config.autorun and self._state.type in READY_TO_START:
                self.log.info('Autorun enabled, current state = {}, starting VM'.format(
                    vmagent_pb2.VMState.VMStateType.Name(self._state.type)
                ))
                self._start_qemu()
            return None

        try:
            wait_result = self._qemu.Wait(timeout=1000)  # ms here
            if "qemu" in wait_result:
                ret = self._qemu.GetProperty("exit_code")
                # stderr = self._qemu.GetProperty("stderr")

                if ret == "0":
                    self._set_state(STOPPED)
                else:
                    self._set_state(
                        _init_state(
                            vmagent_pb2.VMState.CRASHED,
                            # FIXME: add stdout/stderr here
                            "qemu exit code: {}".format(ret)
                        ))

                try:
                    self._qemu.Destroy()

                except PortoContainerNotExists:
                    pass

                self._qemu = None
                self._set_net_state(False, False)

            else:
                self._set_state(RUNNING)
                self._set_net_state(helpers.ping_vm(), helpers.connect_vm())

        except PortoContainerNotExists:
            self._qemu = None
            self._set_state(_init_state(vmagent_pb2.VMState.CRASHED,
                                        "qemu container dropped unexpectedly"))

        except PortoSocketError as e:
            log_msg("Got porto socket error: {}".format(e))
            self._set_state(_init_state(vmagent_pb2.VMState.CRASHED,
                            "Porto socket error"))

        return ret

    def _restart_qemu(self, timeout=10):
        if not self._qemu:
            return None

        self._set_state(BUSY)
        self._shutdown_qemu(timeout)
        self._set_state(STOPPED)
        self._start_qemu()

    def _revert_qemu(self):
        ret = None
        self._set_state(BUSY)
        if self._qemu:
            ret = self._poweroff_qemu()

        if self._config.disk.type == vmagent_pb2.VMDisk.RAW:
            if self._config.id != "":
                self._handle_config(self._config, timeout=None)
            return ret

        helpers.drop_delta()
        if self._config.id != "":
            self._set_state(CONFIGURED)
            if self._config.autorun:
                self.log.info('Autorun enabled, starting VM')
                self._start_qemu()
        else:
            self._set_state(EMPTY)

        return ret

    def _reset_qemu(self):
        if not self._qemu:
            return None

        helpers.qemu_mon("system_reset")

        return None

    def _shutdown_qemu(self, timeout=10):
        if not self._qemu:
            return None

        try:
            helpers.qemu_mon("system_powerdown")
        except Exception as e:
            log_msg("Cannot send shutdown command: {}".format(e))

    def _poweroff_qemu(self, timeout=10):
        if not self._qemu:
            return None

        self._set_state(BUSY)
        ret = None

        try:
            helpers.qemu_mon("quit")

            wait_result = self._qemu.Wait(timeout=timeout * 1000)
            if 'qemu' not in wait_result:
                log_msg("Qemu mon quit timeout exceeded")

            else:
                ret = self._qemu.GetProperty('exit_code')

        except Exception as e:
            log_msg('Cannot stop porto with monitor: {}'.format(e))

        if ret is None:
            try:
                self._qemu.Stop(timeout=timeout)
                ret = '9'
            except Exception as e:
                raise RuntimeError("Cannot stop qemu: {}".format(e))

        if ret != '0':
            self._set_state(_init_state(
                vmagent_pb2.VMState.CRASHED,
                "Qemu exited with code: {}".format(ret)
            ))

        else:
            self._set_state(STOPPED)

        self._qemu.Destroy()
        self._qemu = None

        return ret

    def _start_qemu(self, rescue=False):
        if self._qemu:
            self.log.warn('VM started already')
            self._set_state(RUNNING)
            return
        if self._state.type in (vmagent_pb2.VMState.EMPTY, vmagent_pb2.VMState.INVALID):
            self.log.warn('Cannot start: vm in %s state', vmagent_pb2.VMState.VMStateType.Name(self._state.type))
            return

        try:
            self._qemu = helpers.get_qemu_container()

            self._set_state(BUSY)
            env = os.environ.copy()

            env['VCPU_NR'] = str(self._config.vcpu)
            env['MEM'] = str(self._config.mem / 1048576)
            env['DISK_SIZE'] = str(self._config.disk.delta_size)
            env['IMG_PATH'] = server_config.IMG_PATH
            env['CURRENT_PATH'] = server_config.CURRENT_PATH
            env['STORAGE_PATH'] = server_config.STORAGE_PATH
            env['INSTANCE_ID'] = server_config.INSTANCE_ID
            env['LOG_DIR'] = server_config.LOG_DIR
            env['SERIAL_PATH'] = server_config.SERIAL_PATH
            env['VM_IP'] = server_config.VM_IP
            env['VM_AUX_IP'] = server_config.VM_AUX_IP
            env['MODE'] = self._mode
            env['IMAGE_TYPE'] = vmagent_pb2.VMDisk.ImageType.Name(self._config.disk.type)
            env['SIZE_FOR_RAW'] = str(self._config.disk.delta_size)

            for var, value in env.items():
                self._qemu.SetProperty('env[{}]'.format(var), value)

            if rescue:
                qemu_script = server_config.RESCUE_LAUNCHER_PATH
            elif self._config.type == vmagent_pb2.VMConfig.WINDOWS:
                qemu_script = server_config.QEMU_WIN_LAUNCHER_PATH
            else:
                qemu_script = server_config.QEMU_LAUNCHER_PATH

            self._qemu.SetProperty("stdin_path", "/dev/null")
            self._qemu.SetProperty("command", " ".join(["/bin/bash", "-x", qemu_script]))
            self._qemu.SetProperty("cwd", os.getcwd())

            self._qemu.Start()

            helpers.qemu_mon("set_password vnc {}".format(self._config.access_info.vnc_password))

        except Exception as e:
            log_msg("Exception while starting: {}".format(e))
            log_trace(sys.exc_info()[2])

            # info = e.message
            stderr = ""

            if self._qemu:
                try:
                    stderr = self._qemu.GetProperty('stderr')
                    self._qemu.Destroy()
                except Exception:
                    pass

                self._qemu = None

            self._set_state(_init_state(
                            vmagent_pb2.VMState.CRASHED,
                            "Failed to start qemu: {}, stderr: {}".format(
                                e.message, " ; ".join(stderr.split('\n')[-20:])
                            )))

            return

        self._set_state(RUNNING)

    def _backup(self):
        # Stop virtual machine, you've been warned
        self._poweroff_qemu()
        self._set_state(BUSY)
        backup = vmagent_pb2.Backup()
        if self._config.disk.type == vmagent_pb2.VMDisk.DELTA:
            backup.image_url = helpers.share_resource(self._config.disk.resource.path)
            backup.delta_url = helpers.share_resource(server_config.DELTA_PATH)
        elif self._config.disk.type == vmagent_pb2.VMDisk.RAW:
            backup.image_url = helpers.share_resource(self._config.disk.resource.path)
        backup.creation_time.GetCurrentTime()
        if self._mode == 'gencfg':
            self._run_backup_tasks(backup)
        self._set_backup(backup)

    def _run_backup_tasks(self, backup):
        size = helpers.get_image_size(server_config.IMAGE_PATH)
        dc = server_config.HOST_DC
        if self._config.disk.type == vmagent_pb2.VMDisk.RAW:
            try:
                backup.sandbox_task_url = helpers.run_sandbox_upload_task(backup.image_url, size, dc)
            except Exception as e:
                self.log.error('Run sandbox task failed: {}'.format(e))
        elif self._config.disk.type == vmagent_pb2.VMDisk.DELTA:
            size += helpers.get_image_size(server_config.DELTA_PATH)
            try:
                backup.sandbox_task_url = helpers.run_sandbox_merge_task(backup.image_url, backup.delta_url, size, dc)
            except Exception as e:
                self.log.error('Run sandbox task failed: {}'.format(e))

    def _loop(self, prev):
        self._set_state(prev)

        self._qemu = helpers.get_qemu_container(new_one=False)
        self._check_qemu()

        while not self._shutdown_needed:
            try:
                task = self._queue.get(timeout=0.1)
                self._handle_cmd(task)
            except Queue.Empty:
                pass

            prev = self.get_state()
            self._check_qemu()

    def process(self):
        self._state_lock.acquire()
        self._shutdown_needed = False
        self._state_lock.release()

        while not self._shutdown_needed:
            prev = self._set_state(_init_state(
                vmagent_pb2.VMState.BUSY,
                "Starting skynet container"
            ))

            try:
                if self._mode == "gencfg":
                    with helpers.SkynetContainer():
                        self._loop(prev)

                else:
                    self._loop(prev)

            except Exception as e:
                self._set_state(prev)
                log_msg('Unhandled exception in worker thread: {}'.format(e.message))
                log_trace(sys.exc_info()[2])
                time.sleep(5)

    def _handle_cmd(self, task):
        if task.action == vmagent_api_pb2.VMActionRequest.PUSH_CONFIG:
            self._handle_config(task.config, task.timeout)
            return

        try:
            self._handle_command(task.action, task.timeout)
        except BadAction as e:
            log_msg(e.message)
        except Exception as e:
            log_trace(sys.exc_info()[2])
            log_msg(e.message)

            self._set_state(
                _init_state(
                    vmagent_pb2.VMState.CRASHED,
                    # FIXME: add stdout/stderr here
                    "last command execution error"
                ))

    def _handle_command(self, action, timeout):
        if action == vmagent_api_pb2.VMActionRequest.START:
            self._start_qemu()

        elif action == vmagent_api_pb2.VMActionRequest.SHUTDOWN:
            self._shutdown_qemu()

        # works incorrect, comment for now
        # elif action == vmagent_api_pb2.VMActionRequest.RESTART:
        #     self._restart_qemu()

        elif action == vmagent_api_pb2.VMActionRequest.RESET:
            self._reset_qemu()

        elif action == vmagent_api_pb2.VMActionRequest.POWEROFF:
            self._poweroff_qemu()

        elif action == vmagent_api_pb2.VMActionRequest.HARD_RESET:
            self._revert_qemu()

        elif action == vmagent_api_pb2.VMActionRequest.RESCUE:
            self._start_qemu(rescue=True)

#       NIY
#       elif action == vmagent_api_pb2.VMActionRequest.SNAPSHOT:
#            pass

        elif action == vmagent_api_pb2.VMActionRequest.BACKUP:
            self._backup()

        else:
            raise BadAction("Unknown action")

    def _handle_config(self, config, timeout):
        if not config.disk.resource.rb_torrent:
            self.log.warn('Invalid image url, config aborted')
            return

        self.log.info('Applying config\n{}'.format(config))
        self._poweroff_qemu()

        old_config = self.get_config()
        self._store_config(vmagent_pb2.VMConfig())
        self._set_state(PREPARING)
        if not os.path.exists(server_config.IMAGE_PATH) \
                or config.disk.resource.rb_torrent != old_config.disk.resource.rb_torrent:
            try:
                self._replace_disk_image(config)
            except DownloadingImageError as e:
                if os.path.exists(server_config.IMAGE_PATH):
                    os.unlink(server_config.IMAGE_PATH)
                shutil.rmtree(server_config.IMAGE_FOLDER_PATH)
                self._set_state(_init_state(vmagent_pb2.VMState.INVALID, e.message))
                return

        if not config.id:
            config.id = str(uuid.uuid4())

        if not config.access_info.vnc_password:
            config.access_info.vnc_password = helpers.gen_vnc_password()

        self._store_config(config)
        self._set_state(CONFIGURED)

    def _replace_disk_image(self, config):
        self.log.info('Replacing disk image file')
        helpers.drop_delta()
        helpers.remove_disk(server_config.IMAGE_PATH)
        shutil.rmtree(server_config.IMAGE_FOLDER_PATH, ignore_errors=True)
        os.makedirs(server_config.IMAGE_FOLDER_PATH)

        try:
            helpers.get_resource(config.disk.resource.rb_torrent, server_config.IMAGE_FOLDER_PATH)
        except SkynetError as e:
            raise DownloadingImageError("Sky get error: {}".format(e.message))
        resource_path = "/".join([server_config.IMAGE_FOLDER_PATH, os.listdir(server_config.IMAGE_FOLDER_PATH)[0]])
        if not os.access(resource_path, os.R_OK):
            raise DownloadingImageError("Disk resource not found")

        try:
            os.symlink(resource_path, server_config.IMAGE_PATH)
        except OSError as e:
            raise DownloadingImageError("Cannot make symlink to image: {}".format(e))

        allocated_size = helpers.get_storage_size(server_config.STORAGE_PATH)
        image_size = helpers.get_image_size(server_config.IMAGE_PATH)
        virtual_size = helpers.get_virtual_size(server_config.IMAGE_PATH)
        delta_size = allocated_size - image_size - GIGA
        delta_raw_size = allocated_size - GIGA
        disk_size = delta_size if config.disk.type == vmagent_pb2.VMDisk.DELTA else delta_raw_size

        if virtual_size > disk_size:
            msg = "Allocated size does not enough for this image: should be at least ({}), got({})".format(
                virtual_size - disk_size + allocated_size,
                allocated_size
            )
            raise DownloadingImageError(msg)

        config.disk.resource.path = resource_path
        server_config.DISK_SIZE = config.disk.delta_size = disk_size
        config.disk.size = image_size


if __name__ == "__main__":
    pass
