#!/home/zomb-sandbox/venv/bin/python

"""
Sandbox task executor.
All tasks executed withing this process.
"""

import os
import sys
import shlex
import signal
import logging
import datetime as dt

import six
from six.moves.urllib import parse as urlparse

if six.PY2:
    import subprocess32 as sp
else:
    import subprocess as sp

# cannot import kernel.util on windows because of specific imports in kernel/util/__init_.py not working in windows
if sys.platform == "win32":
    # noinspection PyUnresolvedReferences
    import kernel
    # create empty module kernel.util so that the interpreter will not load it from disk when importing sub modules
    kernel.util = type(kernel)("util")
    kernel.util.__path__ = ["/".join((kernel.__path__[0], kernel.util.__name__))]
    sys.modules["kernel.util"] = kernel.util
# Setup custom picklers explicitly (SANDBOX-7356)
# noinspection PyUnresolvedReferences
import kernel.util.pickle  # noqa
# noinspection PyUnresolvedReferences
import kernel.util.console  # noqa

SANDBOX_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))  # noqa
sys.path = ["/skynet", os.path.dirname(SANDBOX_DIR), SANDBOX_DIR] + sys.path  # noqa

from sandbox.executor.commands import task as task_commands
from sandbox.executor.commands.task import utils as task_utils
from sandbox.executor.commands.task import adapters
from sandbox.executor.common import utils as exec_utils
from sandbox.executor.common import constants

from sandbox.common import os as common_os
from sandbox.common import log as common_log
from sandbox.common import auth as common_auth
from sandbox.common import rest as common_rest
from sandbox.common import proxy as common_proxy
from sandbox.common import errors as common_errors
from sandbox.common import config as common_config
from sandbox.common import format as common_format
from sandbox.common import platform as common_platform
from sandbox.common import itertools as common_itertools
from sandbox.common import threading as common_threading
from sandbox.common import statistics as common_statistics
from sandbox.common import projects_handler

from sandbox.common.types import misc as ctm
from sandbox.common.types import task as ctt
from sandbox.common.types import client as ctc
from sandbox.common.types import statistics as ctss

from sandbox.agentr import client as agentr_client
from sandbox.agentr import errors as agentr_errors

from sandbox import sdk2


logger = logging.getLogger("executor")
settings = common_config.Registry()


def run_tcpdump_subprocess(task):
    # Windows has no tcpdump
    if common_platform.on_windows():
        return None

    if not task.tcpdump_args:
        return None

    dump_path = task.log_path("tcpdump")
    if os.path.exists(dump_path):
        logger.debug("Tcpdump output already exists")
        return None

    tcpdump_binary = sdk2.paths.which("tcpdump")
    if tcpdump_binary is None:
        logger.warning("Unable to start tcpdump: binary is missing")
        return None

    args = shlex.split(task.tcpdump_args) + ["-w", dump_path]
    if not any(_.startswith("-i") for _ in args):
        args.extend(["-i", "any"])

    stdout = open(task.log_path("tcpdump.stdout"), "w")
    stderr = open(task.log_path("tcpdump.stderr"), "w")
    sp.check_call([tcpdump_binary, "--version"], stdout=stdout, stderr=stderr)

    logger.debug("Starting tcpdump (out=%s)", dump_path)
    p = sp.Popen([tcpdump_binary] + args, stdout=stdout, stderr=stderr)
    if not common_itertools.progressive_waiter(0, 1, 10, lambda: os.path.exists(dump_path))[0]:
        p.kill()
        logger.warning("Unable to start tcpdump: output missing")

    return None if p.poll() else p


