import os
import abc
import sys
import json
import time
import base64
import signal
import shutil
import socket
import logging
import tempfile
import textwrap
import cProfile as profile
import datetime as dt
import itertools as it
import threading as th
import traceback
import distutils.spawn as distutils_spawn

from sandbox.agentr import utils as agentr_utils
from sandbox.agentr import errors as agentr_errors

from sandbox.executor.common import utils as exec_utils
from sandbox.executor.common import constants as exec_constants

from sandbox.executor.commands import base as base_commands
from sandbox.executor.commands.task import utils as task_utils
from sandbox.executor.commands.task import adapters
from sandbox.executor.commands.task import arc_traceback

from sandbox.common import fs as common_fs
from sandbox.common import os as common_os
from sandbox.common import log as common_log
from sandbox.common import rest as common_rest
from sandbox.common import proxy as common_proxy
from sandbox.common import errors as common_errors
from sandbox.common import config as common_config
from sandbox.common import system as common_system
from sandbox.common import encoding as common_encoding
from sandbox.common import patterns as common_patterns
from sandbox.common import platform as common_platform
from sandbox.common import itertools as common_itertools
from sandbox.common import statistics as common_statistics

import sandbox.common.types.misc as ctm
import sandbox.common.types.task as ctt
import sandbox.common.types.client as ctc
import sandbox.common.types.resource as ctr
import sandbox.common.types.statistics as ctst

from sandbox.common.windows import wsl

from sandbox import sdk2

import six
from six.moves import cPickle as pickle

if six.PY2:
    import subprocess32 as sp
else:
    import subprocess as sp

# Setup custom picklers explicitly (SANDBOX-7356)
# noinspection PyUnresolvedReferences
import kernel.util.pickle  # noqa


logger = logging.getLogger("executor")


