# coding: utf-8
from __future__ import print_function
import os
import six
import sys
import traceback
import socket
import subprocess
import time
import termios
import pty
import pwd
import telnetlib
import struct
import fcntl
import signal
import errno
import tempfile
import tarfile

from telnetlib import IAC, IP, SB, SE, NAWS

from kernel.util.config import Config
from kernel.util.functional import memoized


@memoized
def findShell():
    shell = '/bin/sh'

    if isinstance(cfg.Shell, six.string_types):
        shell = cfg.Shell
    else:
        for s in cfg.Shell:
            if os.path.exists(s) and os.path.isfile(s):
                shell = s
                break

    return shell


@memoized
def userDefaultShell():
    try:
        shell = pwd.getpwuid(os.getuid()).pw_shell
    except KeyError:
        return findShell()
    else:
        if not os.path.exists(shell) or not os.path.isfile(shell):
            shell = findShell()
        return shell


class RunShell:
    def run(self):
        shell = userDefaultShell()
        subprocess.Popen([shell, '-i'], shell=False, close_fds=True, bufsize=4096).wait()


class RunCommand:
    def __init__(self, command):
        self.command = command

    def run(self):
        shell = userDefaultShell()
        subprocess.Popen(' '.join(self.command), executable=shell, shell=True, close_fds=True, bufsize=4096).wait()


class ShellRunner(object):
    modules = ['library']

    class OverflowError(RuntimeError):
        pass

    class Exclude(object):
        def __init__(self, limit):
            self.totalSize = 0
            self.limit = limit

        def __call__(self, name):
            try:
                self.totalSize += os.stat(name).st_size
            except EnvironmentError:
                pass
            if self.totalSize > self.limit:
                raise ShellRunner.OverflowError()

    def __init__(self, term_type, command, iterate=False):
        self.term_type = term_type
        self.command = command
        self.tar_data = None
        self.tar_files = self.create_tar_data()
        self.iterate = iterate

    def create_tar_data(self):
        copy_to_remote_host = cfg.Deploy
        default_files = cfg.DefaultFiles

        tar_files = []
        home = os.path.expanduser('~')

        totalSize = 0
        sizeLimit = 200 * 1024 * 1024
        try:
            for name in copy_to_remote_host:
                string_io = six.moves.cStringIO()
                tar_file = tarfile.TarFile(fileobj=string_io, mode='w')
                filter = ShellRunner.Exclude(sizeLimit)
                try:
                    name = os.path.expanduser(name)
                    basename = name[len(home) + 1:] if name.startswith(home) else name
                    if os.path.exists(name):
                        tar_file.add(
                            os.path.realpath(name), basename,
                            exclude=filter
                        )
                    elif basename in default_files:
                        self._put_default_file(basename, tar_file)
                    if totalSize + filter.totalSize > sizeLimit:
                        raise ShellRunner.OverflowError()
                    totalSize += filter.totalSize
                except ShellRunner.OverflowError:
                    print("`{}` will not be copied to remote host: size limit {} exceeded".format(name, sizeLimit), file=sys.stderr)
                else:
                    tar_file.close()
                    string_io.reset()
                    tar_files.append(string_io.read())
        except EnvironmentError as e:
            print("Only default configs will be copied because of: {}".format(e), file=sys.stderr)
            tar_files = []
            string_io = six.moves.cStringIO()
            tar_file = tarfile.TarFile(fileobj=string_io, mode='w')
            for basename in default_files:
                self._put_default_file(basename, tar_file)
            tar_file.close()
            string_io.reset()
            tar_files.append(string_io.read())

        return tar_files

    def _put_default_file(self, basename, tar_file):
        default_files = cfg.DefaultFiles
        temp_file_name = tempfile.mktemp()
        try:
            with open(temp_file_name, 'w') as temp_file:
                temp_file.write('\n'.join([s.strip() for s in default_files[basename].strip().split('\n')]) + '\n')
            tar_file.add(temp_file_name, basename)
        finally:
            os.remove(temp_file_name)

    def create_socket(self):
        try:
            self.server_sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
            self.server_sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 0)
            self.server_sock.bind(('::', 0))
        except socket.error:  # if unavailable, fallback to ipv4
            self.server_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            self.server_sock.bind(('0.0.0.0', 0))
        self.server_sock.listen(1)
        self.telnet_port = self.server_sock.getsockname()[1]

    def accept_socket(self):
        sock = self.server_sock.accept()[0]
        self.server_sock.close()
        return sock

    def yield_and_wait(self, pid):
        yield self.telnet_port
        os.waitpid(pid, 0)

    def run(self):
        signal.alarm(60)

        self.create_socket()

        syncPipe = os.pipe()

        pid = os.fork()

        if pid:
            os.read(syncPipe[0], 1)
            signal.alarm(0)
            if self.iterate:
                return self.yield_and_wait(pid)
            else:
                return self.telnet_port

        # TODO rewrite on generator or Queue to remove this
        os.setpgrp()
        os.write(syncPipe[1], '\x00')

        os.close(syncPipe[0])
        os.close(syncPipe[1])

        sock = self.accept_socket()
        signal.alarm(0)

        try:
            handler = TelnetHandler(sock, self.term_type, self.command, tar_data=self.tar_data, tar_files=getattr(self, 'tar_files', None))
            handler.run()
        except:
            from kernel.util.errors import formatException
            return formatException()
        finally:
            handler.destroy()

    __call__ = run


