from __future__ import print_function
import logging
import StringIO

from sandbox.common.errors import TaskFailure
from sandbox import sdk2

from sandbox.projects.yabs.base_bin_task import BaseBinTask


class InstallDebianPackage(BaseBinTask):
    '''Deploy debian package to host
    '''
    class Parameters(BaseBinTask.Parameters):
        with BaseBinTask.Parameters.version_and_task_resource() as version_and_task_resource:
            resource_attrs = sdk2.parameters.Dict('Filter resource by', default={'name': 'InstallDebianPackage'})

        with sdk2.parameters.Group('Installation') as ch_params:
            package = sdk2.parameters.String('Package name', required=True)
            version = sdk2.parameters.String('Package version, latest by default')
            hostname = sdk2.parameters.String('Hostname', required=True)
            port = sdk2.parameters.Integer('Port', required=True, default=22)
            username = sdk2.parameters.String('Username for ssh connection', required=True)
            ssh_key_vault = sdk2.parameters.String('Vault with ssh private key', required=True)

    def ssh_exec(self, ssh_client, command):
        logging.info('Executing command: %s', command)

        _, stdout, stderr = ssh_client.exec_command(command)

        exit_code = stdout.channel.recv_exit_status()
        logging.info("------------\nExit code: %s", exit_code)
        logging.info("\n\n------------\nSTDOUT:")
        for line in stdout:
            logging.info('%s', line)

        logging.info("\n\n------------\nSTDERR:")
        for line in stderr:
            logging.info('%s', line)

        if exit_code:
            raise TaskFailure('Failed to execute command on remote host, see logs')

    def on_execute(self):
        import paramiko

        ssh_key = sdk2.Vault.data(self.Parameters.ssh_key_vault)

        with paramiko.SSHClient() as ssh_client:
            ssh_client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
            pkey = paramiko.RSAKey.from_private_key(StringIO.StringIO(ssh_key))
            ssh_client.connect(hostname=self.Parameters.hostname, username=self.Parameters.username, pkey=pkey, timeout=300)

            # https://www.debian.org/doc/debian-policy/ch-controlfields.html#s-f-source
            for c in self.Parameters.package:
                assert c.islower() or c.isdigit() or c in '-+.', 'Invalid package name'

            # https://www.debian.org/doc/debian-policy/ch-controlfields.html#version
            for c in self.Parameters.version:
                assert c.isalnum() or c in '-.:+~', 'Invalid version'

            if self.Parameters.version:
                package = '{}={}'.format(self.Parameters.package, self.Parameters.version)
            else:
                package = self.Parameters.package

            self.ssh_exec(ssh_client, 'sudo apt-get update')
            self.ssh_exec(ssh_client, 'sudo apt-get install -y {}'.format(package))