@six.add_metaclass(abc.ABCMeta)
class Executable(base_commands.ExecutableBase):

    class UserExceptionHandler(object):
        def __init__(self, executable, ok_status, no_failure=False):
            self.executable = executable
            self.repeat = False
            self.updated = False
            self.ok_status = ok_status
            self.no_failure = no_failure

        @property
        def task(self):
            return self.executable.task

        def _save_traceback(self, ei):
            data = logging.Formatter().formatException(ei)
            result = common_encoding.force_unicode_safe(data)
            result = common_log.VaultFilter.filter_message_from_logger(logging.getLogger(), result)
            self.task.ctx["__last_error_trace"] = result

        def __enter__(self):
            return self

        def __repeat_after_wait(self, exc_val):
            try:
                exc_val(self.task.task)
            except common_errors.NothingToWait:
                self.repeat = True
                return True
            return False

        def __exit__(self, exc_type, exc_val, exc_tb):
            exc_info = (exc_type, exc_val, exc_tb)

            if exc_type is None:
                self.executable.target_status = self.ok_status

            elif issubclass(exc_type, common_errors.NothingToWait):
                self.repeat = True
                return True

            elif issubclass(exc_type, (common_errors.WaitTask, sdk2.WaitTask)):
                self.executable.target_status = ctt.Status.WAIT_TASK
                self.task.wait_targets = {}
                if isinstance(exc_val, sdk2.WaitTask):
                    self.task.wait_targets["tasks"] = exc_val.tasks
                    if self.__repeat_after_wait(exc_val):
                        return True

            elif issubclass(exc_type, (common_errors.WaitTime, sdk2.WaitTime)):
                self.executable.target_status = ctt.Status.WAIT_TIME
                self.task.wait_targets = {}
                if isinstance(exc_val, sdk2.WaitTime):
                    self.task.wait_targets["time"] = exc_val.timeout
                    if self.__repeat_after_wait(exc_val):
                        return True

            elif issubclass(exc_type, sdk2.WaitOutput):
                self.executable.target_status = ctt.Status.WAIT_OUT
                self.task.wait_targets = {}
                if isinstance(exc_val, sdk2.WaitOutput):
                    self.task.wait_targets["output_parameters"] = {str(k): v for k, v in six.iteritems(exc_val.targets)}
                if self.__repeat_after_wait(exc_val):
                    return True

            else:
                try:
                    # try to replace files with arcanum links
                    message = arc_traceback.process_tb(exc_type, exc_val, exc_tb)
                except Exception:
                    # fallback to default message
                    logger.exception("Error while processing tb")
                    message = common_encoding.escape_html(
                        common_encoding.force_unicode(
                            "".join(traceback.format_exception(exc_type, exc_val, exc_tb)),
                            errors="replace"
                        ),
                        quote=False
                    )

                message = common_encoding.force_unicode(message, encoding="utf-8", errors="replace")
                message = common_log.VaultFilter.filter_message_from_logger(logger, message)

                if issubclass(exc_type, Exception):
                    logger.exception("", exc_info=exc_info)

                    if issubclass(exc_type, common_errors.TemporaryError):
                        self.task.set_info("Temporary error: {}".format(message), do_escape=False)
                        self.executable.target_status = ctt.Status.TEMPORARY
                        self._save_traceback(exc_info)
                    elif issubclass(exc_type, agentr_errors.ResourceNotAvailable):
                        self.task.set_info("Unable to continue: {}".format(message), do_escape=False)
                        self.executable.target_status = ctt.Status.NO_RES
                    elif issubclass(exc_type, agentr_errors.InvalidData):
                        self.task.set_info("Invalid resource data: {}".format(message), do_escape=False)
                        self.executable.target_status = ctt.Status.EXCEPTION
                        self._save_traceback(exc_info)
                    elif issubclass(exc_type, agentr_errors.CopierError):
                        self.task.set_info(message, do_escape=False)
                        self.executable.target_status = ctt.Status.EXCEPTION
                    elif issubclass(exc_type, common_errors.TaskStop):
                        self.task.set_info("Stop request: {}".format(message), do_escape=False)
                        self.executable.target_status = ctt.Status.STOPPED
                    elif issubclass(exc_type, common_errors.TaskFailure):
                        self.task.set_info("Task failed: {}".format(message), do_escape=False)
                        self.executable.target_status = ctt.Status.EXCEPTION if self.no_failure else ctt.Status.FAILURE
                    elif issubclass(exc_type, common_errors.TaskError):
                        self.task.set_info("Task error: {}".format(message), do_escape=False)
                        self.executable.target_status = ctt.Status.EXCEPTION
                    else:
                        self.task.set_info("Unhandled exception: {}".format(message), do_escape=False)
                        self.executable.target_status = ctt.Status.EXCEPTION
                        self._save_traceback(exc_info)

                    if hasattr(exc_val, "get_task_context_value") and hasattr(exc_val, "task_context_field"):
                        self.task.ctx[str(exc_val.task_context_field)] = exc_val.get_task_context_value()
                    if hasattr(exc_val, "get_task_info"):
                        self.task.set_info(exc_val.get_task_info(), do_escape=False)

                elif issubclass(exc_type, KeyboardInterrupt):
                    # Treat KeyboardInterrupt as user-exception. Some tasks rely on `thread.interrupt_main`
                    # to cause an error in main thread and expect to go to EXCEPTION.
                    # Clear example -- `yt.wrapper` library, it causes KeyboardInterrupt on failed transaction ping.
                    self.task.set_info("Unhandled exception: {}".format(message))
                    self.executable.target_status = ctt.Status.EXCEPTION
                    self._save_traceback(exc_info)

                else:
                    logger.exception("User-code raised BaseException!", exc_info=exc_info)
                    return None  # don't handle it, re-raise

            parameters_storage = self.task.Parameters if isinstance(self.task, adapters.SDK2TaskAdapter) else self.task
            if (
                self.executable.target_status in getattr(parameters_storage, "suspend_on_status", []) and
                ctt.Status.can_switch(self.task.status, ctt.Status.SUSPENDING)
            ):
                self.task.suspend()

            self.updated = self.task.status != self.executable.target_status
            return True

    def __init__(self, task_id=None, logdir=None, agentr=None, task=None):
        if task:
            if isinstance(task, adapters.SDK2TaskAdapter):
                task.set_cmd(self)
            self.task = task
        super(Executable, self).__init__(task_id=task_id, logdir=logdir, agentr=agentr, task=task)
        if not hasattr(self, "task"):
            self.task = self.safe_load(task_id, logdir, agentr) if task_id else None
        if self.task:
            self.task.ctx.setdefault("__GSID", "SB:{}:{}".format(self.task.type, self.task.id))

    def safe_load(self, task_id, logdir, agentr):
        """
        Safe creation of task object (it will also register new TASK_LOGS resource for it).

        The procedure is different for different SDK versions.
        For sdk1, a task object is loaded via XMLRPC call (pickling happens here under the hood);
        for sdk2, the task is loaded from AgentR's local database.
        """

        try:
            # FIXME: fix tests to do not catch exception common.rest.Client.HTTPError
            try:
                current_task = agentr.meta
                sdk1_task = current_task["sdk_version"] == 1
                parent_id = (current_task.get("parent") or {}).get("id")

                # Fill current task object with the REST API response we have already for SDK2 tasks
                # except if it's timeout Terminate stage, use the standard REST client to retrieve current data
                if getattr(self, "status", None) == ctt.Status.TIMEOUT:
                    client = common_rest.Client
                else:
                    client = type("FakeRestClient", (object,), {
                        "__getattr__": lambda this, _: this,
                        "__getitem__": lambda *_: current_task,
                        "__init__": lambda *_, **__: None
                    })

                with common_rest.DispatchedClient as dispatch:
                    dispatch(client)
                    current_task = sdk2.Task.current

                if current_task is not None:
                    current_task.agentr = agentr
                if sdk1_task:
                    raise common_errors.UnknownTaskType
                task = adapters.SDK2TaskAdapter(current_task, parent_id, self)

            except common_errors.UnknownTaskType:
                if common_system.inside_the_binary():
                    raise
                from sandbox.yasandbox import manager
                task = manager.task_manager.load(task_id)
                if task.type == "TASK":
                    raise TypeError("Unknown task type")

            task.agentr = agentr
            logres_data = agentr.log_resource
            if logres_data:
                if isinstance(task, adapters.SDK2TaskAdapter):
                    task.log_resource = sdk2.Resource.restore(logres_data)
                else:
                    import sandbox.yasandbox.proxy.resource as resource_proxy
                    task._log_resource = resource_proxy.Resource._restore_from_json(logres_data)
            else:
                # TODO: remove this fallback code (SANDBOX-4842)
                logger.warning("agentr.log_resource is empty, run fallback code to create resource")
                if isinstance(task, adapters.SDK2TaskAdapter):
                    task.log_resource = sdk2.service_resources.TaskLogs(current_task, "Task logs", agentr.logdir)
                    try:
                        self.rest_client.resource[task.log_resource.id].source()
                    except Exception as ex:
                        logger.error("Error when adding current host to the resource %s: %s", task.log_resource, ex)
                else:
                    import sandbox.sandboxsdk.task
                    task._log_resource = sandbox.sandboxsdk.task.SandboxTask._create_resource(
                        task, "Task logs", logdir, sdk2.service_resources.TaskLogs
                    )
                    task._log_resource.add_host()

            return task
        except Exception:
            logger.exception("Unable to load task #%s", task_id)
            raise

    def set_status(self, status, message=None, force=False):
        self.prev_status = self.task.status
        logger.debug("Switch task status %s -> %s", self.prev_status, status)
        wait_targets = None
        if status in ctt.Status.Group.WAIT and hasattr(self.task, "wait_targets"):
            wait_targets, self.task.wait_targets = self.task.wait_targets, None
        self.rest_client.task.current.audit(
            {
                "status": status,
                "message": message,
                "force": force,
                "expected_status": self.task.status,
                "wait_targets": wait_targets,
            }
        )
        self.task.status = status
        # SDK2 tasks are cached in AgentR, so we should update its state
        if isinstance(self.task, adapters.SDK2TaskAdapter):
            task_utils.update_agentr_meta(self.task)

        return status

    @abc.abstractmethod
    def execute(self):
        pass


