import os
import sys
import json
import struct
import socket
import logging
import traceback
import threading
from io import StringIO
from code import InteractiveConsole

from library.python import resource
from kernel.util.sys.getpeerid import getpeerid


class IPCLocalServer:
    def __init__(self):
        super().__init__()
        self.sock = None
        self.server_sock = None
        self.hostname = None
        self.port = None
        self._title = None

    def __enter__(self):
        self.connect()
        return self

    def __exit__(self, *args, **kwargs):
        self.close()

    def connect(self):
        """
        Setup a communication socket with the process by injecting
        a reverse subshell and having it connect back to us.
        """
        self.listen()
        self.wait()

    def listen(self):
        """Listen on some port"""
        for res in socket.getaddrinfo(
            self.hostname or 'localhost',
            self.port or None,
            socket.AF_UNSPEC,
            socket.SOCK_STREAM,
            0,
            0,
        ):
            af, socktype, proto, canonname, sa = res
            try:
                self.server_sock = socket.socket(af, socktype, proto)
                try:
                    self.server_sock.bind(sa)
                    self.server_sock.listen(1)
                except socket.error:
                    self.server_sock.close()
                    self.server_sock = None
                    continue
            except socket.error:
                self.server_sock = None
                continue
            break

        if not self.server_sock:
            raise Exception('failed to setup a local server socket')
        else:
            self.hostname, self.port = self.server_sock.getsockname()[:2]

    def wait(self):
        """Wait for the injected payload to connect back to us"""
        (clientsocket, address) = self.server_sock.accept()
        self.sock = clientsocket
        self.sock.settimeout(5)
        self.address = address

    def cmd(self, cmd):
        """
        Send a python command to exec in the process and return the output
        """
        self.send(cmd + '\n')
        return self.recv()

    def send(self, data):
        """Send arbitrary data to the process via self.sock"""
        data = data.encode('utf-8')
        header = struct.pack('<L', len(data))
        self.sock.sendall(header + data)

    def recv(self):
        """Receive a command from a given socket"""
        header_data = self.recv_bytes(4)
        if len(header_data) == 4:
            msg_len = struct.unpack('<L', header_data)[0]
            data = self.recv_bytes(msg_len).decode('utf-8')
            if len(data) == msg_len:
                return data

    def recv_bytes(self, n):
        """Receive n bytes from a socket"""
        data = b''
        while len(data) < n:
            chunk = self.sock.recv(n - len(data))
            if not chunk:
                break
            data += chunk
        return data

    def close(self):
        if self.sock:
            self.sock.close()
        if getattr(self, 'server_sock', None):
            self.server_sock.close()


def _lookup(typename, **kwargs):
    """
    Lookup object with type `typename` in process objects.

    :param str typename: class name to lookup
    :param dict(str,object) kwargs: object attributes to filter by
    :rtype: None or object or list(object)
    :return: item or list of items matching the criteria

    """
    def _filter(obj):
        if type(obj).__name__ != typename:
            return False
        for k, v in list(kwargs.items()):
            if getattr(obj, k, None) != v:
                return False
        return True

    import gc
    items = list(filter(_filter, gc.get_objects()))
    if len(items) == 1:
        return items[0]
    elif not items:
        return None
    else:
        return items


class DistantInteractiveConsole(InteractiveConsole):
    def __init__(self, ipc):
        super().__init__({
            '__name__': '__console__',
            '__doc__': None,
            'l': _lookup,
            '__builtins__': __builtins__
        })

        self.ipc = ipc
        self.set_buffer()

    def runsource(self, source, filename='<input>', symbol="single"):
        if source.startswith("+p "):
            filename = '/ipc-snippets/print_' + source.split(' ', 1)[1] + '.py'
            script = resource.find(filename)
            if not script:
                self.write(f"Invalid command: {filename}")
                return False
            try:
                code_object = compile(script, filename, 'exec')
                exec(code_object, self.locals)
            except SystemExit:
                raise
            except BaseException:
                self.showtraceback()
            return False
        return InteractiveConsole.runsource(self, source, filename, symbol)

    def set_buffer(self):
        self.out_buffer = StringIO()
        sys.stdout = sys.stderr = self.out_buffer

    def unset_buffer(self):
        sys.stdout, sys.stderr = sys.__stdout__, sys.__stderr__
        value = self.out_buffer.getvalue()
        self.out_buffer.close()

        return value

    def raw_input(self, prompt=""):
        output = self.unset_buffer()
        # payload format: 'prompt' ? '\n' 'output'
        self.ipc.send('\n'.join((prompt, output)))

        cmd = self.ipc.recv()

        self.set_buffer()

        return cmd


