import os
import stat
import time
import signal
import socket
import getpass
import logging

import six

if six.PY2:
    import pathlib2 as pathlib
    import subprocess32 as sp
else:
    import pathlib
    import subprocess as sp

from sandbox.executor.preexecutor import constants

from sandbox.common import os as common_os
from sandbox.common import fs as common_fs
from sandbox.common import config as common_config
from sandbox.common import platform as common_platform
from sandbox.common import itertools as common_itertools

from sandbox.common.windows import wsl
from sandbox.common.windows import subprocess as win_sp

import sandbox.common.types.misc as ctm

from sandbox import sdk2


logger = logging.getLogger("preexecutor")


def execute_root_hook(on_start, hook_args, logger):
    path = constants.ON_TASK_START_ROOT_HOOK if on_start else constants.ON_TASK_STOP_ROOT_HOOK
    hook_name = "start" if on_start else "stop"
    if os.path.exists(path):
        try:
            cmd = [path] + hook_args
            logger.info("Executing container's task %s hook: %r", hook_name, cmd)
            with common_os.User.Privileges():
                out = check_subprocess_output(cmd, timeout=constants.ROOT_HOOKS_TIMEOUT).strip()
                if out:
                    logger.debug("Hook output:\n%s", out)
        except sp.TimeoutExpired as ex:
            error_msg = "Container's task {} hook execution timed out after {} seconds".format(
                hook_name, constants.ROOT_HOOKS_TIMEOUT
            )
            logger.error(error_msg)
            logger.debug("Hook output:\n%s", ex.output)
            return error_msg
        except (OSError, sp.CalledProcessError) as ex:
            logger.error("Error on container's task %s hook execution: %s", hook_name, ex)
            if isinstance(ex, sp.CalledProcessError):
                logger.debug("Hook output:\n%s", ex.output)


def check_subprocess_output(cmd, timeout):
    # Implementation mostly copied from
    # https://github.com/google/python-subprocess32/blob/b63954f56166eca312afb70b6d3a4af1128f81f3/subprocess32.py#L314
    # Add 'start_new_session' to Popen, replace 'process.kill' with 'os.killpg'
    process = sp.Popen(cmd, stdout=sp.PIPE, stderr=sp.STDOUT, start_new_session=True)
    try:
        output, unused_err = process.communicate(timeout=timeout)
    except sp.TimeoutExpired:
        os.killpg(process.pid, signal.SIGKILL)
        output, _ = process.communicate()
        raise sp.TimeoutExpired(process.args, timeout, output=output)
    retcode = process.poll()
    if retcode:
        raise sp.CalledProcessError(retcode, process.args, output=output)
    return output


def run_atop_subprocess(logdir, cgroup, container=None, logger=None):
    # Windows has no atop
    if common_platform.on_windows():
        return None

    if logger is None:
        logger = logging.getLogger("atop")

    out_path = os.path.join(logdir, "atop.out")
    if os.path.exists(out_path):
        return None
    err_path = os.path.join(logdir, "atop.err")

    atop_binary = os.path.join(common_config.Registry().client.dirs.data, "atop")
    if not os.path.exists(atop_binary):
        atop_binary = sdk2.paths.which("atop")
        if atop_binary is None:
            logger.warning("Unable to start atop: atop binary is missing")
            return None

    atop_cmd = [atop_binary, "-w", out_path]
    if container and container.get("container_type") == "porto":
        base_container_name = "/".join(container.get("name").split("/")[:-1])
        atop_cmd = [
            "/usr/sbin/portoctl", "exec", "{}/atop".format(base_container_name),
            "enable_porto=false",
            "isolate=false",
            "command={}".format(" ".join(atop_cmd))
        ]
    elif container and container.get("container_type") != "porto":
        atop_cmd = ["/usr/bin/lxc-attach", "-n", container.get("name"), "--"] + atop_cmd

    user = common_os.User(getpass.getuser())
    with common_os.User.Privileges():
        logger.debug("Starting atop (out=%s)", out_path)
        cg = None
        try:
            if cgroup:
                cg = common_os.CGroup(cgroup)
        except Exception as ex:
            logger.warning(ex)

        p = sp.Popen(
            atop_cmd,
            env={"TERM": "xterm"},
            preexec_fn=lambda: cg and cg.set_current(),
            stderr=open(err_path, "w")
        )
        if common_os.User.has_root:
            if not common_itertools.progressive_waiter(0, 1, 10, lambda: os.path.exists(out_path))[0]:
                p.kill()
                logger.warning("Unable to start atop: output missing")

            if os.path.exists(out_path):
                os.chown(out_path, user.uid, user.gid)
                logger.debug("Changed ownership for %s to %s/%s", out_path, user.login, user.group)

            if os.path.exists(err_path):
                os.chown(err_path, user.uid, user.gid)
                logger.debug("Changed ownership for %s to %s/%s", err_path, user.login, user.group)

    return None if p.poll() else p.pid


