import os
import sys
import time
import logging
import datetime

import subprocess as sp

import pathlib2

import sandbox.common.types.misc as ctm
import sandbox.common.types.task as ctt
import sandbox.common.types.client as ctc

import sandbox.agentr.client

from sandbox.common import fs as common_fs
from sandbox.common import rest as common_rest
from sandbox.common import config as common_config
from sandbox.common import format as common_format

from sandbox.client import errors, platforms, system
from sandbox.client.commands import base


logger = logging.getLogger(__name__)


class ServiceCommand(base.Command):
    """ All service commands base class. """

    def __init__(self):
        self.args.update({
            "id": ctc.ServiceTokens.SERVICE_TOKEN,
            "iteration": 0,
            "reserved_space": 0,
        })
        super(ServiceCommand, self).__init__()
        self.iteration = 0


class IdleCommand(ServiceCommand):
    """ Command to perform some actions on client idle """
    command = ctc.Command.IDLE

    def __init__(self, *_):
        token = common_fs.read_settings_value_from_file(common_config.Registry().client.mds.up.token, True)
        if token:
            self.args["mds_token"] = token
        super(IdleCommand, self).__init__()
        self.platform = platforms.Platform(self)
        self.kill_timeout = common_config.Registry().client.service_task_timeout

    def spawn(self):
        for file_name in os.listdir(common_config.Registry().client.executor.dirs.run):
            if file_name.startswith(base.EXECUTOR_ARGS_PREFIX):
                self.logger.warning("Removing stale args file %r", file_name)
            elif file_name.startswith(base.EXECUTOR_RESULT_PREFIX):
                self.logger.warning("Removing stale result file %r", file_name)
            else:
                continue
            os.unlink(os.path.join(common_config.Registry().client.executor.dirs.run, file_name))

        if common_config.Registry().client.lxc.enabled:
            platforms.PrivilegedLXCPlatform.maintain()

        if common_config.Registry().client.porto.enabled:
            platforms.PortoPlatform.maintain()

        super(IdleCommand, self).spawn()


class ClearCommand(ServiceCommand):
    """ base.Command for cleaning of disk space on the host """
    command = ctc.Command.CLEAR

    def __init__(self, *_):
        super(ClearCommand, self).__init__()
        self.platform = platforms.Platform(self)
        self.kill_timeout = common_config.Registry().client.service_task_timeout

    def on_terminate(self):
        return base.Command(IdleCommand.command)


