from __future__ import absolute_import

import os
import signal
import logging
import tempfile

import six

from sandbox.common import patterns
from sandbox.common import errors as common_errors

from sandbox.sdk2 import legacy, paths
from sandbox.sdk2 import task as sdk2_task
from sandbox.sdk2.helpers import subprocess as sp


__all__ = ("SshAgentNotAvailable", "SshAgent", "Key")


logger = logging.getLogger(__name__)


class SshAgentNotAvailable(Exception):
    pass


class SshAgent(six.with_metaclass(patterns.SingletonMeta, object)):
    def __init__(self):
        self._env = {}
        self._env_backup = {}
        self._keys = {}
        self.start()

    @property
    def pid(self):
        return int(self._env["SSH_AGENT_PID"])

    def start(self):
        if paths.which("ssh-agent") is None:
            raise SshAgentNotAvailable("ssh-agent binary is not available")

        self._env_backup["SSH_AUTH_SOCK"] = os.environ.get("SSH_AUTH_SOCK")
        self._env_backup["SSH_OPTIONS"] = os.environ.get("SSH_OPTIONS")

        # set ENV from stdout of ssh-agent
        for line in self._run(["ssh-agent"]).splitlines():
            name, _, value = line.partition(b"=")
            if _ == b"=":
                value = value.split(b";", 1)[0]
                self._env[six.ensure_str(name)] = six.ensure_str(value)
                os.environ[six.ensure_str(name)] = six.ensure_str(value)

        os.environ["SSH_OPTIONS"] = "{}UserKnownHostsFile=/dev/null,StrictHostKeyChecking=no".format(
            "," + os.environ["SSH_OPTIONS"] if os.environ.get("SSH_OPTIONS") else ""
        )

    def add(self, key):
        """
        Add private key to ssh-agent.
        If key is already present, increment ref counter.

        :param key: Private key to be added to ssh-agent.
        :return: Public part of the key.
        """
        key_pub = self._key_pub(key)

        if key_pub in self._keys:
            self._keys[key_pub] += 1
        else:
            self._run(["ssh-add", "-"], stdin=six.ensure_binary(key))
            self._keys[key_pub] = 1

        return key_pub

    def remove(self, key_pub):
        """
        Remove corresponding private key from ssh-agent.
        If key was added more than once, decrement ref counter.

        :param key_pub: Public key to identify private key.
        """
        if key_pub not in self._keys:
            raise common_errors.TaskError("Private key not found, public part: {}".format(key_pub))

        if self._keys[key_pub] > 1:
            self._keys[key_pub] -= 1
        else:
            with tempfile.NamedTemporaryFile() as f:
                f.write(six.ensure_binary(key_pub))
                f.flush()
                self._run(["ssh-add", "-d", f.name])
            self._keys.pop(key_pub)

    def print_keys(self):
        keys = self._run(["ssh-add", "-l"]).splitlines()
        if keys:
            logger.info("ssh-agent keys:")
            for key in keys:
                logger.info("%s", key)
        else:
            logger.info("ssh-agent (pid %d) is empty", self.pid)

    def kill(self):
        for k, v in self._env.items():
            os.environ.pop(k, None)

        for k, v in self._env_backup.items():
            if v is not None:
                os.environ[k] = v

        os.kill(self.pid, signal.SIGTERM)

    def _key_pub(self, key):
        with tempfile.NamedTemporaryFile() as f:
            f.write(six.ensure_binary(key))
            f.flush()
            return self._run(["ssh-keygen", "-y", "-f", f.name])

    @staticmethod
    def _run(cmd, stdin=None):
        p = sp.Popen(cmd, stdout=sp.PIPE, stderr=sp.PIPE, stdin=sp.PIPE if stdin else None)
        stdout, stderr = p.communicate(stdin)

        # Listing keys from empty ssh-agent results in exit code 1
        if stdout.strip() == "The agent has no identities.":
            return ""

        if p.returncode:
            message = stderr.strip() + b"\n" + stdout.strip()
            raise common_errors.TaskError(message.strip())

        return stdout


class Key(object):
    """
    Context manager for using private ssh keys.
    Available initialisation with either sandbox vault, or a private ssh key by itself.

    Usage1 (data from yav as example):
        with ssh.Key(private_part=secret.data()["ssh-key"]):
            do_secure_operations()
    Usage2 (from sandbox vault):
        key = ssh.Key(key_owner="<vault_item_owner>", key_name="<vault_item_name>"):
        with key:
            do_secure_operations()

    :param task: task object. Should be omitted on initialisation. By default, it's taken from the current context
    :param key_owner: string sandbox vault key owner, if key_name not specified this value
        is being considered as the key name
    :param key_name: string sandbox vault key name
    :param private_part: string with private ssh key
    """

    def __init__(self, task=None, key_owner=None, key_name=None, private_part=None):
        if private_part and not key_owner and not key_name:
            self.key = private_part
        elif (key_owner or key_name) and not private_part:
            if task is None:
                task = sdk2_task.Task.current

            self.key = (
                task.get_vault_data(key_owner, key_name)
                if hasattr(task, "get_vault_data") else
                legacy.current_task.get_vault_data(key_owner, key_name)
            )
        else:
            raise ValueError(
                "Key must be initialised with either key_owner and (optional) key_name, "
                "or explicitly with the private_part"
            )

        self._key_pub = None
        self._ssh_agent = SshAgent()

    def __enter__(self):
        self._key_pub = self._ssh_agent.add(self.key)
        self._ssh_agent.print_keys()

    def __exit__(self, exc_type, exc_val, exc_tb):
        self._ssh_agent.remove(self._key_pub)
        self._ssh_agent.print_keys()