class TelnetHandler(object):
    "A remote job to be run by skynet.  Implements a rudimentary telnet server."""
    def __init__(self, sock, term_type, command, tar_data=None, tar_files=None):
        self.sock = sock
        self.shell_pid = None
        self.pty_fd = None
        self.term_type = term_type
        self.home = None
        self.command = command
        self.timeout = 3600
        self.alarm_timeout = self.timeout + 300  # just in case write() or send() syscall hangs.
        self.tar_data = tar_data
        self.tar_files = tar_files

    def setup_environment(self):
        self.home = os.path.realpath(tempfile.mkdtemp())
        os.chdir(self.home)
        os.environ['HOME'] = self.home
        os.environ['TERM'] = self.term_type
        os.environ['SHELL'] = userDefaultShell()

        # TODO: customizability. This could be done by allowing user to create
        # a python file ~/.skyshellrc[.py] with hooks functions/classes.
        if self.tar_data:  # backward compatibility
            tarfile.TarFile(fileobj=six.moves.cStringIO(self.tar_data)).extractall()
        if self.tar_files:
            for tar_piece in self.tar_files:
                tarfile.TarFile(fileobj=six.moves.cStringIO(tar_piece)).extractall()

    def fork_shell(self):
        self.shell_pid, self.pty_fd = pty.fork()
        if self.shell_pid == 0:
            try:
                self.command.run()
            finally:
                os._exit(0)

        # TODO: setsid() and friends

    def create_telnet(self, sock=None):
        def process_IAC(xsock, cmd, option):
            try:
                if cmd == IP:  # Ctrl-C
                    os.write(self.pty_fd, chr(3))
                elif cmd == SE:
                    option = telnet.sbdataq[0]
                    if option == NAWS:  # negotiate window size.
                        dataq = telnet.sbdataq[1:]
                        if len(dataq) == 2:
                            cols, rows = struct.unpack('!BB', dataq)
                        elif len(dataq) == 4:
                            cols, rows = struct.unpack('!HH', dataq)
                        else:
                            cols, rows = 80, 24  # cant detect :(

                        s = struct.pack('HHHH', rows, cols, 0, 0)
                        fcntl.ioctl(self.pty_fd, termios.TIOCSWINSZ, s)
            except:
                from kernel.util.errors import formatException
                self.sock.sendall('Exception during process_IAC() on telnet handler:\n')
                self.sock.sendall(formatException())
                raise

        if sock:
            self.sock = sock

        try:
            self.sock.settimeout(300)
            signal.alarm(self.alarm_timeout)

            telnet = telnetlib.Telnet()
            telnetlib.theNULL = ''
            telnet.sock = self.sock
            telnet.set_option_negotiation_callback(process_IAC)
        except:
            from kernel.util.errors import formatException
            self.sock.sendall('Exception during telnet initialization:\n')
            self.sock.sendall(formatException())
            raise

        return telnet

    def telnet_session(self, telnet):
        import select
        while True:
            r_fds, w_fds, x_fds = select.select(
                [self.sock.fileno(), self.pty_fd],
                [],
                [self.sock.fileno(), self.pty_fd],
                self.timeout
            )

            if len(r_fds) == 0 and len(w_fds) == 0 and len(x_fds) == 0:
                try:
                    self.sock.sendall('\r\n*** Disconnecting you for loss of activity ***\r\n')
                except:
                    pass
                break

            if self.pty_fd in r_fds:
                buf = os.read(self.pty_fd, 4096)
                if len(buf) != 0:
                    self.sock.sendall(buf)
                    signal.alarm(self.alarm_timeout)
                else:
                    break

            if self.sock.fileno() in r_fds:
                try:
                    buf = telnet.read_eager()
                    os.write(self.pty_fd, buf)
                    signal.alarm(self.alarm_timeout)
                except:
                    break

    def run(self):
        self.setup_environment()
        self.fork_shell()
        _telnet = self.create_telnet()
        self.telnet_session(_telnet)

    def destroy(self):
        try:
            if self.shell_pid is not None:
                os.kill(self.shell_pid, 15)
                time.sleep(0.1)
                os.kill(self.shell_pid, 9)
        except:
            pass

        try:
            if self.pty_fd is not None:
                os.close(self.pty_fd)
        except:
            pass

        try:
            if self.home is not None:
                os.chdir('/')
                subprocess.call("rm -rf '%s'" % self.home, shell=True)
        except:
            pass

        try:
            if self.sock is not None:
                self.sock.close()
        except:
            pass