class ExecuteTask(Executable):
    class TimeoutChecker(th.Thread):
        stopping = th.Event()

        def __init__(self, expiration_time, task):
            self.fn = getattr(task, "on_before_timeout", None)
            self.agentr = task.agentr
            self.checkpoints_fn = getattr(task, "timeout_checkpoints", None)
            checkpoints = self.checkpoints_fn and self.checkpoints_fn()
            self.active = expiration_time is not None and self.fn is not None and checkpoints
            logger.debug(
                "[TC] initialized, active=%s, exp time=%s",
                self.active, dt.datetime.fromtimestamp(expiration_time)
            )
            if self.active:
                self.checkpoints = [(ch, expiration_time - ch) for ch in sorted(checkpoints)]
                super(ExecuteTask.TimeoutChecker, self).__init__(name="Timeout Checker")

        def update_expiration_time(self, expiration_time):
            logger.debug("[TC] new expiration time: %s", dt.datetime.fromtimestamp(expiration_time))
            now = time.time()
            self.checkpoints = []
            for ch in sorted(self.checkpoints_fn()):
                checkpoint_time = expiration_time - ch
                if checkpoint_time > now:
                    self.checkpoints.append((ch, checkpoint_time))

        def run(self):
            fn_called = False
            while not self.stopping.is_set():
                try:
                    secs, timestamp = self.checkpoints.pop()
                except IndexError:
                    self.stopping.wait()
                    break
                until = timestamp - time.time()
                if fn_called or until > 0:
                    logger.debug("[TC] wait for %ss until %ss remain", until, secs)
                    self.stopping.wait(until)
                    if not self.stopping.is_set():
                        # check that kill timeout wasn't increased
                        state = self.agentr.state["__state"]
                        current_exp_time = state["args"].get("expiration_time")
                        if current_exp_time >= timestamp + secs + 1:
                            self.update_expiration_time(current_exp_time)
                            continue
                        self.fn(secs)
                        fn_called = True

        def __enter__(self):
            if self.active:
                self.start()
            return self

        def __exit__(self, *_):
            if self.active:
                self.stopping.set()
                logger.debug("[TC] waiting for thread stop")
                self.join()
                logger.debug("[TC] thread stopped")

    class Phazotron(th.Thread):
        SIZE_ST = sdk2.phazotron.Phazotron.SIZE_ST
        RET_ST = sdk2.phazotron.Phazotron.RET_ST

        def __init__(self, task):
            self.task = task
            self.logger = logger.getChild("phazotron")
            self.sock_name = common_config.Registry().client.phazotron.sockname
            basename = os.path.basename(self.sock_name).split(".")
            basename.insert(-1, str(task.id))
            self.sock_name = os.path.join(os.path.dirname(self.sock_name), ".".join(basename))

            try:
                os.unlink(self.sock_name)
            except OSError:
                if os.path.exists(self.sock_name):
                    raise

            self.sock = socket.socket(socket.AF_UNIX)
            self.sock.bind(self.sock_name)
            self.sock.settimeout(None)
            self.sock.listen(10)
            self.peers = set()
            self._scripts = set()

            super(ExecuteTask.Phazotron, self).__init__(name="Phazotron")

        def script(self, cls):
            assert issubclass(cls, sdk2.phazotron.Phazotron)
            script_name = self.task.abs_path(cls.__name__.lower())
            with open(script_name, "w") as fh:
                fh.write(textwrap.dedent(cls.SCRIPT_TMPL).lstrip().format(
                    executable=sys.executable + (" interpret" if common_system.inside_the_binary() else ""),
                    sockname=self.sock_name,
                    size_st=self.SIZE_ST.format,
                    ret_st=self.RET_ST.format,
                    code_root=exec_constants.SANDBOX_DIR,
                    iteration=self.task.agentr.iteration,
                    session_token=self.task.agentr.token,
                    config_env_var=common_config.Registry().CONFIG_ENV_VAR,
                    config_path=str(common_config.Registry().custom),
                ))
            os.chmod(script_name, 0o555)
            self.logger.debug("[PT] Script placed at %r", script_name)
            self._scripts.add(script_name)
            return script_name

        def _call(self, peer, method_name, method, args, kws):
            self.peers.add(peer)
            no = th.current_thread().name
            try:
                self.logger.info(
                    "[PT.%s/%d] Requested call of %s(*%r, **%r).",
                    no, len(self.peers), method_name, args, kws
                )
                ret = getattr(getattr(method, "__func__", method), "__orig_method__", method)(*args, **kws)
                peer.sendall("".join([self.RET_ST.pack(0), json.dumps(ret)]))
                self.logger.debug("[PT.%s/%d] Method called successfully.", no, len(self.peers) - 1)
            except Exception as ex:
                self.logger.error("[PT.{}/{}] ".format(no, len(self.peers) - 1) + str(ex))
                peer.sendall("".join([self.RET_ST.pack(1), str(ex)]))
            peer.close()
            self.peers.remove(peer)

        @staticmethod
        def _recv(peer, size, data_name):
            data = peer.recv(size)
            if len(data) != size:
                raise ValueError(
                    "Error reading {} from the socket. Data fetched: {!r}".format(data_name, data)
                )
            return data

        def run(self):
            accepts = 0
            self.logger.info("[PT] Thread started.")
            while True:
                self.logger.debug("[PT] Waiting for incoming connections.")
                try:
                    accepts += 1
                    peer, address = self.sock.accept()
                    self.logger.debug("[PT] New connection accepted, currently active %d.", len(self.peers))
                except Exception as ex:
                    self.logger.error("[PT] " + str(ex))
                    break
                try:
                    size = self.SIZE_ST.unpack(self._recv(peer, self.SIZE_ST.size, "size of method name"))[0]
                    if not size:
                        peer.close()
                        break
                    full_method_name = self._recv(peer, size, "method name")
                    class_name, _, method_name = full_method_name.partition(".")
                    if _ != ".":
                        raise ValueError("Error parsing method name: {!r}".format(full_method_name))
                    try:
                        cls = common_patterns.Api[class_name]
                    except KeyError:
                        raise ValueError("Class {!r} is not registered as API".format(class_name))
                    try:
                        method = getattr(cls, method_name)
                    except AttributeError:
                        raise ValueError("Class {!r} has no method {!r}".format(class_name, method_name))
                    size = self.SIZE_ST.unpack(
                        self._recv(peer, self.SIZE_ST.size, "size of positional arguments")
                    )[0]
                    method_args = json.loads(self._recv(peer, size, "positional arguments"))
                    size = self.SIZE_ST.unpack(
                        self._recv(peer, self.SIZE_ST.size, "size of keyword arguments")
                    )[0]
                    method_kws = json.loads(self._recv(peer, size, "keyword arguments"))
                    w = th.Thread(
                        name=str(accepts),
                        target=self._call,
                        args=(peer, full_method_name, method, method_args, method_kws),
                    )
                    w.daemon = True
                    w.start()
                except ValueError as ex:
                    self.logger.error("[PT] " + str(ex))
                    peer.sendall("".join([self.RET_ST.pack(1), str(ex)]))
            self.logger.info("[PT] Thread stopped.")

        def __enter__(self):
            self.start()
            return self

        def __exit__(self, *_):
            try:
                sock = socket.socket(socket.AF_UNIX)
                sock.connect(self.sock_name)
                sock.sendall(self.SIZE_ST.pack(0))
                sock.close()
                self.join(5)
                os.unlink(self.sock_name)
                for script in self._scripts:
                    os.unlink(script)
            except Exception as ex:
                self.logger.warning("[PT] Exception occurred while shutting down: %s", ex)

    def __init__(self, task_id, logdir, vault_key=None, ramdrive=None, agentr=None, ipc_args=None, **kwargs):
        super(ExecuteTask, self).__init__(task_id, logdir, agentr)
        self.ipc_args = ipc_args
        self.vault_key = base64.b64decode(vault_key) if vault_key else None
        self.task.ramdrive = (
            ctm.RamDrive(*map(ramdrive.get, ("type", "size", "path")))
            if ramdrive else
            None
        )
        self.task.container = exec_utils.container_info(ipc_args)

    @property
    def timeout_checker(self):
        return self.TimeoutChecker(self.ipc_args["args"].get("expiration_time"), self.task)

    @property
    def platform(self):
        return common_platform.platform()

    def _prepare_environment(self):
        logger.debug("Prepare environment for task.")
        os.environ["YA_CACHE_DIR"] = os.path.join(sdk2.environments.SandboxEnvironment.build_cache_dir, "ya")
        logging.debug("Current environment is:\n" + "\n".join((
            "\t{}: {}".format(_, os.environ[_]) for _ in sorted(os.environ)
        )))
        if self.task.environment:
            for environment in self.task.environment:
                environment.prepare()
                environment.touch()

    def _prepare(self):
        if self.task is None:
            return
        logger.info("Preparing task %s", self.task)
        task_utils.initialize_sdk(self.vault_key, self.task)
        if self.task.status == ctt.Status.STOPPING:
            return ctt.Status.STOPPED
        if self.task.status != ctt.Status.ASSIGNED:
            return self.task.status
        self.set_status(
            ctt.Status.PREPARING,
            "Executing on {} ({})".format(
                common_config.Registry().this.id, common_platform.get_platform_alias(self.platform)
            )
        )

        self.task.timestamp_start = self.task.updated = time.time()
        self.task.ctx.pop("__last_error_trace", None)
        self.task.ctx.pop("__no_timeout", None)
        self.task.platform = self.platform
        self.task.host = common_config.Registry().this.id

        workdir = self.task.abs_path()
        common_fs.chmod_for_path(workdir, "a+w", recursively=True)
        os.chdir(workdir)

        self.task.initLogger()

        if (
            self.task.status != ctt.Status.PREPARING or
            self.task.host != common_config.Registry().this.id or self.task.host == "localhost"
        ):
            # Drop session
            raise common_proxy.ReliableServerProxy.SessionExpired(
                "Wrong task #{} state. Current status: '{}', host: '{}'".format(
                    self.task.id, self.task.status, self.task.host
                )
            )

        os.environ["USER"] = os.environ["USERNAME"] = os.environ["LOGNAME"]  # Cygwin!
        common_fs.chmod_for_path(os.path.join(self.task.abs_path(), "tmp"), "0755", recursively=True)
        tempfile.tempdir = None

        first_call = True
        with self.UserExceptionHandler(self, ctt.Status.PREPARING):
            self._prepare_environment()
        for _ in common_itertools.progressive_yielder(0.1, 1800, float("inf"), False):
            with self.UserExceptionHandler(self, ctt.Status.PREPARING) as ux:
                self.task.on_prepare()
            if not ux.repeat:
                break

            message = "Wakeup conditions are already met, running on_prepare() again"
            logger.info(message)
            if first_call:
                self.rest_client.task.current.audit(message="{} (notifying about this only once)".format(message))
                first_call = False

        self.__main_thread = (th.current_thread().ident, th.current_thread().name)

        if not common_system.inside_the_binary():
            from sandbox import projects
            self.task.ctx["tasks_version"] = getattr(projects, "__revision__", 0)
            self.task.ctx["task_version"] = projects.TYPES[self.task.type].revision
        os.environ["GSID"] = str(self.task.ctx["__GSID"])
        if not common_platform.on_windows():
            logger.debug("Ulimits:\n%s", six.ensure_str(sp.check_output(["bash", "-c", "ulimit -a"]).strip()))
        logger.debug("\n=== Task input context ===\n%s\n==================\n", self.task.ctx)

        task_utils.update_task(self.task, False)
        return self.target_status

    def _on_terminate(self, *_):
        self.task.on_terminate()
        task_utils.update_agentr_meta(self.task)
        # noinspection PyProtectedMember
        os._exit(0)

    def _execute(self):
        if common_config.Registry().client.executor.profile:
            pr = profile.Profile()
            pr.enable()

        first_call = True
        logger.debug("Executing task %s", self.task)
        for _ in common_itertools.progressive_yielder(0.1, 300, float("inf"), False):
            with self.UserExceptionHandler(self, ctt.Status.SUCCESS) as ux:
                self.task.on_execute()
            if not ux.repeat:
                break

            message = "Wakeup conditions are already met, running on_execute() again"
            logger.info(message)
            if first_call:
                self.rest_client.task.current.audit(message="{} (notifying about this only once)".format(message))
                first_call = False

        logger.debug("Task.on_execute finished, target status: %s", self.target_status)

        # target status can be anything here (WAIT_*, EXCEPTION, etc)
        with self.UserExceptionHandler(self, self.target_status):
            self.task.postprocess()
        logger.debug("Task.postprocess finished, target status: %s", self.target_status)

        logger.debug("Task %s execution finished.", self)

        if common_config.Registry().client.executor.profile:
            pr.disable()
            pr.dump_stats(self.task.log_path("task.prof"))
        task_utils.update_task(self.task, False)

        self.post_execute()
        common_fs.FSJournal().clear()  # Be a paranoid
        logger.debug("\n=== Task output context ===\n%s\n=================\n", self.task.ctx)
        return self.target_status

    def post_execute(self):
        if self.task.ramdrive and not self.task.ctx.get("__do_not_dump_ramdrive"):
            logger.debug("dumping ramdrive: %s", self.task.ramdrive)
            agentr_utils.dump_dir_disk_usage_scandir(
                str(self.task.ramdrive.path),
                self.task.log_path("ramdrive_usage.yaml")
            )
        if distutils_spawn.find_executable("docker"):
            logger.debug("Start docker post execute")
            utcnow = dt.datetime.utcnow()
            start_time = time.time()
            docker_config = self.task.abs_path("sandbox_docker_config")
            try:
                docker_client = task_utils.DockerClient(self.task, docker_config)
                self._backup_docker_images(docker_client)
                self._cleanup_docker_files(docker_client)
            except Exception as err:
                logger.warning("Unexpected error on docker processing: %s", str(err))
                # TODO: destroy_container
            common_statistics.Signaler().push(dict(
                type=ctst.SignalType.TASK_OPERATION,
                kind=ctst.OperationType.DOCKER,
                date=utcnow,
                timestamp=utcnow,
                method="post_execute",
                duration=int((time.time() - start_time) * 1000)
            ))
            shutil.rmtree(docker_config, ignore_errors=True)

    def _backup_docker_images(self, docker_client):  # type: (task_utils.DockerClient) -> None
        docker_registry = "registry.yandex.net"
        docker_registry_token = self.ipc_args.get("docker_registry_token")
        if not docker_registry_token:
            logger.debug("No docker token on host")
            return
        docker_images = docker_client.find_external_docker_images(task_created_time=self.task.timestamp_start)
        if not docker_images:
            logger.debug("No external docker images found")
            return
        logger.debug("Found external docker images: %s", ", ".join(sorted(docker_images.values())))
        err = docker_client.login_to_registry(docker_registry, docker_registry_token)
        if err:
            logger.error("Failed to login to %s: %s", docker_registry, err)
            return
        rest = common_rest.Client("https://" + docker_registry, auth=docker_registry_token)
        for image_id, image in docker_images:
            utcnow = dt.datetime.utcnow()
            start = time.time()
            repository, tag = image.rsplit(":", 1)
            try:
                rest.v2[repository].manifests[tag].read()
                logger.debug("Docker image '%s' already backed up", image)
                continue
            except rest.HTTPError as err:
                if err.status != 404:
                    logger.error("Failed to check image backup: %s", str(err))
                    continue
                else:
                    logger.debug("Trying to backup image '%s'", image)
            target_image_name = "{}/{}".format(docker_registry, image)
            err = docker_client.tag_image(image_id, target_image_name)
            if err:
                logger.error("Failed to tag image '%s': %s", image, str(err))
                continue
            err = docker_client.push_image(target_image_name)
            if err:
                logger.error("Failed to push image '%s': %s", target_image_name, str(err))
                continue
            duration = int((time.time() - start) * 1000)
            common_statistics.Signaler().push(dict(
                type=ctst.SignalType.TASK_OPERATION,
                kind=ctst.OperationType.DOCKER,
                date=utcnow,
                timestamp=utcnow,
                method="backup_image",
                duration=duration
            ))

    def _cleanup_docker_files(self, docker_client):  # type: (task_utils.DockerClient) -> None
        stopped_containers, error_messages = docker_client.stop_containers()
        for image in stopped_containers:
            logging.debug("Stopped running container(s) of image %s", image)
        for err in error_messages:
            logging.error(err)
        logger.debug(docker_client.get_disk_usage())
        logger.debug(docker_client.remove_unused_data())

    def execute(self):
        signal.signal(signal.SIGTERM, self._on_terminate)
        status = self._prepare()
        if status != ctt.Status.PREPARING:
            return status

        @six.add_metaclass(common_patterns.Api)
        class Task(object):

            @common_patterns.Api.register
            @staticmethod
            def sync_resource(rid):
                return self.task.sync_resource(rid)

        if common_platform.on_windows():  # TODO: support Phazotron on Windows [SANDBOX-7656]
            with self.timeout_checker:
                self.set_status(ctt.Status.EXECUTING)
                return self._execute()
        with self.Phazotron(self.task) as pt, self.timeout_checker:
            self.set_status(ctt.Status.EXECUTING)
            self.task.arcaphazotron = sdk2.phazotron.Arcaphazotron(pt)
            self.task.synchrophazotron = sdk2.phazotron.Synchrophazotron(pt)
            return self._execute()