def redirect_std_file(std_file, path):
    old_fd = std_file.fileno()
    new_fd = os.open(path, os.O_WRONLY)
    os.lseek(new_fd, 0, os.SEEK_END)
    os.dup2(new_fd, old_fd)
    os.close(new_fd)


def update_resolv_conf(resolv_conf, dns_type, task_id, logger=None):
    # TODO: support DnsType on Windows [SANDBOX-7657]
    if common_platform.on_windows():
        return True

    if logger is None:
        logger = logging.getLogger("resolv_conf")

    resolv_conf_path = "/etc/resolv.conf"
    if dns_type == ctm.DnsType.LOCAL:
        local_path = ".".join([resolv_conf_path, dns_type])
        if os.path.exists(local_path):
            resolv_conf = open(local_path, "r").read()
        else:
            logger.error("DNS type '%s' is set, but %s is not found.", dns_type, local_path)
            return False

    if resolv_conf:
        with common_os.User.Privileges():
            with pathlib.Path(resolv_conf_path + "." + str(task_id)) as path:
                path.write_bytes(six.ensure_binary(resolv_conf))
                path.rename(resolv_conf_path)
    logger.debug("resolv.conf of type '%s':\n%s", dns_type, open(resolv_conf_path).read())
    return True


def check_container_consistency():
    if not common_platform.on_linux() or common_config.Registry().common.installation == ctm.Installation.TEST:
        return True
    if not ensure_fuse_device():
        logger.error("There is no fuse device and it's impossible to create it. See logs for details.")
        return False
    if not check_container_network():
        logger.error("Cannot continue executing because of corrupted network")
        return False
    return True


def check_container_network():
    sandbox_web_host = common_config.Registry().server.web.address.host
    port = common_config.Registry().server.api.port if common_config.Registry().common.installation in \
        ctm.Installation.LOCAL else 80

    def connect_to_sandbox():
        sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
        try:
            code = sock.connect_ex((sandbox_web_host, port))
            if code == 0:
                return True
            logger.error("Cannot connect to {}:{}. connect errno: {}".format(sandbox_web_host, port, code))
            return False
        except socket.error as e:
            logger.error("Cannot connect to {}:{}. {}".format(sandbox_web_host, port, e))
            return False
        finally:
            sock.close()

    return common_itertools.progressive_waiter(1, 1, 30, connect_to_sandbox, sleep_first=False)[0]


def ensure_fuse_device():
    if not common_platform.on_linux():
        return True

    device_path = pathlib.Path("/dev/fuse")
    try:
        if not device_path.exists():
            logger.warning("Fuse device is absent. Creating it")
            mode = stat.S_IRUSR | stat.S_IWUSR | stat.S_IRGRP | stat.S_IWGRP | stat.S_IROTH | stat.S_IWOTH
            with common_os.User.Privileges():
                os.mknod(str(device_path), mode | stat.S_IFCHR, os.makedev(10, 229))
                os.chmod(str(device_path), mode)  # Set mode explicitly to avoid umask influence

        dst_privileges = common_os.User("root", "fuse")
        device_stat = device_path.stat()
        if device_stat.st_uid != dst_privileges.uid or device_stat.st_gid != dst_privileges.gid:
            logger.info("Change ownership of %s to %s:%s", device_path, dst_privileges.uid, dst_privileges.gid)
            with common_os.User.Privileges():
                os.chown(str(device_path), dst_privileges.uid, dst_privileges.gid)
    except OSError:
        logger.exception("Failed to ensure fuse device")
        return False

    return True


def _clean_proc_debris(logger):
    user = common_os.User(getpass.getuser())
    my_pid = os.getpid()
    with common_os.User.Privileges():
        for p in common_os.processes():
            if p.pid != my_pid and p.uid == user.uid:
                try:
                    os.kill(p.pid, signal.SIGKILL)
                except OSError:
                    logger.exception("Unable to kill process %d", p.pid)
                else:
                    logger.warning("Killed process from previous execution %d", p.pid)


def switch_cgroup(ipc_args, logger=None):
    # Windows has no cgroups
    if common_platform.on_windows():
        return

    if logger is None:
        logger = logging.getLogger("switch_cgroup")

    cgroup = ipc_args.get("cgroup")
    if not cgroup:
        return

    logger.debug("Switching cgroup to %s", cgroup)
    switched = False
    try:
        with common_os.User.Privileges():
            for ss in common_os.CGroup(cgroup) or ():
                try:
                    ss.set_current()
                except Exception as ex:
                    logger.error("Failed to switch to subsystem %s: %s", ss, ex)
                else:
                    logger.info("Switched to subsystem %s", ss)
                    switched = True
    except Exception:
        logger.exception("Failed to switch process to cgroup %s", cgroup)

    # if we didn't switch to a cgroup in container, it could be dirty
    # try to kill all previous processes from unprivileged user
    if not switched and ipc_args.get("container"):
        _clean_proc_debris(logger)


