import os
import sys
import logging

from six.moves import cPickle as pickle

# cannot import kernel.util on windows because of specific imports in kernel/util/__init_.py not working in windows
if sys.platform == "win32":
    # noinspection PyUnresolvedReferences
    import kernel
    # create empty module kernel.util so that the interpreter will not load it from disk when importing sub modules
    kernel.util = type(kernel)("util")
    kernel.util.__path__ = ["/".join((kernel.__path__[0], kernel.util.__name__))]
    sys.modules["kernel.util"] = kernel.util


from sandbox.executor.commands import service as service_commands
from sandbox.executor.common import utils as cutils
from sandbox.executor.common import constants
from sandbox.executor.preexecutor import utils

from sandbox.agentr import types as agentr_types
from sandbox.agentr import errors as agentr_errors
from sandbox.agentr import client as agentr_client

from sandbox.common import os as common_os
from sandbox.common import auth as common_auth
from sandbox.common import rest as common_rest
from sandbox.common import proxy as common_proxy
from sandbox.common import config as common_config
from sandbox.common import errors as common_errors
from sandbox.common import format as common_format
from sandbox.common import system as common_system
from sandbox.common import platform as common_platform
from sandbox.common import itertools as common_itertools
from sandbox.common import statistics as common_statistics

from sandbox.common.windows import wsl
from sandbox.common.windows import subprocess as win_sp

import sandbox.common.types.misc as ctm
import sandbox.common.types.task as ctt
import sandbox.common.types.client as ctc

from sandbox import sdk2

# noinspection PyUnresolvedReferences
from kernel.util import console
# Setup custom picklers explicitly (SANDBOX-7356)
# noinspection PyUnresolvedReferences
import kernel.util.pickle  # noqa

logger = logging.getLogger("preexecutor")


def _exec_windows_binary(executable_path, work_dir, tmp_path, command):
    user = os.environ.get("LOGNAME")
    if command in (ctc.Command.EXECUTE, ctc.Command.RELEASE) or "(M)" not in wsl.FS.acl_get_by_user(work_dir, user):
        wsl.FS.acl_grant(work_dir, user, "(OI)(CI)(RX,W,M)")
    exit_code = 1
    try:
        ci = win_sp.CREATIONINFO(
            win_sp.CREATION_TYPE_LOGON,
            lpUsername=user,
            lpPassword=user  # for now they are equal
        )
        p = win_sp.Popen(
            [executable_path] + sys.argv[1:],
            creationinfo=ci, universal_newlines=True, stdout=sys.stdout, stderr=sys.stderr,
            env=utils.env_for_win_binary_task(user, ci, tmp_path)
        )
        exit_code = p.wait()
    finally:
        if (
            command in (ctc.Command.STOP, ctc.Command.TERMINATE, ctc.Command.RELEASE) or
            command == ctc.Command.EXECUTE and exit_code == 0
        ):
            wsl.FS.acl_remove(work_dir, user)
    os._exit(exit_code)


def _exec_binary_tasks(agentr, tasks_rid, work_dir, command=None):
    executable_path = utils.prepare_tasks_binary(agentr, tasks_rid, work_dir)
    logger.info("Execute binary %s", executable_path)
    task_env = utils.env_for_binary_task(tmp_dir=utils.configure_tmp(agentr, work_dir))
    if common_platform.on_windows():
        _exec_windows_binary(executable_path, work_dir, task_env["TMP"], command)
    args = add_reattacher([executable_path] + sys.argv[1:])
    os.execve(args[0], args, task_env)