class TelnetClient(object):
    """A rudimentary telnet client."""

    # TODO: escape telnet control chars, preferably by using telnetlib to manage the connection

    MAGIC_KEY = 0x1D  # Ctrl-]

    def __init__(self, sock=None):
        self.original_termios_attr = termios.tcgetattr(sys.stdin.fileno())
        self.sock = sock

    def connect(self, host, port):
        self.sock = socket.create_connection((host, port))

    def set_raw_mode(self):
        attr = list(self.original_termios_attr)
        attr[3] = attr[3] & ~termios.ICANON & ~termios.ECHO
        termios.tcsetattr(0, termios.TCSANOW, attr)

    def reset_raw_mode(self):
        termios.tcsetattr(0, termios.TCSANOW, self.original_termios_attr)

    def send_terminal_size(self, *args):
        s = struct.pack('HHHH', 0, 0, 0, 0)
        result = fcntl.ioctl(sys.stdin.fileno(), termios.TIOCGWINSZ, s)
        rows, cols = struct.unpack('HHHH', result)[0:2]

        negotiation = SB + NAWS
        negotiation += struct.pack('!HH', cols, rows)

        # telnet spec force us to double any 255 value
        negotiation = negotiation.replace('\xff', '\xff\xff')

        self.sock.sendall(IAC + negotiation + IAC + SE)

    def do_magic(self):
        sys.stderr.write('\r\nType "q" to quit.\nskyshell> ')
        command = sys.stdin.readline()
        if command.startswith('q'):
            return 'quit'

    def ctrl_c(self, *args):
        self.sock.sendall(IAC + IP)

    def ctrl_z(self, *args):
        self.sock.sendall(chr(26))  # Ctrl-Z sends ASCII code 26

    def interact_internal(self):
        import gevent.select as select
        signal.signal(signal.SIGINT, self.ctrl_c)
        signal.signal(signal.SIGTSTP, self.ctrl_z)
        signal.signal(signal.SIGWINCH, self.send_terminal_size)
        self.set_raw_mode()
        self.send_terminal_size()

        while True:
            self.set_raw_mode()

            try:
                r_fds, w_fds, x_fds = select.select([self.sock.fileno(), sys.stdin.fileno()], [], [])

                if self.sock.fileno() in r_fds:
                    buf = self.sock.recv(4096)
                    if len(buf) != 0:
                        os.write(sys.stdout.fileno(), buf)
                    else:
                        sys.stderr.write('\r\nConnection closed by remote host.\n')
                        break

                if sys.stdin.fileno() in r_fds:
                    buf = os.read(sys.stdin.fileno(), 4096)
                    if len(buf) == 1 and ord(buf[0]) == self.MAGIC_KEY:
                        self.reset_raw_mode()
                        if self.do_magic() == 'quit':
                            break
                        self.set_raw_mode()
                    elif len(buf) != 0:
                        self.sock.sendall(buf)
                    else:
                        break

            except select.error as err:
                if err.args[0] == errno.EINTR:  # ignore interrupted syscalls
                    pass
                else:
                    raise
            except socket.error as err:
                # don't raise - just print error and break
                if err.errno == errno.ECONNRESET:
                    sys.stderr.write('\r\nConnection closed by remote host.\n')
                    break
                else:
                    raise
            except KeyboardInterrupt:
                # gevent.select.select() doesn't work well with signals for some reason
                self.ctrl_c()

    def interact(self):
        try:
            self.interact_internal()
        except:
            sys.stderr.write(traceback.format_exc())
        finally:
            self.reset_raw_mode()
            try:
                self.sock.close()
            except:
                pass