def execute_command(ipc_args, executor_ctx, command, command_args, subsequent=False, task=None):
    task_id = command_args.get("task_id")
    if subsequent:
        kernel.util.console.setProcTitle(
            "[sandbox] Executor: " + command + (" task #" + str(task_id) if task_id else "")
        )
        task.agentr.monitoring_finish()

    settings = common_config.Registry()
    oauth_token = ipc_args.get("oauth_token")
    session_token = ipc_args.get("session_token")

    if session_token:
        auth = common_auth.Session(None, session_token)
        logger.debug("Task session token: %s", common_format.obfuscate_token(session_token))
        logger.info(
            "Current client ID: %r, platform: %r (%s), container: %r",
            settings.this.id, settings.this.system.family, common_platform.platform(), ipc_args.get("container")
        )
    else:
        auth = common_auth.OAuth(oauth_token) if oauth_token else common_auth.NoAuth()

    if command not in task_commands.CMD2CLASS:
        logger.error("Unknown command %r", command)
        sys.exit(1)

    common_rest.Client._default_component = common_proxy.ReliableServerProxy._default_component = ctm.Component.EXECUTOR
    common_rest.Client._external_auth = common_proxy.ReliableServerProxy._external_auth = auth

    logger.info("[%d] handle command %r with %d arguments", os.getpid(), command, len(ipc_args))
    projects_handler.load_project_types(reuse=True)
    executable, status = None, None
    tcpdump_process = None
    try:
        iteration = command_args.get("iteration")
        if iteration is not None:
            command_args["agentr"] = agentr_client.Session(session_token, iteration + 1, 0, logger)
        exec_class = task_commands.CMD2CLASS[command]

        executable = exec_class(ipc_args=ipc_args, task=task, **command_args)
        task = getattr(executable, "task")

        # collect profiling information (SANDBOX-4622, SANDBOX-5872). entry points are:
        # - sandboxsdk.svn.Svn.svn()
        # - sandboxsdk.svn.Hg._hg()
        # - common.rest.Client._request(), where _accounter callback is used
        # - common.log.ExceptionSignalSenderHandler

        common_statistics.Signaler(
            # exceptions happened during task executions will be put in a separate table
            common_statistics.ClientSignalHandler(
                token=auth,
                task_id=task and task.id
            ),
            component=ctm.Component.EXECUTOR,
            update_interval=settings.client.statistics.update_interval
        )

        if task:
            # disable statistics sending logging so that it does not confuse anyone
            service_logger = logging.getLogger("statistics_sender")
            service_logger.disabled = True

            operations_handler = common_statistics.AggregatingClientSignalHandler(
                aggregate_by=("kind", "method"),  # everything else is common anyway!
                sum_by=("duration", "count"),
                fixed_args=dict(
                    # because there are going to be single calls,
                    # and all intermediate aggregation takes place in the above class
                    count=1,
                    task_id=task.id,
                    task_type=task.type,
                    host=common_config.Registry().this.id,
                    owner=task.owner,
                ),
                logger=service_logger
            )
            common_statistics.Signaler().register(operations_handler)

            calls_handler = common_statistics.ApiCallClientSideHandler(
                aggregate_by=("timestamp", "duration"),
                fixed_args=dict(
                    task_id=task.id,
                    client_id=common_config.Registry().this.id,
                ),
                replace_timestamp=False,
                logger=service_logger
            )
            common_statistics.Signaler().register(calls_handler)

            # SANDBOX-7129: list of owners to collect extended statistics for
            extended_apicall_owners = {
                "BOOTCAMP",
                "COLLECTIONS_CI",
                "CONNECT_FRONTEND",
                "FEMIDA_STATIC",
                "FRONTEND",
                "INTRASEARCH_WWW",
                "LEGO_TRENDBOX_CI",
                "MSSNGRFRONT",
                "REPORT_RENDERER",
                "SANDBOX_CI_SEARCH_INTERFACES",
                "SEARCH_INTERFACES",
                "SEARCH_INTERFACES_BROWSERS",
                "SI_BUFFER",
                "STAFF-WWW",
                "STARTREK",
                "TESTPALM",
                "TOOLS",
                "TRENDBOX_CI_TEAM",
                "TYCOON-FRONTEND",
                "UGC_FRONTEND",
                "VELOCITY",
                "WATSON",
                "WIKI",
                "WMFRONT",
                "YANDEX-CONNECT",
            }

            def request_callback(request):
                """
                :type request: common.rest.Client.Request
                """
                signals = []

                utcnow = dt.datetime.utcnow()
                utcthen = utcnow - dt.timedelta(milliseconds=request.duration)

                signals.append(dict(
                    type=ctss.SignalType.TASK_OPERATION,
                    date=utcthen,
                    timestamp=utcthen,
                    kind=ctss.OperationType.REST,
                    method=request.method.lower(),
                    duration=request.duration,
                ))

                if task.owner in extended_apicall_owners:
                    status_code = request.response.status_code if request.response else None
                    signals.append(dict(
                        type=ctss.SignalType.API_CALL_CLIENT_SIDE,
                        date=utcthen,
                        timestamp=utcthen,
                        method=request.method,
                        response_code=status_code,
                        duration=request.duration,
                        path=request.path,
                        reqid=request.id,
                        query_string=urlparse.urlencode(sorted(request.params.get("params", {}).items()), doseq=True),
                    ))

                common_statistics.Signaler().push(signals)
                if settings.common.installation == ctm.Installation.TEST:
                    task.ctx["__rest_request_count"] = task.ctx.get("__rest_request_count", 0) + 1

            common_rest.Client.request_callback = staticmethod(request_callback)

        if settings.common.statistics.enabled:
            logging.debug("Statistics collection is set up")

        if command == ctc.Command.EXECUTE_PRIVILEGED_CONTAINER and isinstance(task, adapters.SDK2TaskAdapter):
            task.set_cmd(executable)
            task.ctx.update(ipc_args.get("task_context", {}))

        forced_status, message = ipc_args.get("force_status", (None, None))
        if task and forced_status:
            logger.debug("Force switch task status %s -> %s", executable.prev_status, status)
            status = forced_status
            executor_ctx["status_message"] = message
        else:
            if task:
                task.agentr = command_args.get("agentr")
                task.agentr(None)  # Ensure connection associated with the task session
                if command == ctc.Command.EXECUTE:
                    unprivileged_tcpdump = run_tcpdump_subprocess(task)
                    if unprivileged_tcpdump is not None:
                        task.agentr.register_tcpdump(unprivileged_tcpdump.pid)
                elif command == ctc.Command.EXECUTE_PRIVILEGED_CONTAINER:
                    tcpdump_process = run_tcpdump_subprocess(task)
            status, message = list(common_itertools.chain(executable.execute(), None))[:2]
            executor_ctx["status_message"] = message
            if status in ctt.Status.Group.WAIT and hasattr(task, "wait_targets"):
                executor_ctx["wait_targets"] = task.wait_targets
        if status not in ctt.Status:
            raise Exception("Command {!r} returned wrong status {!r}".format(command, status))
    except common_errors.TaskContextError as ex:
        logger.exception("Task context is not serializable")
        status = ctt.Status.EXCEPTION
        # executable and executable.task are always not None,
        # can't have a problem with context without active Task
        task_utils.task_set_info(executable.task, ex.message)
    except (
        common_proxy.ReliableServerProxy.SessionExpired,
        common_rest.Client.SessionExpired,
        agentr_errors.NoTaskSession,
    ):
        logger.exception("Task session expired")
        status = "SessionExpired"
    except common_errors.TaskError as ex:
        logger.exception("Task error occurred")
        status = "SessionExpired"
        executor_ctx["status_message"] = str(ex)
    except common_errors.TemporaryError as ex:
        logger.exception(ex)
        status = ctt.Status.TEMPORARY
        executor_ctx["status_message"] = "Internal exception while executing command: {}".format(ex)
    except (Exception, KeyboardInterrupt, SystemExit) as ex:
        status = None  # Empty status equals to executor failure
        message = "Unhandled exception while executing command {!r}".format(command)
        executor_ctx["status_message"] = "{}: {}".format(message, ex)
        logger.exception(message)
    finally:
        if tcpdump_process is not None:
            logger.info("Terminating tcpdump subprocess with pid %s.", tcpdump_process.pid)
            try:
                tcpdump_process.terminate()
            except:
                logger.info("Can't terminate tcpdump subprocess with pid %s.", tcpdump_process.pid)

    executor_ctx["status"] = status
    logger.debug("Execution result: %r", status)
    logger.debug("Execution context: %s", executor_ctx)
    return task


