import sys
import six
import socket
import traceback
import threading
from code import InteractiveConsole

import struct

StringIO = six.moves.cStringIO


class PyrasiteIPC(object):
    """Pyrasite Inter-Python Communication.

    This object is used in communicating to or from another Python process.

    It can perform a variety of tasks:

    - Injection of the :class:`pyrasite.ReversePythonConnection` payload via
      :meth:`PyrasiteIPC.connect()`, which causes the process to connect back
      to a port that we are listening on. The connection with the process is
      then available via `self.sock`.

    - Python code can then be executed in the process using
      :meth:`PyrasiteIPC.cmd`. Both stdout and stderr are returned.

    - Low-level communication with the process, both reliably (via a length
      header) or unreliably (raw data, ideal for use with netcat) with a
      :class:`pyrasite.ReversePythonConnection` payload, via
      :meth:`PyrasiteIPC.send(data)` and :meth:`PyrasiteIPC.recv(data)`.

    The :class:`PyrasiteIPC` is subclassed by
    :class:`pyrasite.tools.gui.Process` as well as
    :class:`pyrasite.reverse.ReverseConnection`.

    """
    # Allow subclasses to disable this and just send/receive raw data, as
    # opposed to prepending a length header, to ensure reliability. The reason
    # to enable 'unreliable' connections is so we can still use our reverse
    # shell payloads with netcat.
    reliable = True

    def __init__(self):
        super(PyrasiteIPC, self).__init__()
        self.sock = None
        self.server_sock = None
        self.hostname = None
        self.port = None
        self._title = None

    def __enter__(self):
        self.connect()
        return self

    def __exit__(self, *args, **kwargs):
        self.close()

    def connect(self):
        """
        Setup a communication socket with the process by injecting
        a reverse subshell and having it connect back to us.
        """
        self.listen()
        self.wait()

    def listen(self):
        """Listen on a random port"""
        for res in socket.getaddrinfo(
            self.hostname or 'localhost',
            self.port or None,
            socket.AF_UNSPEC,
            socket.SOCK_STREAM,
            0,
            0
        ):
            af, socktype, proto, canonname, sa = res
            try:
                self.server_sock = socket.socket(af, socktype, proto)
                try:
                    self.server_sock.bind(sa)
                    self.server_sock.listen(1)
                except socket.error:
                    self.server_sock.close()
                    self.server_sock = None
                    continue
            except socket.error:
                self.server_sock = None
                continue
            break

        if not self.server_sock:
            raise Exception('pyrasite was unable to setup a ' +
                            'local server socket')
        else:
            self.hostname, self.port = self.server_sock.getsockname()[0:2]

    def wait(self):
        """Wait for the injected payload to connect back to us"""
        (clientsocket, address) = self.server_sock.accept()
        self.sock = clientsocket
        self.sock.settimeout(5)
        self.address = address

    def cmd(self, cmd):
        """
        Send a python command to exec in the process and return the output
        """
        self.send(cmd + '\n')
        return self.recv()

    def send(self, data):
        """Send arbitrary data to the process via self.sock"""
        header = ''.encode('utf-8')
        data = data.encode('utf-8')
        if self.reliable:
            header = struct.pack('<L', len(data))
        self.sock.sendall(header + data)

    def recv(self):
        """Receive a command from a given socket"""
        if self.reliable:
            header_data = self.recv_bytes(4)
            if len(header_data) == 4:
                msg_len = struct.unpack('<L', header_data)[0]
                data = self.recv_bytes(msg_len).decode('utf-8')
                if len(data) == msg_len:
                    return data
        else:
            return self.sock.recv(4096).decode('utf-8')

    def recv_bytes(self, n):
        """Receive n bytes from a socket"""
        data = ''.encode('utf-8')
        while len(data) < n:
            chunk = self.sock.recv(n - len(data))
            if not chunk:
                break
            data += chunk
        return data

    def close(self):
        if self.sock:
            self.sock.close()
        if getattr(self, 'server_sock', None):
            self.server_sock.close()


class DistantInteractiveConsole(InteractiveConsole):
    def __init__(self, ipc):
        InteractiveConsole.__init__(self, globals())

        self.ipc = ipc
        self.set_buffer()

    def set_buffer(self):
        self.out_buffer = StringIO()
        sys.stdout = sys.stderr = self.out_buffer

    def unset_buffer(self):
        sys.stdout, sys.stderr = sys.__stdout__, sys.__stderr__
        value = self.out_buffer.getvalue()
        self.out_buffer.close()

        return value

    def raw_input(self, prompt=""):
        output = self.unset_buffer()
        # payload format: 'prompt' ? '\n' 'output'
        self.ipc.send('\n'.join((prompt, output)))

        cmd = self.ipc.recv()

        self.set_buffer()

        return cmd


class ReversePythonShell(threading.Thread, PyrasiteIPC):
    """A reverse Python shell that behaves like Python interactive interpreter.

    """

    host = 'localhost'
    port = 9001
    reliable = True
    daemon = True

    def __init__(self, host=None, port=None):
        super(ReversePythonShell, self).__init__()
        if host is not None:
            self.host = host
        if port is not None:
            self.port = port

    def run(self):
        stdout, stderr = sys.stdout, sys.stderr
        try:
            for res in socket.getaddrinfo(self.host, self.port,
                                          socket.AF_UNSPEC, socket.SOCK_STREAM):
                af, socktype, proto, canonname, sa = res
                try:
                    self.sock = socket.socket(af, socktype, proto)
                    try:
                        self.sock.connect(sa)
                    except socket.error:
                        self.sock.close()
                        self.sock = None
                        continue
                except socket.error:
                    self.sock = None
                    continue
                break

            if not self.sock:
                raise Exception('pyrasite cannot establish reverse ' +
                                'connection to %s:%d' % (self.host, self.port))

            DistantInteractiveConsole(self).interact()

        except SystemExit:
            pass
        except:
            traceback.print_exc(file=stderr)

        sys.stdout, sys.stderr = stdout, stderr
        self.close()
