import socket
from contextlib import contextmanager
from io import StringIO

import paramiko
from gevent.pool import Group

from sepelib.core.exceptions import LogicalError
from walle.credentials import get_ssh_credentials, Credentials
from walle.errors import RecoverableError

# Empirical data:
# * expected maximum connection time is about five seconds,
# * expected maximum command execution time is about one second.
SSH_CONNECTION_TIMEOUT = 10
SSH_COMMAND_TIMEOUT = 10


class SshError(RecoverableError):
    pass


class SshConnectionFailedError(SshError):
    def __init__(self, message, *args):
        super().__init__("SSH connection failed: {}.".format(message), *args)


class SshAuthenticationError(SshError):
    def __init__(self, message):
        super().__init__("SSH connection failed: {}".format(message))


class SshCommandFailedError(SshError):
    def __init__(self, code, error_lines):
        message = "".join(error_lines)
        super().__init__("SSH command failed with code {}: {}", code, message)


class SshCommandError(SshError):
    def __init__(self, message, output):
        output = "".join(output)
        super().__init__("SSH command failed: {}.".format(message), output)


class SshCommandTimeoutError(SshError):
    def __init__(self, message):
        super().__init__("SSH command failed: {}.".format(message))


class SshClient:
    def __init__(self, hostname):
        self.hostname = hostname
        self._client = paramiko.SSHClient()

    def connect(self):
        client = self._client
        hostname = self.hostname

        # I'm sorry to say that, but we can eventually redeploy server and this must not be a problem.
        client.get_host_keys().clear()
        client.set_missing_host_key_policy(paramiko.AutoAddPolicy())

        try:
            client.connect(**self._get_connection_params(hostname))
        except SshAuthenticationError:
            raise
        except paramiko.AuthenticationException as e:
            raise SshAuthenticationError(e)
        except Exception as e:
            raise SshConnectionFailedError(e)

    def issue_reboot_command(self):
        """Issue a reboot command on the remote server.

        As any other ACPI commands, this does not actually mean that server will be rebooted.
        You may need to check uptime or perform some other verification."""

        # won't use --force here as it does not close connection properly which causes errors.
        self.execute("sudo -n shutdown -r +1")

    def issue_kexec_reboot_command(self):
        """Issue a reboot server via kexec helper

        As any other ACPI commands, this does not actually mean that server will be rebooted.
        You may need to check uptime or perform some other verification."""

        self.execute("sudo -n /usr/sbin/wall-e.kexec-helper --reboot </dev/null &>/dev/null &")

    def issue_poweroff_command(self):
        """Issue a poweroff command on the remote server.

        As any other ACPI commands, this does not actually mean that server will be powered off.
        You may need to perform some kind of verification to ensure its actual state."""

        # won't use --force here as it does not close connection properly which causes errors.
        self.execute("sudo -n shutdown -h +1")

    def get_boot_id(self):
        """Get boot id of a host.
        This feature first appears in linux kernel 2.3.16, hope we don't have older kernels."""
        result = self.execute("cat /proc/sys/kernel/random/boot_id")
        if not len(result) or not result[0]:
            raise SshCommandError("Got invalid boot_id from host: {}", result)

        return result[0].strip()

    def execute(self, command):
        """
        Execute given command on the remote host. Return list of stdout lines.
        Raise SshCommandFailedException if command execution failed.
        """
        out, err, code = self._execute_command(command)
        if code == 0:
            return out
        else:
            raise SshCommandFailedError(code, "".join(err))

    def _get_connection_params(self, host):
        connect = {"hostname": host, "timeout": SSH_CONNECTION_TIMEOUT}
        connect.update(self._get_credentials())

        return connect

    def _get_credentials(self):
        try:
            crd = get_ssh_credentials()
        except Credentials.DoesNotExist:
            raise SshAuthenticationError("Can't find credentials for ssh connection.")

        credentials = {"username": crd.public, "pkey": paramiko.RSAKey(file_obj=StringIO(crd.private))}
        return credentials

    def _execute_command(self, command, bufsize=-1, timeout=None):
        """
        Execute given command on the remote server.
        Return list of stdout output lines, stderr output lines and exit status as a number.
        """

        if timeout is None:
            timeout = SSH_COMMAND_TIMEOUT

        with self.paramiko_errors():
            stdin, stdout, stderr = self._client.exec_command(command, bufsize, timeout, get_pty=False)

        if not (stdin.channel is stdout.channel and stdin.channel is stderr.channel):
            raise LogicalError()

        out = []
        err = []

        stdin.close()
        greenlets = Group()
        try:
            with stdin.channel as chan:

                greenlets.spawn(self._read_stream, out, stdout)
                greenlets.spawn(self._read_stream, err, stderr)
                greenlets.spawn(chan.recv_exit_status)
                greenlets.join(raise_error=True, timeout=timeout)

                if not chan.exit_status_ready():
                    # in case of any other error .join() would raise an exception.
                    raise SshCommandTimeoutError("command execution timeout")
        finally:
            # Previous join might leave greenlets staled.
            # They all finish on channel close, just join them.
            greenlets.join()

        status = chan.recv_exit_status()
        return out, err, status

    def _read_stream(self, buf, stream):
        """Wrap the stream, close it automatically, convert timeout error into proper SshClient exception."""
        with stream:
            with self.paramiko_errors():
                buf.extend(stream)

    @contextmanager
    def paramiko_errors(self):
        try:
            yield
        except socket.timeout:
            raise SshCommandTimeoutError("command execution timeout")
        except Exception as e:
            message = str(e)
            raise SshConnectionFailedError(message)

    def __enter__(self):
        self.connect()
        return self

    # noinspection PyUnusedLocal
    def __exit__(self, exc_type, exc_val, exc_tb):
        self._client.close()


def get_client(hostname):
    return SshClient(hostname)