def mount_tasks_image(tasks_image_path, tasks_dir, logging_mark, logger=None):
    if logger is None:
        logger = logging.getLogger("mount_tasks_image")
    sp.call(["/bin/umount", tasks_dir], preexec_fn=common_os.User.Privileges().__enter__)
    # FIXME: SANDBOX-6164: Protection against RC on selecting loop device ID by `mount` command.
    attempts = 5
    for i in six.moves.xrange(attempts + 1):
        logger.info("Mounting tasks image %r to %r", tasks_image_path, tasks_dir)
        cmd = ["/bin/mount", "-v", "-t", "squashfs", "-o", "loop,ro", tasks_image_path, tasks_dir]
        p = sp.Popen(cmd, preexec_fn=common_os.User.Privileges().__enter__, stderr=sp.PIPE)
        _, stderr = p.communicate()
        if p.returncode == 2 and i < attempts:
            logger.warning("RC on mounting code. Sleeping a while and trying again")
            time.sleep(i + 2)
            continue
        elif not p.returncode or "{} is already mounted".format(tasks_image_path) in stderr:
            break
        else:
            raise Exception(
                "Error mounting tasks image by command {}: {}\n{}".format(" ".join(cmd), stderr, logging_mark)
            )


def env_for_binary_task(initial_env=None, tmp_dir=None):
    env = initial_env or os.environ.copy()
    env.pop("PYTHONPATH", None)
    env[common_config.Registry.CONFIG_ENV_VAR] = str(common_config.Registry().custom)
    env["Y_PYTHON_ENTRY_POINT"] = "sandbox.bin.executor:main"  # standard Sandbox entry point
    if tmp_dir:
        env["TMP"] = env["TEMP"] = env["TMPDIR"] = tmp_dir
    return env


def configure_tmp(agentr, work_dir):
    common_fs.make_folder(os.path.join(work_dir, "tmp"))
    return agentr.prepare_tmp()


def extract_resource_path(agentr, tasks_rid, executable_path):
    tasks_resource_meta = agentr.resource_meta(tasks_rid)
    if tasks_resource_meta.get("multifile"):
        pl = common_platform.get_arch_from_platform(common_platform.platform())
        platform_path = tasks_resource_meta.get("system_attributes", {}).get(pl + "_platform", None)
        if platform_path is None:
            raise ValueError("Task directory-resource doesn't provide {} binary".format(pl))
        executable_path = os.path.join(executable_path, platform_path)
    if not os.path.exists(executable_path):
        raise OSError("Binary path {} not exists".format(executable_path))
    return executable_path


def prepare_tasks_binary(agentr, tasks_rid, work_dir):
    resource_path = agentr.resource_sync(tasks_rid, fastbone=False)
    if common_platform.on_windows():
        return extract_resource_path(agentr, tasks_rid, resource_path)
    new_path = os.path.join(work_dir, os.path.basename(resource_path))
    if os.path.islink(new_path):
        os.remove(new_path)
    elif os.path.exists(new_path):
        raise Exception("Executable path {} already exists.".format(new_path))
    os.symlink(resource_path, new_path)
    return extract_resource_path(agentr, tasks_rid, new_path)


def run_executor_as_subprocess(cmd, user_name, work_dir, env):
    p = sp.Popen(
        cmd, cwd=work_dir, env=env,
        preexec_fn=common_os.User.Privileges(user_name).__enter__,
    )
    logger.info(
        "Subprocess with PID #%d with %s privileges started by command %r",
        p.pid,
        user_name or "root",
        cmd,
    )
    p.wait()
    logger.info("Subprocess with PID #%d finished with exit code %d", p.pid, p.returncode)
    return p.returncode


def env_for_win_binary_task(username, ci, tmp_path):
    # Spawn cmd.exe to load user profile before craft env
    pre_p = win_sp.Popen(["cmd.exe", "/C", "exit"], creationinfo=ci)
    pre_p.wait()

    user_home_path = wsl.User.win_user_home_path(username)
    if not user_home_path:
        logging.error("Can't determinate user home dir. Trying to craft one")
        user_home_path = os.path.join("C:\\Users", username)
    env = {
        "LOGNAME": username,
        "LOCALAPPDATA": os.path.join(user_home_path, "AppData", "Local"),
        "APPDATA": os.path.join(user_home_path, "AppData", "Roaming"),
        "HOMEDRIVE": os.path.splitdrive(user_home_path)[0],
        "HOMEPATH": os.path.splitdrive(user_home_path)[1],
        "USERPROFILE": user_home_path,
        "TMP": tmp_path,
        "TEMP": tmp_path
    }
    for k, v in os.environ.items():
        env.setdefault(k, v)
    return env_for_binary_task(initial_env=env)