class ReloadClientCommand(ServiceCommand):
    """ Reload client and fileserver command """
    SKYNET_UPGRADE_TIMEOUT = 1800  # Timeout for the skynet upgrade process execution.

    command = ctc.ReloadCommand.RESTART
    subsequent = None

    def __init__(self, *_):
        super(ReloadClientCommand, self).__init__()

    def __update_skynet(self):
        if system.local_mode():
            return
        self.logger.info("Checking for skynet updates.")
        try:
            mtime = os.stat("/skynet/.info").st_mtime
            with system.UserPrivileges():
                p = sp.Popen(
                    "/Berkanavt/skynet/bin/gosky",
                    stdout=sp.PIPE, stderr=sp.PIPE, shell=True
                )
                stdout, stderr = p.communicate()
            if p.returncode:
                self.logger.warning(
                    "Skynet update failed with code {rc}, output follows:\n"
                    "{hr}STDERR{hr}\n{stderr}\n{hr}STDOUT{hr}\n{stdout}".format(
                        hr="-" * 40, stdout=stdout, stderr=stderr, rc=p.returncode
                    )
                )
                if p.returncode == 1:
                    base.Command(ctc.ReloadCommand.SHUTDOWN).spawn()
                    sys.exit(1)
            updated = os.stat("/skynet/.info").st_mtime != mtime
            self.logger.info("Skynet %supdated. Output follows:\n%s", ("not " if not updated else ""), stdout)
        except Exception:
            self.logger.exception("Error occurred during update. Shutting down the service.")
            self.subsequent = ShutdownClientCommand()

    def poll(self):
        return self.status or None

    def spawn(self):
        from sandbox.client.pinger import PingSandboxServerThread
        pt = PingSandboxServerThread()
        pt._kamikadze_thread.ttl = self.SKYNET_UPGRADE_TIMEOUT
        self.__update_skynet()
        pt._kamikadze_thread.ttl = common_config.Registry().client.idle_time * 100
        if common_config.Registry().client.lxc.enabled:
            map(
                lambda _: platforms.LXCPlatform._destroy_container(platforms.Container(name=_), stop=True),
                platforms.LXCPlatform.running.keys()
            )

            self.logger.info("Removing LXC container templates")
            with open(os.devnull, "w") as devnull:
                for tmpl in pathlib2.Path(common_config.Registry().client.lxc.rootfs.template).iterdir():
                    if tmpl.name.isdigit():
                        self.logger.debug("Removing '%s'", tmpl)
                        with system.UserPrivileges():
                            sp.call(["/bin/umount", str(tmpl)], stderr=devnull, stdout=devnull)
                            sp.call(["/bin/rm", "-rf", str(tmpl)])

            self.logger.info("Removing LXC containers' log files")
            with system.UserPrivileges():
                platforms.LXCPlatform.empty_dir("/var/log/lxc")

        elif common_config.Registry().client.porto.enabled:
            platforms.porto.PortoContainerRegistry().destroy_namespace()
            platforms.porto.PortoContainerRegistry().remove_unused_layers()
        elif common_config.Registry().this.system.family == ctm.OSFamily.LINUX:
            df, dferr = sp.Popen(["df", "-P", "-t", "tmpfs"], stdout=sp.PIPE, stderr=sp.PIPE).communicate()
            if dferr:
                self.logger.warning("Failed to find temporary filesystems: %s", dferr)
            else:
                mounts = []
                for line in df.split("\n")[1:]:
                    split = line.rsplit("%", 1)
                    if len(split) > 1:
                        mount = split[1].strip()
                        if "/ramdrive" in mount:
                            self.logger.info("Found mounted ramdrive: %s", mount)
                            mounts.append(mount)
                with system.UserPrivileges():
                    for mount in mounts:
                        platforms.LinuxPlatform.unmount_ramdrive(mount, logger=self.logger)

        with self._status_lock:
            self.status = ctt.Status.SUCCESS

        self.logger.info("Going to %s", self.command.upper())
        pt.stop()
        pt.stopped.wait(common_config.Registry().client.idle_time)
        self.agentr = None
        sandbox.agentr.client.Service(self.logger).shutdown()

    def on_terminate(self):
        return self.subsequent


