import sandbox.sandboxsdk.errors as sdk_errors
import sandbox.sandboxsdk.process as sdk_process
import sandbox.sandboxsdk.channel as sdk_channel
from sandbox import sdk2
from sandbox.sdk2.helpers import subprocess as sp

import os
import sys
import signal
import logging
import tempfile
import subprocess


logger = logging.getLogger(__name__)


class KeyArc(object):
    """
    Context manager for using private ssh keys.

    Usage:
        with ssh.Key("task_object", "<vault_item_owner>", "<vault_item_name>"):
            ...
    """

    def __init__(self, task, key_owner, key_name):
        """
        :param task: task object that requires key
        :param key_owner: vault item's that contains private key owner
        :param key_name vault item's that contains private key name
        """
        self.__task = task
        self.__is_sdk2 = isinstance(task, sdk2.Task)
        self.__key = (key_owner, key_name)
        self.__env = {}
        self.__agent_pid = None
        self.__env_backup = {}
        self.__is_freebsd = sys.platform.startswith("freebsd")

    def __enter__(self):
        # run ssh-agent
        if self.__is_sdk2:
            p = sp.Popen(["ssh-agent"], stdout=sp.PIPE, stderr=sp.PIPE)
            stdout, stderr = p.communicate()
        else:
            stdout, stderr = sdk_process.run_process(["ssh-agent"], outs_to_pipe=True, check=False).communicate()

        if stderr:
            raise sdk_errors.SandboxTaskUnknownError(stderr)

        self.__env_backup["SSH_AUTH_SOCK"] = os.environ.get("SSH_AUTH_SOCK")
        self.__env_backup["SSH_OPTIONS"] = os.environ.get("SSH_OPTIONS")

        # set ENV from stdout of ssh-agent
        for line in stdout.splitlines():
            name, _, value = line.partition("=")
            if _ == "=":
                value = value.split(";", 1)[0]
                self.__env[name] = value
                os.environ[name] = value
        self.__agent_pid = int(self.__env["SSH_AGENT_PID"])

        if self.__is_sdk2:
            key = sdk2.Vault.data(self.__key[1])
        else:
            key = (
                self.__task.get_vault_data(*self.__key)
                if hasattr(self.__task, "get_vault_data") else
                sdk_channel.channel.task.get_vault_data(*self.__key)
            )

        os.environ["SSH_OPTIONS"] = "{}UserKnownHostsFile=/dev/null,StrictHostKeyChecking=no".format(
            "," + os.environ["SSH_OPTIONS"] if os.environ.get("SSH_OPTIONS") else ""
        )

        # run ssh-add
        if self.__is_freebsd:
            filename = None
            try:
                handle, filename = tempfile.mkstemp()
                os.write(handle, key)
                os.close(handle)
                if self.__is_sdk2:
                    with sdk2.helpers.ProcessLog(self.__task, "ssh-add") as process_log:
                        sp.check_call(["ssh-add", filename], stderr=process_log.stderr, stdout=process_log.stdout)
                else:
                    sdk_process.run_process(["ssh-add", filename])
            finally:
                if filename:
                    os.unlink(filename)
        else:
            if self.__is_sdk2:
                with sdk2.helpers.ProcessLog(self.__task, "ssh-add") as process_log:
                    p = sp.Popen(["ssh-add", "-"], stdin=sp.PIPE, stdout=process_log.stdout, stderr=process_log.stderr)
                    p.communicate(input=key)
            else:
                p = sdk_process.run_process(
                    ["ssh-add", "-"],
                    stdin=subprocess.PIPE,
                    outs_to_pipe=True,
                    wait=False
                )
                p.communicate(input=key)

        try:
            # exit code is 1 if no keys in agent
            if self.__is_sdk2:
                p = sp.Popen(["ssh-add", "-l"], stdin=sp.PIPE, stdout=sp.PIPE, stderr=sp.PIPE)
                stdout, stderr = p.communicate()
            else:
                stdout, stderr = sdk_process.run_process(
                    ["ssh-add", "-l"], outs_to_pipe=True, check=False, wait=False
                ).communicate()

            if stdout:
                logger.info("Keys in ssh-agent: %s", stdout.strip())

            if stderr:
                logger.error("Stderr from ssh-add: %s", stderr)
        except Exception as e:
            logger.exception("Exception while list keys in ssh-agent")
            logger.exception(e)

    def __exit__(self, exc_type, exc_val, exc_tb):
        # unset ENV

        for k, v in self.__env.iteritems():
            del os.environ[k]

        for k, v in self.__env_backup.iteritems():
            if v is not None:
                os.environ[k] = v
        # kill ssh-agent
        os.kill(self.__agent_pid, signal.SIGTERM)