def main():
    global logger

    ipc_args = exec_utils.parse_ipc_args()

    tasks_dir = ipc_args["tasks_dir"]
    sys.path = [tasks_dir] + sys.path
    if common_platform.on_linux() or common_platform.on_osx():
        os.environ["PYTHONPATH"] = ":".join(["/skynet", os.path.dirname(SANDBOX_DIR), SANDBOX_DIR, tasks_dir])

    result_filename = ipc_args.get("result_filename")
    if not result_filename:
        raise Exception("result filename is not found in arguments")
    master_host = ipc_args.pop("master_host", None)
    # set container's node_id to node_id of the master host
    if master_host:
        settings.this.id = master_host

    command = ipc_args["command"]
    command_args = ipc_args["args"]
    task_id = command_args.get("task_id")
    iteration = command_args.get("iteration")
    executor_ctx = command_args.get("ctx", {})
    exec_type = command_args.get("exec_type")
    if not common_platform.on_windows():
        signal.signal(signal.SIGUSR2, lambda *_: common_threading.dump_threads(logger))

    kernel.util.console.setProcTitle("[sandbox] Executor: " + command + (" task #" + str(task_id) if task_id else ""))
    # We should remove Y_PYTHON_ENTRYPOINT because it broke arcadia builded binaries
    if exec_type == ctt.ImageType.BINARY:
        os.environ.pop("Y_PYTHON_ENTRY_POINT", None)

    has_root = common_os.User.has_root
    username = os.environ.get("LOGNAME") if command != ctc.Command.EXECUTE_PRIVILEGED_CONTAINER else None
    with common_os.User.Privileges(username):
        agentr, logger, logdir = exec_utils.configure_logger(task_id, iteration, ipc_args.get("session_token"))
        if not has_root or ipc_args.get("privileges_set"):
            current_pid = exec_utils.get_executor_pid()
            agentr.register_executor(current_pid)

        # Start ssh-agent if ssh key is provided
        key = ipc_args.get("ssh_private_key", None)
        if key:
            # Don't inherit TMPDIR from parent, we may have no permissions to write there.
            # https://github.com/openssh/openssh-portable/blob/25cf9105b849932fc3b141590c009e704f2eeba6/misc.c#L1406
            os.environ.pop("TMPDIR", None)

            try:
                ssh_agent = sdk2.ssh.SshAgent()
            except sdk2.ssh.SshAgentNotAvailable as e:
                logger.error("Cannot start ssh-agent, private ssh-key won't be available: %s", e)
            else:
                try:
                    ssh_agent.add(key)
                except common_errors.TaskError as ex:
                    logger.error("Cannot add ssh key, private ssh-key won't be available: %s", ex)

        common_arc_token = ipc_args.get("common_arc_token", None)
        if common_arc_token:
            vault_filter = common_log.VaultFilter.filter_from_logger(logger)
            if vault_filter:
                vault_filter.add_record("COMMON_ARC_TOKEN", common_arc_token)
            sdk2.environments.ArcEnvironment.set_common_arc_token(common_arc_token)
        command_args["logdir"] = os.path.split(logdir)[-1]
        task = execute_command(ipc_args, executor_ctx, command, command_args)
        subsequent = command_args.get("subsequent")
        status = executor_ctx["status"]
        if subsequent and status in ctt.Status and not (
            # if task has stop hook, it cannot be executed in the same process
            ipc_args.get("container") and os.path.exists(constants.ON_TASK_STOP_ROOT_HOOK)
        ):
            subsequent["args"]["status"] = status
            subsequent["args"]["logdir"] = command_args["logdir"]
            execute_command(
                ipc_args, executor_ctx,
                subsequent["command"], subsequent["args"],
                subsequent=True, task=task
            )
            executor_ctx["ran_subsequent"] = True

        logging.debug("Sending operations profiling signals")
        common_statistics.Signaler().wait()
        logging.debug("All statistics sent!")
        exec_utils.save_result(result_filename, executor_ctx)


if __name__ == "__main__":
    main()