class RebootClientBase(ReloadClientCommand):
    """ Base class for reboot commands """

    command = None
    walle_command = None
    _WALLE_MAX_TRIES = 3
    _MAINTENANCE_TICKET = "SANDBOXALERT-63"
    _MAINTENANCE_TIMEOUT = 30 * 60 + 10  # seconds

    def _local_execute(self):
        pass

    def spawn(self):
        from sandbox.client.pinger import PingSandboxServerThread
        options = self.args.get("options", {})
        author = options.get("author")
        comment = options.get("comment")
        PingSandboxServerThread()._kamikadze_thread.ttl = common_config.Registry().client.idle_time * 100

        self.logger.info("STARTING %s with author %s and message: %s", self.command.upper(), author, comment)
        walle_token = common_config.Registry().common.walle.token
        if not walle_token:
            self.logger.warning("There is no Wall-e token, rebooting host on my own")

        if walle_token and common_config.Registry().common.installation in ctm.Installation.Group.NONLOCAL:
            walle_client = common_rest.Client(
                common_config.Registry().common.walle.api,
                auth=common_fs.read_settings_value_from_file(walle_token),
            )
            os_family = common_config.Registry().this.system.family
            for attempt in range(self._WALLE_MAX_TRIES):
                try:
                    if self.walle_command == ctc.ReloadCommand.REBOOT and os_family in ctm.OSFamily.Group.OSX:
                        walle_client.hosts[common_config.Registry().this.fqdn]["set-maintenance"](
                            reason="Author: {}. Reason: {}".format(author, comment),
                            ignore_cms=True,
                            cms_task_action="reboot",
                            ticket_key=self._MAINTENANCE_TICKET,
                            timeout_status="ready",
                            timeout_time=int(time.time()) + self._MAINTENANCE_TIMEOUT,
                        )
                        time.sleep(common_config.Registry().client.idle_time)
                        self._local_execute()
                    else:
                        walle_client.hosts[common_config.Registry().this.fqdn][self.walle_command](
                            reason="Author: {}. Reason: {}".format(author, comment),
                            ignore_cms=True,
                        )
                    break
                except common_rest.Client.HTTPError:
                    if attempt + 1 == self._WALLE_MAX_TRIES:
                        self._local_execute()
                    else:
                        time.sleep(common_config.Registry().client.idle_time)
        else:
            self._local_execute()

        self.logger.info("Going to %s", self.command.upper())
        start_time = datetime.datetime.utcnow()
        while (datetime.datetime.utcnow() - start_time).seconds < common_config.Registry().client.idle_time * 99:
            time.sleep(1)
            if PingSandboxServerThread().stopped.is_set():
                break
        PingSandboxServerThread().stop()


class RebootClientCommand(RebootClientBase, ReloadClientCommand):
    """ Reboot client command """
    command = ctc.ReloadCommand.REBOOT
    walle_command = "reboot"

    def setup_agentr(self):
        pass

    def _local_execute(self):
        from sandbox.client.pinger import PingSandboxServerThread
        self.logger.info("Make %s command by self.", self.command.upper())
        os_family = common_config.Registry().this.system.family
        client_tags = common_config.Registry().client.tags
        cmd = self.__get_reboot_cmd(os_family, client_tags)
        with system.UserPrivileges():
            while True:
                if os_family in ctm.OSFamily.Group.OSX and ctc.Tag.M1 in client_tags:
                    platforms.OSXPlatform.cleanup_rozetta_cache()
                sp.check_call(cmd)
                if not PingSandboxServerThread().stopped.is_set():
                    time.sleep(common_config.Registry().client.idle_time)
                else:
                    break

    @staticmethod
    def __get_reboot_cmd(os_family, client_tags):
        if os_family == ctm.OSFamily.CYGWIN:
            cmd = ["/cygdrive/c/Windows/system32/shutdown", "-t", "0", "-r", "-f"]
        elif os_family in ctm.OSFamily.Group.OSX:
            cmd = ["/sbin/shutdown", "-r", "now"]
        elif ctc.Tag.WINDOWS in client_tags:
            cmd = ['/mnt/c/Windows/System32/shutdown.exe', '/r', '/f', '/t', '0']
        else:
            cmd = ["/sbin/reboot"]
        return cmd


class PowerOffClientCommand(RebootClientBase, ReloadClientCommand):
    """ Reboot client command """
    command = ctc.ReloadCommand.POWEROFF
    walle_command = "power-off"

    def setup_agentr(self):
        pass

    def _local_execute(self):
        from sandbox.client.pinger import PingSandboxServerThread
        self.logger.info("Make %s command by self.", self.command.upper())
        with system.UserPrivileges():
            while True:
                sp.check_call(
                    ["/cygdrive/c/Windows/system32/shutdown", "-t", "0", "-f"]
                    if common_config.Registry().this.system.family == ctm.OSFamily.CYGWIN else
                    ["/sbin/shutdown -Ph now"]
                )
                if not PingSandboxServerThread().stopped.is_set():
                    time.sleep(common_config.Registry().client.idle_time)
                else:
                    break


