import os
import copy
import json
import time
import errno
import cPickle
import logging
import datetime as dt
import itertools as it
import threading
import contextlib

import psutil

import sandbox.common.types.misc as ctm
import sandbox.common.types.task as ctt
import sandbox.common.types.client as ctc

import sandbox.agentr.client
import sandbox.agentr.types as ar_types
import sandbox.agentr.errors as ar_errors

from sandbox.common import fs as common_fs
from sandbox.common import log as common_log
from sandbox.common import rest as common_rest
from sandbox.common import config as common_config
from sandbox.common import format as common_format
from sandbox.common import patterns as common_patterns
from sandbox.common import threading as common_threading

from sandbox.client import base, errors, system


logger = logging.getLogger(__name__)


# file name prefix, files with this prefix used to pass task arguments to executor from client
EXECUTOR_ARGS_PREFIX = "executor_args"
# file name prefix, files with this prefix used to pass task status from executor to client
EXECUTOR_RESULT_PREFIX = "executor_result"


class Command(threading.Thread, base.Serializable):
    """ Base command class """

    # noinspection PyPep8Naming
    class __metaclass__(base.Serializable.__metaclass__):
        def __new__(mcs, name, bases, namespace):
            cls = super(mcs, mcs).__new__(mcs, name, bases, namespace)
            if cls.command:
                cls.command = str(cls.command)
                cls.commands[cls.command] = cls
            return cls

    SERIALIZABLE_ATTRS = (
        "id", "token", "command", "suspended_command",
        "liner", "task_id", "tasks_rid", "exec_type", "iteration",
        "arch", "kill_timeout", "killed_by_timeout", "timeout_extension",
        "timestamp_start", "args", "status", "platform", "executor_ctx"
    )

    # Reserve at least 10Mb of disk space for any kind of task execution space
    MINIMAL_SPACE_RESERVE = 10 << 20

    next_command_id = common_patterns.classproperty(lambda _, g=it.count(1): g.next())
    last_command_id = None
    task_id = None
    liner = None
    tasks_rid = None
    exec_type = None
    arch = None
    status = None
    status_message = None
    reject_type = None
    executor_ctx = {}
    command = None
    commands = {}
    suspended_command = None
    platform = None
    agentr = None
    registry = {}
    ssh_private_key = None
    common_arc_token = None
    docker_registry_token = None
    assigned_check = False

    def __new__(cls, command=None, args=None):
        cls = cls.commands.get(command)
        if cls is None:
            return
        obj = super(Command, Command).__new__(cls)
        obj.args = args or {}
        return obj

    def __init__(self, *_):
        self.id = self.next_command_id
        self.__status_lock = threading.RLock()
        self.token = self.args.pop("id", None)
        self.logger = logger
        if self.token:
            self.logger = common_log.MessageAdapter(
                logger, fmt='{{{}}} %(message)s'.format(common_format.obfuscate_token(self.token))
            )
        self.logger.info("Processing command #%d %r", self.id, self.command.upper())
        Command.last_command_id = self.id
        arch = self.args.pop("arch", None)
        self.arch = self.arch or arch
        self.kill_timeout = self.args.get("kill_timeout", None)
        self.killed_by_timeout = self.args.pop("killed_by_timeout", False)
        self.timeout_extension = False
        self.timestamp_start = None
        threading.Thread.__init__(self, name="Thread for {!r}".format(self))

    def __repr__(self):
        return "<{} :: {}>".format(type(self).__name__, self.platform and self.platform.name)

    def __nonzero__(self):
        return bool(self.liner)

    @property
    @contextlib.contextmanager
    def _status_lock(self):
        with self.__status_lock:
            old_status = self.status
            yield self
            if old_status != self.status:
                self.save_state()
                del self.executor_result

    def __setstate__(self, state):
        threading.Thread.__init__(self)

        # FIXME: SANDBOX-5068: Drop after cluster update
        tasks_res = state.pop("tasks_res", None)
        if isinstance(tasks_res, dict):
            state["tasks_rid"] = tasks_res["id"]
        if "tasks_rid" not in state:
            state["tasks_rid"] = None

        super(Command, self).__setstate__(state)
        self.__status_lock = threading.RLock()
        self.logger = logger
        if self.token:
            self.logger = common_log.MessageAdapter(
                logger, fmt='{{{}}} %(message)s'.format(common_format.obfuscate_token(self.token))
            )
            self.agentr = self._new_agentr_session()
        if self.platform is not None:
            self.platform._cmd = self
            self.platform.logger = self.logger

            from sandbox.client.platforms import (
                PortoPlatform,
                LXCPlatform,
            )
            if self.task_id:
                if isinstance(self.platform, PortoPlatform):
                    from sandbox.client.platforms.porto import PortoContainerRegistry

                    if self.platform._base_container_name:
                        self.logger.debug(
                            "Lock container %s for task #%s (job restore)",
                            self.platform._base_container_name, self.task_id
                        )
                        registry = PortoContainerRegistry()
                        registry.lock_base_container(self.platform._base_container_name)

                if isinstance(self.platform, LXCPlatform):
                    c = self.platform._container
                    if c.instance is not None:
                        i = c.instance + 1
                        ii = LXCPlatform.instances.setdefault(c.template, [0] * i)
                        if len(ii) < i:
                            ii.extend([0] * (i - len(ii)))
                        ii[c.instance] = self.task_id

    def spawn(self):
        ramdrive = self.args.get("ramdrive")
        if ramdrive:
            self.platform.ramdrive = ramdrive["type"], ramdrive["size"]
            ramdrive["path"] = self.platform.ramdrive

        if self.kill_timeout is not None:
            self.args["expiration_time"] = time.time() + self.kill_timeout

        self.logger.debug("Waiting for task session initialization in AgentR")
        self.platform.setup_agentr(self.setup_agentr())

        self.liner, executor_args = self.platform.spawn({
            "tasks_dir": (
                os.path.realpath(common_config.Registry().client.tasks.code_dir)
                if system.local_mode() else
                os.path.join(system.SERVICE_USER.home, "tasks")
            ),
            "command": self.command,
            "args": self.args,
            "session_token": self.token,
            "master_host": common_config.Registry().this.id,
            "ssh_private_key": self.ssh_private_key,
            "common_arc_token": self.common_arc_token,
            "docker_registry_token": self.docker_registry_token,
            "executable": self.platform.executable,
            "executor_path": self.platform.executor_path
        })
        log_executor_args = copy.deepcopy(executor_args)

        # Hide sensitive data
        log_executor_args["session_token"] = common_format.obfuscate_token(executor_args["session_token"])
        log_executor_args["ssh_private_key"] = "***"
        log_executor_args["common_arc_token"] = "***"
        log_executor_args["docker_registry_token"] = "***"
        try:
            service_auth = log_executor_args.get("args", {})["service_auth"]
            log_executor_args["args"]["service_auth"] = common_format.obfuscate_token(service_auth)
        except KeyError:
            pass

        self.logger.debug(
            "Subprocess [%s :: %s] [%r] spawned with arguments %r.",
            type(self).__name__, self.platform.name, self.liner, log_executor_args
        )
        if self.tasks_rid:
            # Early start of tasks code resource synchronization
            common_threading.daemon(self.agentr.resource_sync, self.tasks_rid, fastbone=False)

    @property
    def result_filename(self):
        return "{}.{}".format(
            os.path.join(common_config.Registry().client.executor.dirs.run, EXECUTOR_RESULT_PREFIX),
            common_format.obfuscate_token(self.token)
        )

    @property
    def executor_args(self):
        return "{}.{}".format(
            os.path.join(common_config.Registry().client.executor.dirs.run, EXECUTOR_ARGS_PREFIX),
            common_format.obfuscate_token(self.token)
        )

    @executor_args.setter
    def executor_args(self, args):
        args["result_filename"] = self.platform.native_path(self.result_filename)
        executor_args_path = self.executor_args
        if os.path.exists(executor_args_path):
            with system.UserPrivileges():
                self.logger.error("Drop previously created %r", self.executor_args)
                os.unlink(executor_args_path)
        with system.UserPrivileges.lock, open(executor_args_path, "wb") as _:
            _.write(cPickle.dumps(args))
        self.logger.debug("Created file %r", executor_args_path)

    @property
    def executor_result(self):
        filename = self.result_filename
        with open(filename) as _:
            contents = _.read()
            try:
                return json.loads(contents)
            except ValueError:
                # compatibility check: old executor result was a status string
                return {"status": contents}

    @executor_result.deleter
    def executor_result(self):
        filename = self.result_filename
        if os.path.exists(filename):
            os.unlink(filename)
            self.logger.debug("Removed file %r", filename)

    @property
    def session_state(self):
        if self.status == ctt.Status.SUSPENDED:
            return ctt.SessionState.SUSPENDED
        if self.status is errors.SessionExpired:
            return ctt.SessionState.EXPIRED
        if self.status == ctt.Status.STOPPED:
            return ctt.SessionState.ABORTED
        return ctt.SessionState.ACTIVE

    def poll(self):
        """
        Wait for task completion

        :return: task status
        """
        job_expired = self.status == errors.SessionExpired
        # Canceled (STOPPED) job has to finish on its own
        job_finished = self.status and self.command != ctc.Command.SUSPEND and self.status != ctt.Status.STOPPED

        if job_expired or job_finished:
            return self.status

        if (
            self.command == ctc.Command.SUSPEND and self.status != ctt.Status.STOPPED or
            self.liner and not self.liner.poll
        ):
            if self.command == ctc.Command.STOP and not dt.datetime.now().second % 10:
                self.logger.debug("There was a command to stop, but still no progress")
            return  # job is either suspended or still running

        liner, self.liner = self.liner, None
        self.logger.debug(
            "Subprocess %r finished with status %r%s.",
            liner,
            liner and liner.poll,
            " after graceful stop" if self.status == ctt.Status.STOPPED else "",
        )
        if self.status == ctt.Status.STOPPED:
            if liner:
                # liner is still running after graceful kill of executor
                liner.terminate()
            return self.status
        try:
            self.executor_ctx = self.executor_result
            self.status = self.executor_ctx.get("status")
            self.status_message = self.executor_ctx.get("status_message")
            self.logger.info("Executor reported status %r.", self.status)
            if self.command == ctc.Command.EXECUTE and self.agentr:
                self.agentr.monitoring_finish()
            if not self.status:
                self.status = errors.ExecutorFailed
            elif self.status == "SessionExpired":
                self.status = errors.SessionExpired
            elif self.status == "ContainerError":
                self.status = errors.ContainerError
            elif self.status not in ctt.Status:
                raise Exception("Wrong status received from executor: {!r}".format(self.status))
        except Exception:
            self.logger.exception("Error reading executor status from %s", self.result_filename)
            logdir = self.agentr.logdir if self.agentr else None

            def _read_std_file_content(filename):
                if logdir is not None:
                    path = os.path.join(logdir, filename)
                    if os.path.exists(path):
                        return open(path, "r").read()
                return ""

            if liner is not None:
                stdout = liner.stdout or _read_std_file_content(ar_types.STDOUT_FILENAME)
                stderr = liner.stderr or _read_std_file_content(ar_types.STDERR_FILENAME)
            else:
                self.logger.error("executor output is lost")
                stdout = stderr = ""
            if stdout or stderr:
                self.logger.error("executor stdout: %s", stdout)
                self.logger.error("executor stderr: %s", stderr)

            # Get last line of the traceback as a status message
            traceback = stderr.strip().splitlines() or ["<output is empty>"]
            self.status_message = "Executor failed: " + traceback[-1]

            if self.timestamp_start < psutil.boot_time():
                self.status = errors.InfraError
            else:
                self.status = errors.ExecutorFailed

        return self.status

    def cancel(self, status=ctt.Status.STOPPED, reason=None):
        """
        Cancel execution of command

        :param status: switch task to this status unless timeout has occurred
        """
        self.logger.info("Task cancelling requested")
        with self._status_lock:
            self.status = ctt.Status.TIMEOUT if self.killed_by_timeout else status
            if reason:
                self.status_message = reason

        if self.status == ctt.Status.STOPPED and self.command == ctc.Command.STOP:
            self.logger.debug("Task cancelling requested but it's already stopping")
            return

        if self.command == ctc.Command.EXECUTE and self.agentr:
            self.agentr.monitoring_finish()
        self.logger.info("Terminating executor process %r", self.liner)
        try:
            if self.liner:
                with system.UserPrivileges():
                    self.platform.resume()
                terminate_liner = True
                if self.status == ctt.Status.STOPPED:
                    executor_pid = self.agentr.executor_pid
                    if executor_pid:
                        terminate_liner = False
                        self.logger.info("Stopping process #%s gracefully", executor_pid)
                        with system.UserPrivileges():
                            self.platform.sigterm(executor_pid)
                if terminate_liner:
                    self.liner.terminate()
                    self.liner = None
                self.platform.cancel()
        except Exception:
            self.logger.exception("Cannot kill process %r", self.liner)

    def on_terminate(self):
        """ Called on task terminate """

    def get_status(self):
        """
        Get command status

        :return: status
        :rtype: dict
        """
        try:
            return {
                "command_id": self.id,
                "command": self.command,
                "status": self.status,
                "killed_by_timeout": self.killed_by_timeout
            }
        finally:
            self.killed_by_timeout = False

    def check_timeout(self):
        if self.kill_timeout is None:
            return
        current_time = time.time()
        self.logger.info(
            "Check task #%s with execution time: %d, timeout: %d",
            self.task_id,
            current_time - self.timestamp_start,
            self.kill_timeout
        )
        if current_time - self.timestamp_start > self.kill_timeout:
            # check if we're in single-executor mode and task is already terminating
            if (
                not self.timeout_extension and self.args.get("subsequent") and
                not self.agentr.monitoring_status()
            ):
                delay = common_config.Registry().common.task.execution.terminate_timeout
                self.logger.info("Task #%s is terminating, delaying timeout by %s", self.task_id, delay)
                self.kill_timeout += delay
                self.timeout_extension = True
                return

            self.killed_by_timeout = True
            self.cancel()
            self.logger.warning("Task #%s was killed by timeout %d", self.task_id, self.kill_timeout)
            self.check_timeout = lambda *_: None
        elif (
            not self.assigned_check and
            current_time - self.timestamp_start > common_config.Registry().client.assigned_task_timeout
        ):
            self.assigned_check = True
            if self.task_id:
                try:
                    rest = common_rest.Client(auth=self.token, component=ctm.Component.CLIENT)
                    if rest.task.current.read().get("status") == ctt.Status.ASSIGNED:
                        self.cancel(status=errors.InfraError, reason="Task expired in ASSIGNED status")
                        self.reject_type = ctc.RejectionReason.ASSIGNED_TIMEOUT
                        self.logger.warning(
                            "Task #%s was killed by timeout of ASSIGNED stage %d",
                            self.task_id, common_config.Registry().client.assigned_task_timeout
                        )
                        self.check_timeout = lambda *_: None
                        return True
                except:
                    logger.warning("Can't get information about task %s", self.task_id)

        return self.killed_by_timeout

    @property
    def shell_command(self):
        return self.platform.shell_command if self.platform and self.command == ctc.Command.SUSPEND else None

    @property
    def ps_command(self):
        return self.platform.ps_command if self.platform else None

    @property
    def attach_command(self):
        return self.platform.attach_command if self.platform else None

    def reset_shell_command(self):
        self.agentr.fileserver_meta = self.agentr.fileserver_meta._replace(
            shell_command=self.shell_command,
            attach_command=self.attach_command
        )

    def suspend(self):
        with system.UserPrivileges():
            self.platform.suspend()
        self.suspended_command, self.command = self.command, str(ctc.Command.SUSPEND)
        self.reset_shell_command()
        common_rest.Client(auth=self.token, component=ctm.Component.CLIENT).task.current.audit({
            "status": ctt.Status.SUSPENDED,
            "message": "Task has been suspended",
        })
        self.status = ctt.Status.SUSPENDED

    def resume(self):
        with system.UserPrivileges():
            self.platform.resume()
        self.suspended_command, self.command = None, self.suspended_command
        self.reset_shell_command()
        common_rest.Client(auth=self.token, component=ctm.Component.CLIENT).task.current.audit({
            "status": ctt.Status.EXECUTING,
            "message": "Task has been resumed",
        })
        self.status = None

    def update(self, options):
        pass

    def check_session_state(self, state, options):
        with self._status_lock:
            if state == ctt.SessionState.UPDATED:
                self.update(options)
            if state == ctt.SessionState.SUSPENDED and self.status != ctt.Status.SUSPENDED:
                self.suspend()
            if state == ctt.SessionState.ACTIVE and self.status == ctt.Status.SUSPENDED:
                self.resume()
            if state == ctt.SessionState.EXPIRED and self.status is not errors.SessionExpired:
                self.cancel(errors.SessionExpired)
            if state == ctt.SessionState.ABORTED and self.status != ctt.Status.STOPPED:
                self.cancel(ctt.Status.STOPPED)

    def save_state(self):
        if self.agentr:
            try:
                self.agentr.state = self.encode()
            except ar_errors.NoTaskSession:
                pass

    @property
    def reserved_space(self):
        return self.args.get("reserved_space", self.MINIMAL_SPACE_RESERVE)

    def audit(self, status, message, force=True, admin=True, expected_status=None):
        if admin:
            from sandbox.client.pinger import PingSandboxServerThread
            rest = PingSandboxServerThread().rest
        else:
            rest = common_rest.Client(auth=self.token, component=ctm.Component.CLIENT)
        rest.task[self.task_id].audit({
            "status": status, "force": force, "message": message, "expected_status": expected_status
        })

    def _new_agentr_session(self):
        return sandbox.agentr.client.Session(
            self.token, self.iteration + 1, self.reserved_space,
            self.logger, local=True
        )

    def setup_agentr(self):
        if self.agentr is not None:
            return self.agentr
        self.agentr = self._new_agentr_session()
        try:
            self.agentr(None)  # Ensure connection associated with the task session
            return self.agentr
        except ar_errors.UserDismissed as ex:
            logger.error("Dismissed user detected: %s", ex)
            # Use admin API since session token is invalid
            self.status = errors.SessionExpired
            self.audit(ctt.Status.DELETED, "Task is owned by dismissed user")
            raise

    def start(self):
        Command.registry[self.token] = self
        return super(Command, self).start()

    def run(self):
        tick = common_config.Registry().client.min_tick
        from sandbox.client.pinger import PingSandboxServerThread
        pt = PingSandboxServerThread()
        stopping, events = pt.stopped, pt._events
        initial = self.liner is None and not self.status

        last_tm_check, last_status = time.time(), None
        self.logger.info(
            "Job %r thread %s.",
            common_format.obfuscate_token(self.token), ("initially started" if initial else "restarted")
        )
        try:
            if initial:
                # Early start task session registration
                common_threading.daemon(self._new_agentr_session)
                self.save_state()
                self.spawn()
                if not self.timestamp_start:
                    self.timestamp_start = time.time()
                self.save_state()

            if self.task_id:
                with system.UserPrivileges():
                    cgroup = self.platform.cgroup and self.platform.cgroup.name
                fs_meta = self.agentr.fileserver_meta
                self.agentr.fileserver_meta = ar_types.FileServerMeta(
                    pid=fs_meta and fs_meta.pid,
                    attach_command=self.attach_command,
                    shell_command=self.shell_command,  # updated on resume/suspend later
                    ps_command=self.ps_command,
                    cgroup=cgroup,
                )

            while True:
                with self._status_lock:
                    if self.poll():
                        break
                stopping.wait(tick)
                now = time.time()
                with self._status_lock:
                    status_changed = last_status != self.status
                    time_to_check = (
                        now - last_tm_check > min(common_config.Registry().client.idle_time, self.kill_timeout)
                    )

                if status_changed or time_to_check:
                    if self.check_timeout():
                        break
                    last_tm_check = now
                if stopping.is_set():
                    return
                last_status = self.status

        except BaseException as ex:
            self.logger.exception("Error on job %r execution", common_format.obfuscate_token(self.token))
            if isinstance(ex, OSError) and ex.errno == errno.ENOSPC:
                self.logger.error("No space left on the root partition. Shutting down.")
                from sandbox.client.commands import ShutdownClientCommand
                ShutdownClientCommand.emergency_shutdown()
            # TODO: review after SANDBOX-5901
            elif isinstance(ex, errors.CriticalInfraError):
                self.logger.error("Critical error occurred: %s. Rebooting.", ex)
                self.status = errors.CriticalInfraError
                self.send_to_current_client(ctc.ReloadCommand.REBOOT, comment=str(ex))
            elif isinstance(ex, errors.InfraError):
                self.status = errors.InfraError
            elif isinstance(ex, ar_errors.NoTaskSession):
                self.status = self.status or errors.SessionExpired
            elif isinstance(ex, errors.InvalidJob):
                self.status = errors.InvalidJob
                self.audit(ctt.Status.EXCEPTION, str(ex))
            self.status = self.status or errors.ExecutorFailed
            self.status_message = self.status_message or str(ex)
        self.logger.info("Job %r thread finished.", common_format.obfuscate_token(self.token))

        from sandbox.client.pinger import Event
        events.put((Event.JOB_COMPLETED, self))

    @staticmethod
    def send_to_current_client(cmd, comment=""):
        """
        Send a reload command to the current client

        :type cmd: `ctc.ReloadCommand`
        :type comment: str
        """

        current_client_id = common_config.Registry().this.id
        oauth_token_path = common_config.Registry().client.auth.oauth_token
        token = common_fs.read_settings_value_from_file(oauth_token_path) if oauth_token_path else None

        common_rest.Client(auth=token, component=ctm.Component.CLIENT).batch.clients[str(cmd)].update({
            "comment": comment,
            "id": [current_client_id]
        })