def _exec_privileged_task(ipc_args, exec_type, tasks_rid, agentr, work_dir, tasks_dir):
    with open(sys.argv[1], "wb") as fh:
        fh.write(pickle.dumps(ipc_args, 2))

    container = cutils.container_info(ipc_args)
    if exec_type == ctt.ImageType.BINARY:
        env = utils.env_for_binary_task()
        cmd = [utils.prepare_tasks_binary(agentr, tasks_rid, work_dir)] + sys.argv[1:]
    else:
        env = os.environ.copy()
        env["PYTHONPATH"] = ":".join(
            ["/skynet", os.path.dirname(constants.SANDBOX_DIR), constants.SANDBOX_DIR, tasks_dir]
        )
        cmd = [ipc_args["executable"], "-u", ipc_args["executor_path"]] + sys.argv[1:]

    # on_prepare stage
    return_code = utils.run_executor_as_subprocess(
        cmd=cmd,
        user_name=os.environ.get("LOGNAME"),
        work_dir=work_dir,
        env=env,
    )
    if return_code:
        sys.exit(return_code)

    # In case of successful first stage there will be updated executor_args file by the same path
    ipc_args = cutils.parse_ipc_args(False)

    env_for_container = ipc_args["env"]
    tmp_dir = utils.configure_tmp(agentr, work_dir)
    env_for_container.update({
        "HOME": "/root",
        "USER": "root",
        "LOGNAME": "root",
        "UNPRIVILEGED_USER": env["USER"],
        "TMP": tmp_dir,
        "TEMP": tmp_dir,
        "TMPDIR": tmp_dir,
        common_config.Registry.CONFIG_ENV_VAR: os.environ.get(common_config.Registry.CONFIG_ENV_VAR)
    })

    if exec_type != ctt.ImageType.BINARY and container.container_type != "porto":
        ipc_args["executable"] = container.executable
        with open(sys.argv[1], "wb") as fh:
            fh.write(pickle.dumps(ipc_args, 2))
    cmd = sys.argv

    # on_execute stage
    if container.container_type == "porto":
        if type(cmd) is list:
            cmd = " ".join(cmd)
        cmd = [
            "/usr/sbin/portoctl", "exec", container.name,
            "isolate=false",
            "enable_porto=isolate",
            "command={}".format(cmd),
            "resolv_conf={}".format(container.resolvconf),
            "ulimit=core: unlimited unlimited",
            "cwd={}".format(work_dir),
        ] + (container.properties or [])
        cmd[-1:-1] = ['env={}={}'.format(k, v) for k, v in env_for_container.iteritems()]
    else:
        cmd = ["/usr/bin/lxc-attach", "-n", container.name, "--"] + cmd

    return_code = utils.run_executor_as_subprocess(
        cmd=cmd,
        user_name=None,  # root
        work_dir=work_dir,
        env=env_for_container,
    )
    sys.exit(return_code)


def execute_command(ipc_args, executor_ctx, command, command_args, subsequent=False, task=None):
    settings = common_config.Registry()
    oauth_token = ipc_args.get("oauth_token")
    session_token = ipc_args.get("session_token")
    service_auth = command_args.get("service_auth")
    is_service_command = session_token == ctc.ServiceTokens.SERVICE_TOKEN

    if session_token:
        if is_service_command and service_auth:
            logger.debug("Using service auth token: %s", common_format.obfuscate_token(service_auth))
            auth = common_auth.Session(None, service_auth)
        else:
            auth = common_auth.Session(None, session_token)
        logger.debug("Task session token: %s", common_format.obfuscate_token(session_token))
        logger.info(
            "Current client ID: %r, platform: %r (%s), container: %r",
            settings.this.id, settings.this.system.family, common_platform.platform(), ipc_args.get("container")
        )
    else:
        auth = common_auth.OAuth(oauth_token) if oauth_token else common_auth.NoAuth()

    common_rest.Client._default_component = common_proxy.ReliableServerProxy._default_component = ctm.Component.EXECUTOR
    common_rest.Client._external_auth = common_proxy.ReliableServerProxy._external_auth = auth

    logger.info("[%d] handle command %r with %d arguments", os.getpid(), command, len(ipc_args))

    try:
        iteration = command_args.get("iteration")
        if iteration is not None:
            command_args["agentr"] = agentr_client.Session(session_token, iteration + 1, 0, logger)
        exec_class = service_commands.CMD2CLASS[command]
        kill_timeout = command_args.get("kill_timeout")
        if kill_timeout and kill_timeout < 600:
            common_rest.Client.DEFAULT_TIMEOUT = common_proxy.ReliableServerProxy.DEFAULT_TIMEOUT = 10

        executable = exec_class(ipc_args=ipc_args, **command_args)

        # collect profiling information (SANDBOX-4622, SANDBOX-5872). entry points are:
        # - sandboxsdk.svn.Svn.svn()
        # - sandboxsdk.svn.Hg._hg()
        # - common.rest.Client._request(), where _accounter callback is used
        # - common.log.ExceptionSignalSenderHandler

        common_statistics.Signaler(
            # exceptions happened during task executions will be put in a separate table
            common_statistics.ClientSignalHandler(
                token=auth,
                task_id=None
            ),
            component=ctm.Component.EXECUTOR,
            update_interval=settings.client.statistics.update_interval
        )

        if settings.common.statistics.enabled:
            logging.debug("Statistics collection is set up")

        status, message = list(common_itertools.chain(executable.execute(), None))[:2]
        executor_ctx["status_message"] = message
        if status not in ctt.Status:
            raise Exception("Command {!r} returned wrong status {!r}".format(command, status))
    except (
        common_proxy.ReliableServerProxy.SessionExpired,
        common_rest.Client.SessionExpired,
        agentr_errors.NoTaskSession,
    ):
        logger.exception("Task session expired")
        status = "SessionExpired"
    except common_errors.TaskError as ex:
        logger.exception("Task error occurred")
        status = "SessionExpired"
        executor_ctx["status_message"] = str(ex)
    except (Exception, KeyboardInterrupt, SystemExit) as ex:
        status = None  # Empty status equals to executor failure
        message = "Unhandled exception while executing command {!r}".format(command)
        executor_ctx["status_message"] = "{}: {}".format(message, ex)
        logger.exception(message)

    executor_ctx["status"] = status
    logger.debug("Execution result: %r", status)
    logger.debug("Execution context: %s", executor_ctx)
    return task


