import logging
import socket
import tempfile
import time
from subprocess import Popen, PIPE

import requests


class SshClient(object):
    def __init__(self, host, port, user, identity_file):
        self.host = host
        self.port = port
        self.user = user
        self.identity_file = identity_file
        self.process = None

    def get_command(self, local_port):
        cmd = [
            'ssh', '-v',
            '-i', self.identity_file,
            '-L', '{}:localhost:8080'.format(local_port),
            '-p', str(self.port),
            '-o', 'ExitOnForwardFailure=yes',
            '{}@{}'.format(self.user, self.host)
        ]

        logging.info('ssh command: %s', ' '.join(cmd))
        return cmd

    @staticmethod
    def is_port_open(host, port, retries=3):
        for _ in range(retries):
            s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            s.settimeout(5)
            # noinspection PyBroadException
            try:
                s.connect((host, port))
                s.shutdown(socket.SHUT_RDWR)
                return True
            except Exception:
                logging.info('Port not open')
            finally:
                s.close()
            time.sleep(5)
        return False

    def forward_jangles_port_to(self, local_port):
        assert self.process is None

        self.process = Popen(self.get_command(local_port), stdout=PIPE, stderr=PIPE, stdin=PIPE)
        if not self.is_port_open('localhost', local_port):
            raise Exception('Local port not connectable')

    def terminate_connect(self):
        if self.process is None:
            logging.warning('you need to call port_forwarding first')
            return

        ssh_out, ssh_err = self.process.communicate()
        self.process.terminate()

        exit_code = self.process.poll()
        logging.info('ssh exit code: %s', exit_code)
        if exit_code != 0:
            logging.info('ssh out: %s', ssh_out)
            logging.info('ssh err: %s', ssh_err)

        self.process = None


class LocalJanglesMixin(object):
    LOCAL_PORT = 8080

    def make_requests(self):
        raise NotImplementedError('need to override make_requests method to do someone useful work')

    @staticmethod
    def get_ssh_key(vault_params):
        return vault_params.yav_secret.data()[vault_params.yav_secret_key]

    def write_key_to_file(self, filename, vault_params):
        with open(filename, 'w') as f:
            f.write(self.get_ssh_key(vault_params))
            f.flush()

    def make_local_jangles_requests(self, params):
        with tempfile.NamedTemporaryFile(delete=True) as ssh_key_file:
            self.write_key_to_file(ssh_key_file.name, params)
            ssh = SshClient(params.host, params.port, params.user, ssh_key_file.name)

            try:
                ssh.forward_jangles_port_to(self.LOCAL_PORT)
                self.make_requests()
            except requests.exceptions.ConnectionError:
                # todo: use /ping to answer the question
                raise Exception('localhost connection error, jangles is down?')
            except requests.exceptions.HTTPError:
                raise Exception('invalid jangles response code')
            finally:
                ssh.terminate_connect()
