import os
import pwd
import sys
import errno
import signal
import shutil
import logging
import tarfile
import collections

import itertools as it
import functools as ft

import subprocess32 as sp

import sandbox.common.types.misc as ctm
import sandbox.common.types.task as ctt
import sandbox.common.types.client as ctc

from sandbox.common import fs as common_fs
from sandbox.common import os as common_os
from sandbox.common import config as common_config
from sandbox.common import format as common_format
from sandbox.common import package as common_package
from sandbox.common import patterns as common_patterns
from sandbox.common import itertools as common_itertools

from sandbox.agentr import client as agentr_client

from sandbox.client import base, errors, system


logger = logging.getLogger(__name__)

SANDBOX_DIR = "SANDBOX_DIR"


class Platform(base.Serializable):
    Package = collections.namedtuple("Package", ("name", "src", "dst"))

    _cmd = None  # for IDE only

    def __new__(cls, cmd):
        from sandbox.client.platforms import (
            LinuxPlatform,
            LXCPlatform,
            PrivilegedLXCPlatform,
            PortoPlatform,
            PrivilegedPortoPlatform,
            BSDPlatform,
            OSXPlatform,
            CygwinPlatform,
            WSLPlatform,
        )
        from sandbox.client.commands import (
            ExecutePrivilegedTaskCommand
        )

        if cls != Platform:
            return super(Platform, Platform).__new__(cls, cmd)
        if common_config.Registry().this.system.family == ctm.OSFamily.FREEBSD:
            cls = BSDPlatform
        elif common_config.Registry().this.system.family in (ctm.OSFamily.LINUX, ctm.OSFamily.LINUX_ARM):
            container = cmd.args.get("container")
            if (
                container and
                not common_config.Registry().client.lxc.enabled and
                not common_config.Registry().client.porto.enabled
            ):
                logger.error("Neither LXC, no Porto are available")
                raise errors.ExecutorFailed
            if isinstance(cmd, ExecutePrivilegedTaskCommand):
                if common_config.Registry().client.lxc.enabled and container:
                    cls = PrivilegedLXCPlatform
                elif common_config.Registry().client.porto.enabled and container:
                    cls = PrivilegedPortoPlatform
                else:
                    logger.error("Cannot execute privileged task without a container.")
                    raise errors.ExecutorFailed
            else:
                if common_config.Registry().client.porto.enabled:
                    cls = PortoPlatform
                elif container:
                    if common_config.Registry().client.lxc.enabled:
                        cls = LXCPlatform
                elif ctc.Tag.WINDOWS in common_config.Registry().client.tags:
                    cls = WSLPlatform
                else:
                    cls = LinuxPlatform
        elif common_config.Registry().this.system.family in ctm.OSFamily.Group.OSX:
            cls = OSXPlatform
        elif common_config.Registry().this.system.family == ctm.OSFamily.CYGWIN:
            cls = CygwinPlatform
        else:
            raise Exception("Unsupported platform {!r}".format(common_config.Registry().this.system.family))
        return super(Platform, Platform).__new__(cls, cmd)

    def __init__(self, cmd):
        self._cmd = cmd
        self.logger = self._cmd.logger
        super(Platform, self).__init__()

    def __repr__(self):
        return "<{}>".format(self.name)

    @classmethod
    def native_path(cls, path):
        return path

    @property
    def name(self):
        return type(self).__name__

    @property
    def executable(self):
        if system.local_mode():
            return os.path.realpath(sys.executable)
        return os.path.join(system.SERVICE_USER.home, "venv", "bin", "python")

    @property
    def executor_path(self):
        if system.local_mode():
            return os.path.join(os.environ[SANDBOX_DIR], "bin", "executor.py")
        return os.path.join(system.SERVICE_USER.home, "client", "sandbox", "bin", "executor.py")

    @property
    def preexecutor_path(self):
        if common_config.Registry().common.installation == ctm.Installation.LOCAL:
            return os.path.join(common_config.Registry().common.dirs.data, "preexecutor", "preexecutor")
        if common_config.Registry().common.installation == ctm.Installation.TEST:
            return os.path.join(os.environ[SANDBOX_DIR], "executor", "preexecutor", "preexecutor")
        return os.path.join(system.SERVICE_USER.home, "preexecutor")

    @property
    def config_path(self):
        if common_config.Registry().common.installation in ctm.Installation.Group.LOCAL:
            return str(common_config.Registry().custom)
        return os.path.join(system.SERVICE_USER.home, "client", "sandbox", "etc", "settings.yaml")

    @property
    def venv(self):
        if not system.local_mode() and os.environ.get("EXECUTABLE"):
            yield self.Package(
                "venv",
                os.path.dirname(os.path.dirname(os.environ["EXECUTABLE"])),
                os.path.join(system.SERVICE_USER.home, "venv")
            )

    @property
    def client(self):
        if not system.local_mode():
            yield self.Package(
                "client",
                os.path.dirname(os.environ[SANDBOX_DIR]),
                os.path.join(system.SERVICE_USER.home, "client")
            )

    @property
    def tasks(self):
        if not system.local_mode():
            yield self.Package(
                "tasks",
                common_config.Registry().client.tasks.code_dir,
                os.path.join(system.SERVICE_USER.home, "tasks")
            )

    @common_patterns.singleton_property
    def agentr(self):
        """
        :rtype: sandbox.agentr.client.Session
        """
        return agentr_client.Session.service(self.logger)

    @classmethod
    def ensure_taskdir(cls, task_id):
        task_workdir = common_config.Registry().client.tasks.data_dir
        sbid = pwd.getpwnam(system.UNPRIVILEGED_USER.login)

        def ensure(path):
            if not os.path.exists(path):
                logger.debug("Creating task's directory %r owned by user %r", path, system.UNPRIVILEGED_USER.login)
                os.mkdir(path)
                os.chown(path, sbid.pw_uid, sbid.pw_gid)
            return path

        for subdir in ctt.relpath(task_id):
            task_workdir = ensure(os.path.join(task_workdir, subdir))
        return task_workdir

    @classmethod
    def empty_dir(cls, path, create=True, rm_msg="Removing directory %r", mk_msg="Creating directory %r"):
        """ Ensures the given directory exists and is empty. """
        if os.path.exists(path):
            logger.debug(rm_msg, path)
            common_fs.recursive_remove_immutable(path)
            common_fs.cleanup(path)
        if create:
            logger.debug(mk_msg, path)
            os.makedirs(path)
        return path

    @classmethod
    def logged_command(cls, prefix, cmd, timeout=None, privileged=False):
        logger.debug(
            "%s %s command %r with %s timeout", prefix, ("by privileged" if privileged else "by"), cmd, timeout
        )
        kws = dict(close_fds=True, stdout=sp.PIPE, stderr=sp.PIPE)
        if privileged:
            kws["preexec_fn"] = common_os.User.Privileges().__enter__
        p = sp.Popen(cmd, **kws)
        try:
            outs, errs = p.communicate()
            if p.returncode:
                logger.error(
                    "%s failed with return code %d. STDOUT:\n%s\nSTDERR:\n%s", prefix, p.returncode, outs, errs
                )
                cls._logged_command_check_critical_errors(errs)
                raise errors.InfraError(
                    "error while {}".format(prefix.lower()), stdout=outs, stderr=errs, returncode=p.returncode
                )
        except sp.TimeoutExpired as ex:
            logger.error("Command timed out: %s. Killing process.", ex)
            p.kill()

    @classmethod
    def _logged_command_check_critical_errors(cls, errs):
        pass

    def prepare(self, tasks_rid):
        self.logger.info("Actualizing packages.")
        if tasks_rid and self._cmd.exec_type == ctt.ImageType.INVALID:
            raise errors.InvalidJob("Invalid resource to execute task: #{}".format(tasks_rid))

        with system.UserPrivileges.lock:
            # Always try to update venv and client
            pkg_sources = [self.venv, self.client]
            # Always download tasks code for custom tasks code archive
            force_tasks = tasks_rid and self._cmd.exec_type == ctt.ImageType.CUSTOM_ARCHIVE
            # Update tasks code archive for non-binary executors
            if tasks_rid and self._cmd.exec_type in ctt.ImageType.Group.EXTRACTABLE:
                pkg_path = self.agentr.resource_sync(tasks_rid, fastbone=False)
                pkg_meta = self.agentr.resource_meta(tasks_rid)
                pkg_meta = {"revision": pkg_meta["attributes"].get("commit_revision", "0")}
                updater = common_package.PackageUpdater(self.logger)
                if updater.update_package("tasks", pkg_meta, force=force_tasks, pkg_path=pkg_path):
                    sys.modules.pop("projects", None)
                # Update tasks only for non-binary executors
                pkg_sources.append(self.tasks)

            for pkg in common_itertools.chain(*pkg_sources):
                actual = common_package.directory_version(pkg.src)
                current = common_package.directory_version(pkg.dst)
                self.logger.info(
                    "Current %s %r version %r, %r actual %r", pkg.name, pkg.dst, current, pkg.src, actual
                )
                if (not current or actual != current) and pkg.dst != pkg.src or pkg.name == "tasks" and force_tasks:
                    common_fs.cleanup(pkg.dst)
                    if os.path.isfile(pkg.src):
                        shutil.copy(pkg.src, pkg.dst)
                    else:
                        shutil.copytree(pkg.src, pkg.dst, symlinks=True)
                    common_fs.chmod_for_path(pkg.dst, "g-w")

    def setup_agentr(self, agentr):
        """ Setup AgentR for executing task """
        return agentr

    def on_system_error(self):
        """ Action on system error """
        pass

    def spawn(self, executor_args):
        with system.UserPrivileges():
            executor_args["cgroup"] = self.cgroup and self.cgroup.name

        self._cmd.executor_args = executor_args  # must not be executed under root
        self._cmd.save_state()

        env = dict(os.environ)
        env["HOME"] = system.UNPRIVILEGED_USER.home
        env["LOGNAME"] = system.UNPRIVILEGED_USER.login
        env["LANG"] = "en_US.UTF8"
        env.pop("PYTHONPATH", None)
        env[common_config.Registry.CONFIG_ENV_VAR] = self.config_path
        return system.TaskLiner(
            common_format.obfuscate_token(self._cmd.token),
            self.logger,
            [self.preexecutor_path, self._cmd.executor_args],
            env,
            user="root" if common_os.User.has_root else None,
            cgroup=self.cgroup,
        ), executor_args

    def terminate(self):
        pass

    def cancel(self):
        pass

    @classmethod
    def maintain(cls):
        pass

    @property
    def cgroup(self):
        return None

    def send_signal(self, sig):
        if system.local_mode():
            if self._cmd.liner:
                with system.UserPrivileges():
                    os.kill(self._cmd.liner.pid, sig)
        else:
            for proc in common_os.processes():
                try:
                    if proc.uid == system.UNPRIVILEGED_USER.uid:
                        os.kill(proc.pid, sig)
                except OSError:
                    pass

    def suspend(self):
        self.send_signal(signal.SIGSTOP)

    def resume(self):
        self.send_signal(signal.SIGCONT)

    @property
    def _shell_command(self):
        return "/bin/bash"

    @property
    def shell_command(self):
        if common_os.User.has_root:
            return "/usr/bin/sudo -EHu {} {}".format(
                system.UNPRIVILEGED_USER.login, self._shell_command
            )
        else:
            return self._shell_command

    @property
    def _ps_command(self):
        return "/bin/ps wwuxf"

    @property
    def ps_command(self):
        if common_os.User.has_root:
            return (
                "/bin/bash -c '/usr/bin/sudo -Hu {} {} ; /usr/bin/sudo -Hu {} {}'".format(
                    system.SERVICE_USER.login, self._ps_command,
                    system.UNPRIVILEGED_USER.login, self._ps_command,
                )  # Especially for mocksoul@
                if common_config.Registry().this.id.startswith("sandbox-storage") else
                "/usr/bin/sudo -Hu {} {}".format(
                    system.UNPRIVILEGED_USER.login, self._ps_command
                )
            )
        else:
            return self._ps_command

    @property
    def attach_command(self):
        exe = self.executable
        return " ".join([
            exe,
            os.path.join(os.path.dirname(os.path.dirname(exe)), "pydevd_attach_to_process", "attach_pydevd.py"),
            "--port", "{port}", "--host", "{host}", "--pid", "{pid}"
        ])

    @property
    def ramdrive(self):
        return None

    @ramdrive.setter
    def ramdrive(self, value):
        self.logger.error("Attempt to mount RAM drive on unsupported platform.")
        raise errors.ExecutorFailed

    @ramdrive.deleter
    def ramdrive(self):
        pass

    @staticmethod
    def _remove_path(path, exclude_paths=()):
        try:
            if os.path.isdir(path) and not os.path.islink(path):
                logger.debug("Removing files in %r", path)
                for root, dirs, files in os.walk(path, topdown=False):
                    for name in it.chain(dirs, files):
                        full_path = os.path.join(root, name)
                        if any(full_path.startswith(ex_path) for ex_path in exclude_paths):
                            continue
                        try:
                            if os.path.isdir(full_path) and not os.path.islink(full_path):
                                os.rmdir(full_path)
                            else:
                                os.remove(full_path)
                        except OSError as e:
                            logger.error("Error while removing %r: %s", full_path, e)
            else:
                logger.debug("Removing %r", path)
                os.remove(path)
        except Exception as e:
            logger.error("Error while removing files in %r: %s", path, e)

    def _unmount_fuse(self):
        pass

    @classmethod
    def _clean_fs_debris(cls, home_dir=None):
        if not home_dir and system.local_mode():
            return
        home_dir = home_dir or os.path.abspath(system.UNPRIVILEGED_USER.home)
        if not os.path.isdir(home_dir) or len(home_dir) < 2:
            logger.warning("Path is not a directory: %s. Skip cleanup", home_dir)
            return
        if not system.local_mode():
            tmp_dir = "/tmp"
            logger.debug("Cleaning %r", tmp_dir)
            root, dirs, files = next(os.walk(tmp_dir, topdown=False), ("", (), ()))
            for path in it.imap(ft.partial(os.path.join, root), it.chain(dirs, files)):
                try:
                    if os.lstat(path).st_uid == system.UNPRIVILEGED_USER.uid:
                        cls._remove_path(path)
                except OSError as ex:
                    if ex.errno == errno.ENOENT:
                        logger.warning("No such file or directory: %s", path)
                        continue
                    logger.exception("Error when removing path %s", path)
        cls._remove_path(home_dir, common_config.Registry().client.sandbox_home_cleanup_exclude)

    def _kill_executor(self, tag):
        truken = common_format.obfuscate_token(tag)
        self.logger.debug("Killing processes from previous execution by tag %r", truken)
        liner = None
        try:
            if truken in system.TaskLiner.active_sockets:
                liner = system.TaskLiner(truken, self.logger)
                self.logger.warning("Killing %r", liner)
                liner.terminate()
        except Exception:
            self.logger.error("Cannot kill %s", liner if liner else "processes from previous execution")
            system.TaskLiner.drop_stale_socket(truken)

    def _clean_proc_debris(self):
        if system.SERVICE_USER == system.UNPRIVILEGED_USER:
            return
        self.logger.debug("Killing all processes of user %r", system.UNPRIVILEGED_USER.login)
        for _ in xrange(5):
            were_killed = False
            for p in common_os.processes():
                try:
                    if p.uid == system.UNPRIVILEGED_USER.uid:
                        were_killed = were_killed or os.kill(p.pid, signal.SIGKILL)
                        self.logger.debug("Killing process #%d", p.pid)
                except OSError:
                    pass
            if not were_killed:
                break

    @classmethod
    def _safe_chown(cls, user, path, recursive=False):
        def chown(_path):
            try:
                os.lchown(_path, user.uid, user.gid)
            except OSError:
                logger.warning("Unable to chown %r to %d.%d", _path, user.uid, user.gid)
        logger.debug("Changing privileges on %r", path)
        chown(path)
        if recursive:
            common_fs.recursive_remove_immutable(path)
            for root, dirs, files in os.walk(path):
                for p in common_itertools.chain(files, dirs):
                    chown(os.path.join(root, p))

    @staticmethod
    def _restore_home(home_dir, tarball_path=None, job_logger=None):
        if job_logger is None:
            job_logger = logger
        job_logger.debug("Going to extract %s to %s", common_config.Registry().client.sandbox_home_tarball, home_dir)
        if system.local_mode() and tarball_path is None:
            return
        tarball_path = tarball_path or common_config.Registry().client.sandbox_home_tarball
        job_logger.info("Extracting %s to %s", tarball_path, home_dir)
        with tarfile.open(tarball_path) as tar:
            tar.extractall(home_dir)
        job_logger.info("%s successfully extracted", tarball_path)

    def cleanup(self):
        self.logger.info("Performing platform cleanup.")
        with system.PrivilegedSubprocess(("platform cleanup", self._cmd is not None and self._cmd.token)):
            self.resume()
            if self._cmd is not None and self._cmd.token:
                self._kill_executor(self._cmd.token)
            self._clean_proc_debris()
            self._unmount_fuse()
            self._clean_fs_debris()
            self._restore_home(system.UNPRIVILEGED_USER.home, job_logger=self.logger)
        del self.ramdrive

    def restore_system_files(self):
        pass

    def sigterm(self, executor_pid):
        self.logger.debug("Going to kill process %s with SIGTERM", executor_pid)
        os.kill(executor_pid, signal.SIGTERM)


class ResolvConfMixin(object):
    resolv_conf_path = "/etc/resolv.conf"

    def read_resolv_conf(self, names=None):
        if not names:
            names = ("orig", "")
        for name in common_itertools.chain(names):
            path = self.resolv_conf_path + ("." + name if name else "")
            if os.path.exists(path):
                return open(path, "r").read()
        return ""

    @property
    def dns_type(self):
        return self._cmd.args.get("dns")

    @property
    def resolv_conf(self):
        if self.dns_type == ctm.DnsType.DNS64:
            return self.read_resolv_conf(ctm.DnsType.DNS64)
        return self.read_resolv_conf(None)

    def set_resolv_conf_arg(self, executor_args):
        from sandbox.client.commands import TaskCommand
        if isinstance(self._cmd, TaskCommand):
            executor_args["dns"] = self.dns_type
            executor_args["resolv.conf"] = self.resolv_conf