class ExecutePrivilegedTask(ExecuteTask):
    @property
    def platform(self):
        return self.task.container.platform

    def _prepare_environment(self):
        logging.debug("Prepare non-privileged environment for task.")
        os.environ["YA_CACHE_DIR"] = os.path.join(sdk2.environments.SandboxEnvironment.build_cache_dir, "ya")
        sdk2.paths.make_folder(common_config.Registry().client.tasks.env_dir)
        logging.debug("Current environment is:\n" + "\n".join((
            "\t{}: {}".format(_, os.environ[_]) for _ in sorted(os.environ)
        )))
        for environment in self.task.environment or []:
            if not isinstance(environment, sdk2.environments.PipEnvironment):
                environment.prepare()
                environment.touch()

    def execute(self):
        signal.signal(signal.SIGTERM, self._on_terminate)
        status = self._prepare()
        if status != ctt.Status.PREPARING:
            return status

        @six.add_metaclass(common_patterns.Api)
        class Task(object):

            @common_patterns.Api.register
            @staticmethod
            def sync_resource(rid):
                return self.task.sync_resource(rid)

        if isinstance(self.task, adapters.SDK2TaskAdapter):
            self.ipc_args["task_id"] = self.task.id
            self.ipc_args["task_context"] = self.task.ctx
        else:
            self.ipc_args["task"] = pickle.dumps(self.task, 2)
            self.ipc_args["vault_filter"] = pickle.dumps(
                common_log.VaultFilter.filter_from_logger(logging.getLogger()), 2
            )
        self.ipc_args["command"] = ctc.Command.EXECUTE_PRIVILEGED_CONTAINER
        self.ipc_args["env"] = os.environ.copy()
        self.ipc_args["args"].pop("agentr", None)
        with open(sys.argv[1], "wb") as fh:
            fh.write(pickle.dumps(self.ipc_args, 2))

        # SystemExit is caught by outer try/except (to protect from users using sys.exit),
        # use dirty _exit here to avoid it.
        # noinspection PyProtectedMember
        os._exit(0)


