import os
import signal
import logging

import subprocess as sp

from sandbox.common import os as common_os
from sandbox.common import config as common_config
from sandbox.common import patterns as common_patterns
from sandbox.common import itertools as common_itertools

from sandbox.agentr import errors as agentr_errors
from sandbox.client import errors, system

from . import base


logger = logging.getLogger(__name__)


# how long to wait until cgroup can be dropped (in seconds)
CGROUP_DROP_TIMEOUT = 3 * 60


class LinuxPlatform(base.Platform, base.ResolvConfMixin):
    @classmethod
    def _cgroup_name(cls, name):
        return os.path.join("/sandbox", "executor", name).rstrip(os.sep)

    def _cgroup_id(self):
        return str(getattr(self._cmd, "token", ""))

    @classmethod
    def make_cgroup(cls, name):
        if not common_os.User.has_root or common_config.Registry().client.porto.enabled:
            return
        cg = common_os.CGroup(name, owner=system.UNPRIVILEGED_USER.login)
        if cg is not None:
            logger.debug("Creating CGroup %s", cg)
            return cg.create()

    @common_patterns.singleton_property
    def cgroup(self):
        return self.make_cgroup(self._cgroup_name(self._cgroup_id()))

    @classmethod
    def _get_cgroup_processes(cls, group):
        for subgroup in group:
            for pid in cls._get_cgroup_processes(subgroup):
                yield pid
        for pid in group.tasks:
            yield pid

    def send_signal(self, sig, cgroups=None):
        if cgroups is None:
            cgroups = []
            if self.cgroup:
                cgroups.append(self.cgroup)
        if not cgroups:
            return super(LinuxPlatform, self).send_signal(sig)
        for cgroup in cgroups:
            for pid in self._get_cgroup_processes(cgroup.freezer):
                try:
                    os.kill(pid, sig)
                except OSError:
                    pass

    def _set_freeze_state(self, state, cgroups=None):
        if cgroups is None:
            cgroups = []
            if self.cgroup:
                cgroups.append(self.cgroup)
        for cgroup in cgroups:
            cgroup.freezer["state"] = state
            self.logger.debug("Cgroup %r %s", cgroup.name, state.lower())

    def _drop_cgroup(self, cgroups=None):
        if cgroups is None:
            cgroups = []
            if self.cgroup:
                cgroups.append(self.cgroup)
        for cgroup in cgroups:
            res, slept = common_itertools.progressive_waiter(
                0, 1, CGROUP_DROP_TIMEOUT,
                lambda: not next(self._get_cgroup_processes(cgroup.freezer), None),
            )
            if res:
                self.logger.debug("Dropping cgroup %r", cgroup.name)
                cgroup.delete()
            else:
                self.logger.error(
                    "Impossible to drop cgroup %r in %s seconds, still alive pids: %r",
                    cgroup.name, slept, list(self._get_cgroup_processes(cgroup.freezer)),
                )
                from sandbox.client.commands import ShutdownClientCommand
                ShutdownClientCommand.emergency_shutdown()

    def _clean_proc_debris(self, cgroups=None):
        self._set_freeze_state("FROZEN", cgroups=cgroups)
        self.send_signal(signal.SIGKILL, cgroups=cgroups)
        self._set_freeze_state("THAWED", cgroups=cgroups)
        self._drop_cgroup(cgroups=cgroups)

    def suspend(self):
        self._set_freeze_state("FROZEN")
        super(LinuxPlatform, self).suspend()
        self._set_freeze_state("THAWED")

    def resume(self):
        self._set_freeze_state("FROZEN")
        super(LinuxPlatform, self).resume()
        self._set_freeze_state("THAWED")

    @property
    def ramdrive(self):
        if self._cmd.task_id:
            return os.path.join(common_config.Registry().client.tasks.ramdrive, str(self._cmd.task_id))
        return None

    @ramdrive.setter
    def ramdrive(self, value):
        self._mount_ramdrive(value)

    @ramdrive.deleter
    def ramdrive(self):
        mountpoint = self.ramdrive
        if mountpoint:
            with system.UserPrivileges():
                self.unmount_ramdrive(mountpoint, logger=self.logger)

    def _mount_ramdrive(self, value, command_prefix=None):
        rd_type, size = value
        command_prefix = command_prefix or []
        ramdrive = self.ramdrive
        with system.UserPrivileges():
            if not sp.call(command_prefix + ["/bin/mountpoint", "-q", ramdrive]):
                self.logger.info("RAM drive %r already exists: skipping mounting ", ramdrive)
                return
            self.logger.info(
                "Mounting RAM drive of type %r and size %sM at %r.%s",
                rd_type, size, ramdrive,
                " Command prefix: {}".format(command_prefix) if command_prefix else "",
            )
            if not os.path.exists(ramdrive):
                os.makedirs(ramdrive)
            else:
                self.logger.warning("RAM mount point %r already exists", ramdrive)
            p = sp.Popen(
                command_prefix + ["/bin/mount", "-t", rd_type, "-o", "size={}M".format(size), rd_type, ramdrive],
                stdout=sp.PIPE,
                stderr=sp.PIPE,
            )
            stdout, stderr = p.communicate()
        if p.returncode:
            self.logger.error(
                "RAM drive mounting failed with code {rc}, output follows:\n"
                "{hr}STDERR{hr}\n{stderr}\n{hr}STDOUT{hr}\n{stdout}".format(
                    hr="-" * 40, stdout=stdout, stderr=stderr, rc=p.returncode
                )
            )
            raise errors.ExecutorFailed

    @classmethod
    def unmount_ramdrive(cls, mountpoint, command_prefix=None, logger=logger):
        result = True
        command_prefix = command_prefix or []
        tries = 0
        while tries < 2 and not sp.call(command_prefix + ["/bin/mountpoint", "-q", mountpoint]):
            result = True
            logger.info("Unmounting RAM drive %r, try %r", mountpoint, tries)
            p = sp.Popen(command_prefix + ["/bin/umount", mountpoint], stdout=sp.PIPE, stderr=sp.PIPE)
            stdout, stderr = p.communicate()
            if not p.returncode:
                try:
                    os.rmdir(mountpoint)
                    result = True
                except OSError as ex:
                    logger.warning("Failed to remove mountpoint %r: %r", mountpoint, ex)
            else:
                lsof_cmd = command_prefix + ["lsof", "+D", mountpoint]
                lsof_out, _ = sp.Popen(lsof_cmd, stdout=sp.PIPE, stderr=sp.STDOUT).communicate()
                logger.warning(
                    "Unmounting of RAM drive failed with code {rc}, output follows:\n"
                    "{hr}STDERR{hr}\n{stderr}\n{hr}STDOUT{hr}\n{stdout}\nOutput of {lsof_cmd}:\n{lsof_out}".format(
                        hr="-" * 40, stdout=stdout, stderr=stderr, rc=p.returncode,
                        lsof_cmd=lsof_cmd, lsof_out=lsof_out,
                    )
                )
                if "device is busy" in stderr or "device is busy" in stdout:
                    result = False
                tries += 1

        return result

    def _unmount_fuse(self, command_prefix=None):
        if command_prefix is None:
            # do not perform this on local sandbox outside of LXC container
            if system.SERVICE_USER == system.UNPRIVILEGED_USER:
                return
            command_prefix = []
        try:
            mounts = sp.check_output(
                command_prefix + ["findmnt", "-nl", "-o", "fstype,target"],
                stderr=sp.STDOUT
            ).strip()
        except sp.CalledProcessError as exc:
            # findmnt is missing on Ubuntu Lucid, for which fuse is not supported anyway
            if exc.returncode != 255:
                self.logger.error("Unable to find mounts:\n%s", exc.output)
            return
        for line in mounts.split("\n"):
            fstype, mount = line.split(None, 1)
            if system.local_mode() and fstype == "fuse.arc" and mount == common_config.Registry().devbox.arc_repo:
                continue
            if fstype.startswith("fuse.") and fstype != "fuse.lxcfs":
                try:
                    sp.check_output(command_prefix + ["fusermount", "-uz", mount], stderr=sp.STDOUT)
                except (sp.CalledProcessError, OSError) as exc:
                    self.logger.error("Unable to unmount %s:\n%s", mount, exc.output)
                    continue
                else:
                    self.logger.debug("Unmounted %s of type %s", mount, fstype)

    @property
    def tasks(self):
        pkg = next(iter(super(LinuxPlatform, self).tasks), None)
        if not pkg:
            return
        # TODO: SANDBOX-5068: Backward-compatibility for tasks with fixed __archive__
        with open(os.devnull, "wb") as devnull:
            sp.call(
                ["/bin/umount", pkg.dst],
                stdout=devnull, stderr=sp.STDOUT,
                preexec_fn=common_os.User.Privileges().__enter__
            )
        yield pkg

    def prepare(self, tasks_rid):
        try:
            self.agentr.turbo_boost(True)
        except agentr_errors.InvalidPlatform as er:
            self.logger.error("Impossible to enable turbo boost: %s", er)
        return super(LinuxPlatform, self).prepare(tasks_rid)

    def spawn(self, executor_args):
        self.set_resolv_conf_arg(executor_args)
        return super(LinuxPlatform, self).spawn(executor_args)