class ResetClientCommand(ReloadClientCommand):
    """ Resets client state and restarts the client"""
    command = ctc.ReloadCommand.RESET

    def setup_agentr(self):
        pass

    @classmethod
    def reset_jobs(cls):
        for job in base.Command.registry.values():
            if job.token != ctc.ServiceTokens.SERVICE_TOKEN:
                job.cancel(errors.SessionExpired)
                try:
                    if job.platform:
                        job.platform.cleanup()
                        job.platform = None
                except Exception:
                    logger.exception("Error on job %r platform cleanup.", job)

    def spawn(self):
        from sandbox.client.pinger import PingSandboxServerThread
        pt = PingSandboxServerThread()
        pt._kamikadze_thread.ttl = common_config.Registry().client.idle_time * 100

        self.logger.info("Stopping jobdropper")
        jobdropper = pt._service_threads.pop("_jobdropper", None)
        pt._dropajob.put(None)
        if jobdropper:
            jobdropper.join()

        self.reset_jobs()

        pt.stop()
        pt.stopped.wait(common_config.Registry().client.idle_time)
        pt.agentr.reset()

    @classmethod
    def initial_reset(cls):
        from sandbox.client.pinger import PingSandboxServerThread
        pt = PingSandboxServerThread()
        cls.reset_jobs()
        base.Command.registry = {}
        pt.set_stopped()
        pt.agentr.reset()


class CleanupClientCommand(ClearCommand):
    """ Cleanup client command """
    command = ctc.ReloadCommand.CLEANUP

    def spawn(self):
        from sandbox.client.pinger import PingSandboxServerThread
        reason = PingSandboxServerThread().need_cleanup()
        if not reason:
            self.logger.info("No need to cleanup.")
            self.status = ctt.Status.EXCEPTION
        else:
            self.logger.info("Cleaning up the client because of lack of free space (%s).", reason)
            self.command = ClearCommand.command
            super(CleanupClientCommand, self).spawn()

    def on_terminate(self):
        return base.Command(CleanupClientCommand.command) if self.status == ctt.Status.SUCCESS else None


class ShutdownClientCommand(ReloadClientCommand):
    """ Shutdown client command """
    command = ctc.ReloadCommand.SHUTDOWN

    def setup_agentr(self):
        pass

    def spawn(self):
        from sandbox.client.pinger import PingSandboxServerThread
        pt = PingSandboxServerThread()
        pt._kamikadze_thread.ttl = common_config.Registry().client.idle_time * 100
        self.logger.info("STARTING %s", self.command.upper())

        self.logger.info("Stopping jobgetter")
        jobgetter = pt._service_threads.pop("_jobgetter", None)
        if jobgetter:
            jobgetter.join()

        self.logger.info("Cancelling jobs")
        for job in base.Command.registry.values():
            self.logger.info("Cancelling job %s: %s", common_format.obfuscate_token(job.token), job)
            job.cancel(errors.ShuttingDown, reason="Client is shutting down")

        # if common_config.Registry().client.porto.enabled:
        #     platforms.porto.PortoContainerRegistry().destroy_namespace()
        #     platforms.porto.PortoContainerRegistry().remove_unused_layers()

        for cnt in (platforms.LXCPlatform.running.keys() if common_config.Registry().client.lxc.enabled else []):
            try:
                platforms.PrivilegedLXCPlatform._destroy_container(platforms.Container(name=cnt), stop=True)
            except Exception:
                self.logger.exception("Error on destroying container %r", cnt)
        with open(os.path.join(common_config.Registry().client.dirs.run, "client_check_stop"), "a"):
            pass
        self.logger.info("Going to %s", self.command.upper())
        pt.stop()

    @classmethod
    def emergency_shutdown(cls):
        base.Command(cls.command).start()