class ExecutePrivilegedTaskInContainer(ExecuteTask):
    def __init__(self, *args, **kwargs):
        task_id = None
        ipc_args = kwargs["ipc_args"]
        if "task" in ipc_args:
            kwargs.pop("task_id", None)
            task = ipc_args["task"]
            if isinstance(task, six.string_types):
                # TODO: Backward compatibility code. Always unpickle after deployment. SANDBOX-5906
                task = pickle.loads(task)
            self.task = task
            vault_filter = common_log.VaultFilter.filter_from_logger(logging.getLogger())
            if vault_filter is not None:
                vault_filter_state = ipc_args["vault_filter"]
                if isinstance(vault_filter_state, six.string_types):
                    # TODO: Backward compatibility code. Always unpickle after deployment. SANDBOX-5906
                    vault_filter_state = pickle.loads(vault_filter_state)
                vault_filter.update_records(vault_filter_state)
            sdk2.Task.current.agentr = kwargs.get("agentr")
        else:
            task_id = kwargs.pop("task_id", None)
        super(ExecutePrivilegedTaskInContainer, self).__init__(task_id, *args[1:], **kwargs)

    def _prepare_environment(self):
        logging.debug("Prepare privileged environment for task.")
        sdk2.paths.make_folder(common_config.Registry().client.tasks.env_dir)
        os.environ.pop("YA_CACHE_DIR", None)
        logging.debug("Current environment is:\n" + "\n".join((
            "\t{}: {}".format(_, os.environ[_]) for _ in sorted(os.environ)
        )))
        for environment in self.task.environment or []:
            if isinstance(environment, sdk2.environments.PipEnvironment):
                environment.prepare()

    def execute(self):
        signal.signal(signal.SIGTERM, self._on_terminate)
        self.task.platform = self.platform
        task_utils.initialize_sdk(self.vault_key, self.task)

        self.task.initLogger()
        logger.debug("Ulimits:\n%s", six.ensure_str(sp.check_output(["bash", "-c", "ulimit -a"]).strip()))
        self.set_status(ctt.Status.EXECUTING)

        with self.Phazotron(self.task) as pt:
            self.task.arcaphazotron = sdk2.phazotron.Arcaphazotron(pt)
            self.task.synchrophazotron = sdk2.phazotron.Synchrophazotron(pt)
            self.task._sync_resource = self.task._sync_resource_via_synchrophazotron

            with self.UserExceptionHandler(self, ctt.Status.EXECUTING):
                self._prepare_environment()
            with self.timeout_checker:
                return self._execute()