cfg = Config('ya.skynet.shell', """
# Список пробуемых для запуска шеллов, по убыванию приоритета
Shell : ['/bin/bash', '/usr/bin/bash', '/usr/local/bin/bash', '/bin/sh']
# Список файлов и каталогов, которые необходимо скопировать на удалённую машину
Deploy : ['~/.bashrc', '~/.profile', '~/.vimrc', '~/.inputrc', '~/profile', '~/.vim']
# Стандартные файлы, которые нужно использовать, если на текущей машине их нет
DefaultFiles:
  .bashrc : |
    export PATH=$PATH:/skynet/tools:/Berkanavt/bin
    export PS1="\[\\033[34;01m\]\\u@\h:\w\[\\033[m\]\\\\$ "
    export HISTFILE=
    export HISTCONTROL=ignoreboth
    export LESSHISTFILE=-
    export LANG=en_US.UTF-8
    alias ls="ls -G"
    alias l="ls -l"
    alias ll="ls -la"

  .inputrc : |
    set input-meta on
    set output-meta on
    $if mode=emacs
    "\e[1~": beginning-of-line
    "\e[4~": end-of-line
    "\e[3~": delete-char
    "\e[2~": quoted-insert
    "\e[5~": beginning-of-history
    "\e[6~": end-of-history
    "\e[1;5C": forward-word
    "\e[1;5D": backward-word
    "\e[5C": forward-word
    "\e[5D": backward-word
    "\e\e[C": forward-word
    "\e\e[D": backward-word
    $if term=rxvt
    "\e[8~": end-of-line
    "\eOc": forward-word
    "\eOd": backward-word
    $endif
    $endif

  .vimrc: |
    set nocompatible nobackup noswapfile noerrorbells backspace=2 autoread autowrite nowrap
    set incsearch ignorecase smartcase hlsearch tabstop=4 shiftwidth=4 expandtab autoindent
    set ruler showcmd nowrap wildmenu formatoptions+=r display+=lastline,uhex
    set laststatus=2 statusline=%<%f%h%m%r%=%b=0x%B\ \ %l,%c%V\ %P
    set fileencodings=ucs-bom,utf-8,cp1251,default,latin1
    set viminfo=""
    syntax on
""")