class ReversePythonShell(threading.Thread, IPCLocalServer):
    """
    A reverse Python shell that behaves like Python interactive interpreter.
    """

    host = 'localhost'
    port = 9001
    daemon = True

    def __init__(self, host=None, port=None):
        super().__init__()
        if host is not None:
            self.host = host
        if port is not None:
            self.port = port

    def run(self):
        stdout, stderr = sys.stdout, sys.stderr
        try:
            for res in socket.getaddrinfo(self.host, self.port,
                                          socket.AF_UNSPEC, socket.SOCK_STREAM):
                af, socktype, proto, canonname, sa = res
                try:
                    self.sock = socket.socket(af, socktype, proto)
                    try:
                        self.sock.connect(sa)
                    except socket.error:
                        self.sock.close()
                        self.sock = None
                        continue
                except socket.error:
                    self.sock = None
                    continue
                break

            if not self.sock:
                raise Exception(f'cannot establish reverse connection to {self.host}:{self.port}')

            DistantInteractiveConsole(self).interact(
                "Type '+p stack' to see some useful dump.\n"
                "Use `l(typename, **kwargs)` to find in-memory instances."
            )

        except SystemExit:
            pass
        except BaseException:
            traceback.print_exc(file=stderr)
        finally:
            sys.stdout, sys.stderr = stdout, stderr
        self.close()


def ensure_peer_allowed(sock):
    peer_info = getpeerid(sock, True)
    if peer_info[0] != 0:  # not root
        raise Exception(f"User id {peer_info[0]} is not allowed")


class ReplServer(threading.Thread):
    def __init__(self, sock_path, *args, logger=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.daemon = True
        self.__logger = logger or logging.getLogger('repl')
        self.__sock_path = sock_path

    def run(self):
        sock = socket.socket(socket.AF_UNIX)
        sock.bind(self.__sock_path)
        self.__logger.info("bound to %r", self.__sock_path)
        sock.listen(1)

        try:
            while True:
                peer, _ = sock.accept()
                peer.settimeout(60)
                try:
                    ensure_peer_allowed(peer)
                    data = peer.recv(1024).decode('utf-8')
                    msg = json.loads(data)
                    ReversePythonShell("localhost", msg['port']).start()
                except Exception as e:
                    peer.sendall(json.dumps({"success": False, "message": str(e)}).encode('utf-8'))
                else:
                    peer.sendall(json.dumps({"success": True}).encode('utf-8'))
                finally:
                    peer.close()
        finally:
            sock.close()

    def stop_repl(self):
        try:
            os.unlink(self.__sock_path)
        except EnvironmentError:
            pass


def repl_client(sock_path):
    ipc = IPCLocalServer()
    ipc.hostname = '::'
    ipc.port = 9091
    ipc.listen()
    print(f"bound to [{ipc.hostname}]::{ipc.port}")

    sock = socket.socket(socket.AF_UNIX)
    sock.settimeout(120)
    sock.connect(sock_path)
    sock.sendall(json.dumps({"port": 9091}).encode('utf-8'))
    response = json.loads(sock.recv(4096).decode('utf-8'))
    if not response['success']:
        raise Exception(response['message'])

    ipc.wait()

    prompt, payload = ipc.recv().split('\n', 1)

    print(payload)

    try:
        import readline  # noqa
    except ImportError:
        pass

    try:
        while True:
            try:
                input_line = eval(input(prompt))
            except EOFError:
                input_line = 'exit()'
                print('')
            except KeyboardInterrupt:
                input_line = 'None'
                print('')

            ipc.send(input_line)
            payload = ipc.recv()
            if payload is None:
                break
            prompt, payload = payload.split('\n', 1)
            if payload != '':
                print(payload)
    except:
        print('')
        raise
    finally:
        ipc.close()

    return 0