class TerminateTask(Executable):
    def __init__(self, task_id, status, logdir, vault_key=None, agentr=None, task=None, **kwargs):
        self.status = status
        super(TerminateTask, self).__init__(task_id, logdir, agentr, task)
        self.vault_key = base64.b64decode(vault_key) if vault_key else None
        ramdrive = kwargs.get("ramdrive")
        if not self.subsequent:
            self.task.ramdrive = (
                ctm.RamDrive(*map(ramdrive.get, ("type", "size", "path")))
                if ramdrive else
                None
            )

    def _term_status_and_hook(self):
        """
        Returns a dict containing transition status, first and second hook to be executed
        and a termination message.
        """
        if self.status in (ctt.Status.SUCCESS, ctt.Status.FAILURE):
            return {
                "status": ctt.Status.FINISHING,
                "hook": self.task.on_finish,
                "subhook": (self.task.on_success if self.status == ctt.Status.SUCCESS else self.task.on_failure)
            }
        elif self.status == ctt.Status.TIMEOUT:
            if self.task.ctx.get("__no_timeout") and self.task.status == ctt.Status.EXECUTING:
                self.status = ctt.Status.TEMPORARY
                return {
                    "hook": self.task.on_break,
                    "message": "Timeout when no_timeout flag is active, switching to TEMPORARY"
                }
            else:
                return {
                    "status": None if self.task.status == ctt.Status.FINISHING else ctt.Status.STOPPING,
                    "hook": self.task.on_break,
                    "subhook": self.task.on_timeout
                }
        elif self.status in it.chain(ctt.Status.Group.BREAK, ctt.Status.Group.WAIT):
            return {
                "status": ctt.Status.STOPPING,
                "hook": (self.task.on_wait if self.status in ctt.Status.Group.WAIT else self.task.on_break),
            }
        return {}

    def execute(self):
        logger.debug("TerminateTask #%s", self.task)
        terminate_info = self._term_status_and_hook()
        status, hook, subhook, message = map(terminate_info.get, ["status", "hook", "subhook", "message"])
        if not self.subsequent:
            task_utils.initialize_sdk(self.vault_key, self.task)
        self.task.initLogger()
        if status and self.task.status != status:
            self.set_status(status)
        os.chdir(self.task.abs_path())
        if self.status not in ctt.Status:
            hook, message = self.task.on_break, "Executor returned wrong status {!r}".format(self.status)
            self.status = ctt.Status.EXCEPTION

        no_failure = self.status != ctt.Status.FINISHING
        with self.UserExceptionHandler(self, self.status, no_failure=no_failure):
            resources = list(self.task.list_resources())
            if not isinstance(self.task, adapters.SDK2TaskAdapter):
                for resource in resources:
                    if resource.state == ctr.State.NOT_READY:
                        resource_cls = sdk2.Resource[resource.type.name]
                        self.task.agentr.resource_register_meta(
                            resource.abs_path(), resource.__meta__,
                            share=resource_cls.share,
                            service=issubclass(resource_cls, sdk2.ServiceResource)
                        )
            if isinstance(self.task, adapters.SDK2TaskAdapter):
                self.task.list_parent_resources()

            if hook is not None:
                hook()
                if subhook is not None:
                    subhook()
            if common_platform.on_windows():
                _, user = common_os.User.service_users
                # TODO: remove after SANDBOX-9583
                wsl.Process.win_kill_user_procs(username=user.login, logger=logger, exclude_self=True)
            if self.status != ctt.Status.STOPPED:
                # cleanup after hooks to make possible to save debug files in resources
                self.task.cleanup()

        common_fs.chmod_for_path(self.task.abs_path(), "a-w+rX", recursively=True)
        try:
            task_utils.update_task(self.task)
        except Exception as ex:
            logger.exception("Error updating task #%r", self.task.id)
            self.target_status = ctt.Status.EXCEPTION
            message = str(ex)

        return self.target_status, message