def add_reattacher(args):
    # On OS X we should try to reattach to user namespace though
    # https://github.com/ChrisJohnsen/tmux-MacOSX-pasteboard
    if sys.platform.startswith("darwin"):
        try:
            reattacher = os.path.join(
                common_config.Registry().common.dirs.service,
                "reattach-to-user-namespace"
            )
            if os.path.exists(reattacher):
                args = [reattacher] + args
            else:
                logger.warning(
                    "Failed to find 'reattach-to-user-namespace' binary, "
                    "some features (pbcopy/pbpaste) might be unavailable."
                )
        except:
            logger.exception(
                "Failed to find 'reattach-to-user-namespace' binary, "
                "some features (pbcopy/pbpaste) might be unavailable."
            )
    return args


def main():
    global logger

    ipc_args = cutils.parse_ipc_args()

    tasks_dir = ipc_args["tasks_dir"]
    sys.path = [tasks_dir] + sys.path

    result_filename = ipc_args.get("result_filename")
    if not result_filename:
        raise Exception("result filename is not found in arguments")
    master_host = ipc_args.pop("master_host", None)
    # set container's node_id to node_id of the master host
    if master_host:
        common_config.Registry().this.id = master_host

    command = ipc_args["command"]
    command_args = ipc_args["args"]
    task_id = command_args.get("task_id")
    iteration = command_args.get("iteration")
    executor_ctx = command_args.get("ctx", {})
    exec_type = command_args.get("exec_type")
    tasks_rid = command_args.get("tasks_rid")

    console.setProcTitle("[sandbox] Executor: " + command + (" task #" + str(task_id) if task_id else ""))
    # We should remove Y_PYTHON_ENTRYPOINT because it broke arcadia builded binaries
    if exec_type == ctt.ImageType.BINARY:
        os.environ.pop("Y_PYTHON_ENTRY_POINT", None)

    has_root = common_os.User.has_root
    username = os.environ.get("LOGNAME") if command != ctc.Command.EXECUTE_PRIVILEGED_CONTAINER else None
    with common_os.User.Privileges(username):
        agentr, logger, logdir = cutils.configure_logger(task_id, iteration, ipc_args.get("session_token"))
        task_dir = os.path.dirname(logdir)
        if not has_root or ipc_args.get("privileges_set"):
            current_pid = cutils.get_executor_pid()
            agentr.register_executor(current_pid)

        if exec_type != ctt.ImageType.BINARY:
            status_message = None
            if common_platform.on_windows():
                status_message = "Only binary tasks supported on Windows"
            if ctc.Tag.M1 in common_config.Registry().client.tags:
                status_message = "Only binary tasks supported on M1"
            if status_message:
                executor_ctx["status"] = ctt.Status.EXCEPTION
                executor_ctx["status_message"] = status_message
                cutils.save_result(result_filename, executor_ctx, logger)

        if (
            command == ctc.Command.EXECUTE and
            exec_type != ctt.ImageType.BINARY
        ):
            os.environ["TMP"] = os.environ["TEMP"] = os.environ["TMPDIR"] = utils.configure_tmp(agentr, task_dir)

        if (
            sys.platform.startswith("linux") and
            not ipc_args.get("privileges_set") and
            common_config.Registry().common.installation != ctm.Installation.TEST
        ):
            hook_error = None
            if command in (ctc.Command.STOP, ctc.Command.TERMINATE):
                if has_root and ipc_args.get("container") and os.path.exists(constants.ON_TASK_STOP_ROOT_HOOK):
                    hook_error = utils.execute_root_hook(False, [], logger)
            elif command == ctc.Command.EXECUTE_PRIVILEGED:
                atop_pid = utils.run_atop_subprocess(
                    logdir,
                    ipc_args.get("cgroup"),
                    ipc_args.get("container"),
                    logger=logger
                )
                if atop_pid:
                    agentr.register_atop(atop_pid)
                if has_root and ipc_args.get("container") and os.path.exists(constants.ON_TASK_START_ROOT_HOOK):
                    hook_error = utils.execute_root_hook(True, [task_dir], logger)
            elif command == ctc.Command.EXECUTE:
                atop_pid = utils.run_atop_subprocess(logdir, ipc_args.get("cgroup"))
                if atop_pid:
                    agentr.register_atop(atop_pid)
                if has_root and ipc_args.get("container") and os.path.exists(constants.ON_TASK_START_ROOT_HOOK):
                    hook_error = utils.execute_root_hook(True, [task_dir], logger)
            if hook_error:
                ipc_args["force_status"] = (ctt.Status.TEMPORARY, hook_error)

        if has_root and not ipc_args.get("privileges_set") and task_id and iteration is not None:
            try:
                utils.redirect_std_file(sys.stdout, os.path.join(logdir, agentr_types.STDOUT_FILENAME))
                utils.redirect_std_file(sys.stderr, os.path.join(logdir, agentr_types.STDERR_FILENAME))
            except OSError:
                logger.exception("Failed to redirect std files to files in log directory.")

        if has_root and not ipc_args.get("privileges_set"):
            if (
                not utils.update_resolv_conf(ipc_args.get("resolv.conf"), ipc_args.get("dns"), task_id, logger) and
                # avoid fails when dns=local and command is executed in host system
                command not in (
                    ctc.Command.STOP, ctc.Command.DELETE, ctc.Command.RELEASE,
                    ctc.Command.TERMINATE, ctc.Command.EXECUTE_PRIVILEGED
                )
            ):
                ipc_args["force_status"] = (
                    ctt.Status.EXCEPTION, "Failed to update resolv.conf. See logs for details."
                )

            if (
                command in (ctc.Command.EXECUTE, ctc.Command.EXECUTE_PRIVILEGED_CONTAINER) and
                common_config.Registry().common.installation in ctm.Installation.Group.NONLOCAL
            ):
                if not utils.check_container_consistency():
                    executor_ctx["status"] = "ContainerError"
                    executor_ctx["status_message"] = "Error in container, restart task, destroy container"
                    cutils.save_result(result_filename, executor_ctx, logger)

        if has_root and command != ctc.Command.EXECUTE_PRIVILEGED and not ipc_args.get("privileges_set"):
            utils.switch_cgroup(ipc_args, logger=logger)

            ipc_args["privileges_set"] = True
            with open(sys.argv[1], "wb") as fh:
                fh.write(pickle.dumps(ipc_args, 2))

            # Drop privileges completely on OSX. Darwin doesn't have saved uid,
            # therefore real uid remains unchanged to be able to restore privileges.
            # We won't be able to elevate back to root, so executor should exec
            # before exiting `common.os.User.Privileges` context manager.
            if sys.platform.startswith("darwin"):
                common_os.User.Privileges(os.environ["LOGNAME"], store=False).__enter__()

            # restarting with appropriate privileges and cgroup

            if exec_type == ctt.ImageType.BINARY:
                _exec_binary_tasks(agentr, tasks_rid, task_dir, command)
            elif not common_platform.on_windows() and ctc.Tag.M1 not in common_config.Registry().client.tags:

                if tasks_rid and exec_type and exec_type not in ctt.ImageType.Group.EXTRACTABLE:
                    # We should mount tasks image before the task start
                    tasks_resource_path = agentr.resource_sync(tasks_rid, fastbone=False)

                    if ctc.Tag.PORTOD in common_config.Registry().client.tags and command in (
                        ctc.Command.EXECUTE_PRIVILEGED_CONTAINER, ctc.Command.EXECUTE,
                        ctc.Command.STOP, ctc.Command.TERMINATE,
                    ):
                        try:
                            agentr.umount(tasks_dir)
                        except ValueError as e:
                            logger.debug("%s for task #%s", e.message, task_id)
                        agentr.mount_image(tasks_resource_path, tasks_dir)
                    else:
                        utils.mount_tasks_image(
                            tasks_resource_path,
                            tasks_dir,
                            "Task: {}, container: {}".format(task_id, ipc_args.get("container", {}).get("name")),
                            logger=logger
                        )

                # Exec into itself by default
                args = (
                    sys.argv
                    if command in service_commands.CMD2CLASS else
                    [ipc_args["executable"], "-u", ipc_args["executor_path"]] + sys.argv[1:]
                )

                args = add_reattacher(args)

                os.environ["PYTHONPATH"] = ":".join(
                    ["/skynet", os.path.dirname(constants.SANDBOX_DIR), constants.SANDBOX_DIR, tasks_dir]
                )
                os.execv(args[0], args)  # TODO tyt ne tak bydet

        if has_root and command == ctc.Command.EXECUTE_PRIVILEGED:
            _exec_privileged_task(ipc_args, exec_type, tasks_rid, agentr, task_dir, tasks_dir)

        if (
            not common_system.inside_the_binary() and exec_type == ctt.ImageType.BINARY and
            not has_root and common_config.Registry().common.installation in ctm.Installation.Group.LOCAL
        ):
            # Execution of binary in case of absent root privileges in local installation.
            with open(sys.argv[1], "wb") as fh:
                fh.write(pickle.dumps(ipc_args, 2))
            _exec_binary_tasks(agentr, tasks_rid, task_dir)

        # Start ssh-agent if ssh key is provided
        key = ipc_args.get("ssh_private_key", None)
        if key:
            # Don't inherit TMPDIR from parent, we may have no permissions to write there.
            # https://github.com/openssh/openssh-portable/blob/25cf9105b849932fc3b141590c009e704f2eeba6/misc.c#L1406
            os.environ.pop("TMPDIR", None)

            try:
                ssh_agent = sdk2.ssh.SshAgent()
            except sdk2.ssh.SshAgentNotAvailable as e:
                logger.error("Cannot start ssh-agent, private ssh-key won't be available: %s", e)
            else:
                try:
                    ssh_agent.add(key)
                except common_errors.TaskError as ex:
                    logger.error("Cannot add ssh key, private ssh-key won't be available: %s", ex)

        command_args["logdir"] = os.path.split(logdir)[-1]
        if command in service_commands.CMD2CLASS:
            task = execute_command(ipc_args, executor_ctx, command, command_args)
        else:
            ipc_args["privileges_set"] = True
            with open(sys.argv[1], "wb") as fh:
                fh.write(pickle.dumps(ipc_args, 2))
            if exec_type == ctt.ImageType.BINARY:
                _exec_binary_tasks(agentr, tasks_rid, task_dir)
            else:
                args = [ipc_args["executable"], "-u", ipc_args["executor_path"]] + sys.argv[1:]
                os.environ["TMP"] = os.environ["TEMP"] = os.environ["TEMPDIR"] = utils.configure_tmp(agentr, task_dir)
                os.environ["PYTHONPATH"] = ":".join(
                    ["/skynet", os.path.dirname(constants.SANDBOX_DIR), constants.SANDBOX_DIR, tasks_dir]
                )
                os.execv(args[0], args)

        subsequent = command_args.get("subsequent")
        status = executor_ctx["status"]
        if subsequent and status in ctt.Status and not (
            # if task has stop hook, it cannot be executed in the same process
            ipc_args.get("container") and os.path.exists(constants.ON_TASK_STOP_ROOT_HOOK)
        ):
            subsequent["args"]["status"] = status
            subsequent["args"]["logdir"] = command_args["logdir"]
            execute_command(
                ipc_args, executor_ctx,
                subsequent["command"], subsequent["args"],
                subsequent=True, task=task
            )
            executor_ctx["ran_subsequent"] = True

        logging.debug("Sending operations profiling signals")
        common_statistics.Signaler().wait()
        logging.debug("All statistics sent!")
        cutils.save_result(result_filename, executor_ctx)
        os._exit(0)


if __name__ == "__main__":
    main()