class StopTask(TerminateTask):
    def __init__(self, task_id, status, logdir, vault_key=None, agentr=None, **kwargs):
        super(StopTask, self).__init__(task_id, status, logdir, vault_key=vault_key, agentr=agentr)


class ReleaseTask(Executable):
    def __init__(self, task_id, logdir, release_params=None, vault_key=None, agentr=None, **kwargs):
        assert release_params is not None, "Parameter 'release_params' cannot be None"
        self.release_params = release_params
        self.vault_key = base64.b64decode(vault_key) if vault_key else None
        super(ReleaseTask, self).__init__(task_id, logdir, agentr)

    def execute(self):
        try:
            task_utils.initialize_sdk(self.vault_key, self.task)
            self.task.initLogger()

            message_body = self.release_params.get("release_comments")
            self.task.ctx["release_changelog"] = message_body
            task_utils.update_task(self.task, False)
            self.task.on_release(self.release_params)
            task_utils.update_task(self.task)

            author = self.release_params["releaser"]
            release_status = self.release_params["release_status"]
            message_subject = self.release_params.get("release_subject")
            release_changelog_entry = self.release_params.get("release_changelog_entry")
            self.rest_client.task.current.release.create(
                author=author,
                release_status=release_status,
                message_subject=message_subject,
                message_body=message_body,
                changelog=release_changelog_entry
            )
            self.task.status, status_message = ctt.Status.RELEASED, "Released as {}".format(release_status)
        except Exception as ex:
            logger.exception("Error is occurred while calling on_release() of task #%s", self.task.id)
            message = common_log.VaultFilter.filter_message_from_logger(
                logger, "Task is not released: {}".format(traceback.format_exc())
            )
            self.task.set_info(message)
            self.rest_client.task.current.release.delete()
            self.task.ctx.pop("release_changelog", None)
            task_utils.update_task(self.task)
            self.task.status, status_message = ctt.Status.NOT_RELEASED, "Error on release: {}".format(ex)
        finally:
            self.task._mark_resources()
        return self.task.status, status_message


CMD2CLASS = {
    ctc.Command.EXECUTE: ExecuteTask,
    ctc.Command.EXECUTE_PRIVILEGED: ExecutePrivilegedTask,
    ctc.Command.EXECUTE_PRIVILEGED_CONTAINER: ExecutePrivilegedTaskInContainer,
    ctc.Command.TERMINATE: TerminateTask,
    ctc.Command.STOP: StopTask,
    ctc.Command.RELEASE: ReleaseTask,
}
