import os
import time
import errno
import fcntl
import struct
import signal
import select
import termios
import threading
import telnetlib

from collections import deque
from functools import partial

import six
import msgpack

from porto.exceptions import WaitContainerTimeout

from .sysutils import pipe

from ya.skynet.util.misc import daemonthr
from ya.skynet.util.errors import formatException
from ya.skynet.util.net.socketstream import SocketStream

__all__ = ['TelnetHandler', 'SocketOutHandler', 'WatchHandler', 'InOutHandler']

DEFAULT_TIMEOUT = 3600


class TelnetHandler(object):
    """A remote job to be run by skynet.  Implements a rudimentary telnet server."""
    def __init__(self, container, sock, pty_fd, timeout=DEFAULT_TIMEOUT):
        self.container = container
        try:
            timeout = int(timeout)
        except (TypeError, ValueError):
            timeout = DEFAULT_TIMEOUT

        if timeout <= 0:
            timeout = DEFAULT_TIMEOUT

        self.sock = sock
        self.pty_fd = pty_fd
        self.alive_pipes = pipe()
        self.timeout = timeout
        self.alarm_timeout = self.timeout + 300  # just in case write() or send() syscall hangs.

    def process_IAC(self, telnet, xsock, cmd, option):
        try:
            if cmd == telnetlib.IP:  # Ctrl-C
                os.write(self.pty_fd, chr(3))
            elif cmd == telnetlib.SE:
                option = telnet.sbdataq[0]
                if option == telnetlib.NAWS:  # negotiate window size.
                    dataq = telnet.sbdataq[1:]
                    if len(dataq) == 2:
                        cols, rows = struct.unpack('!BB', dataq)
                    elif len(dataq) == 4:
                        cols, rows = struct.unpack('!HH', dataq)
                    else:
                        cols, rows = 80, 24  # cant detect :(

                    self.set_winsize(rows, cols)
        except:
            self.sock.sendall('Exception during process_IAC() on telnet handler:\n')
            self.sock.sendall(formatException())
            raise

    def create_telnet(self):
        try:
            self.sock.settimeout(300)
            signal.alarm(self.alarm_timeout)

            telnet = telnetlib.Telnet()
            telnetlib.theNULL = ''
            telnet.sock = self.sock
            telnet.set_option_negotiation_callback(partial(self.process_IAC, telnet))
        except:
            self.sock.sendall('Exception during telnet initialization:\n')
            self.sock.sendall(formatException())
            raise

        return telnet

    def aliveness_thread(self):
        while True:
            try:
                self.container.WaitContainer(60)
            except WaitContainerTimeout:
                continue
            except Exception:
                # TODO log error
                os.write(self.alive_pipes[1], '1')
                break
            else:
                os.write(self.alive_pipes[1], '1')
                break

    def telnet_session(self, telnet):
        while True:
            r_fds, w_fds, x_fds = select.select(
                [self.sock.fileno(), self.pty_fd, self.alive_pipes[0]],
                [],
                [self.sock.fileno(), self.pty_fd, self.alive_pipes[0]],
                self.timeout
            )

            if len(r_fds) == 0 and len(w_fds) == 0 and len(x_fds) == 0:
                try:
                    self.sock.sendall('\r\n*** Disconnecting you for loss of activity ***\r\n')
                    self.container.Kill(signal.SIGKILL)
                except:
                    pass
                break

            if self.pty_fd in r_fds:
                try:
                    buf = os.read(self.pty_fd, 4096)
                except OSError:
                    break

                if len(buf) != 0:
                    self.sock.sendall(buf)
                    signal.alarm(self.alarm_timeout)
                else:
                    break

            if self.sock.fileno() in r_fds:
                try:
                    buf = telnet.read_eager()
                    os.write(self.pty_fd, buf)
                    signal.alarm(self.alarm_timeout)
                except:
                    break

            if self.alive_pipes[0] in r_fds:
                # If shell dies, it can leave some subprocs which would
                # write to stdout/stderr, so we cannot wait for fd contents to finish.
                # So we had at most 4KB (pipe buffer size) of data in pty_fd, produced
                # by original process, which was read above in this loop. Now we
                # can safely ignore other pipe contents and just exit
                break

        while self.alive_pipes[0] not in r_fds:
            r_fds, _, _ = select.select([self.alive_pipes[0]], [], [], 60)

    def run(self):
        _telnet = self.create_telnet()
        t = threading.Thread(target=self.aliveness_thread)
        t.daemon = True
        t.start()
        self.telnet_session(_telnet)

    def destroy(self):
        try:
            if self.pty_fd is not None:
                os.close(self.pty_fd)
        except:
            pass

        try:
            if self.sock is not None:
                self.sock.close()
        except:
            pass

    def set_winsize(self, rows, cols, xpixel=0, ypixel=0):
        s = struct.pack('HHHH', rows, cols, xpixel, ypixel)
        fcntl.ioctl(self.pty_fd, termios.TIOCSWINSZ, s)


def make_send_adapter(sockobj, api_mode):
    if api_mode:
        wrapper = SocketStream(sockobj)
        return wrapper.writeBEStr
    else:
        return sockobj.sendall


class Stream(object):
    __slots__ = ['stream', 'len', 'sock', 'streaming', 'api_mode', 'mark', 'direct', 'sendfun']

    def __init__(self, mark, sock=None, streaming=True, direct=False, api_mode=False, sendfun=None):
        self.mark = mark
        self.stream = deque()
        self.len = 0
        self.streaming = streaming
        self.api_mode = api_mode
        self.direct = direct
        self.sendfun = sendfun or make_send_adapter(sock, api_mode)

    def _send(self, data):
        if not isinstance(data, str):
            data = ''.join(data)

        if self.api_mode:
            data = msgpack.dumps({
                'error': False,
                'output': self.mark,
                'data': data,
            })

        self.sendfun(data)

    def feed(self, data):
        if not data:
            return

        if self.direct:
            self._send(data)
            return

        self.stream.append(data)
        self.len += len(data)

        if not self.streaming:
            return

        npos = self.stream[-1].find('\n')
        msgs = []

        if npos > -1:
            n = len(self.stream) - 1
            to_append = [self.stream.popleft() for _ in six.moves.xrange(n)]
            msgs.extend(to_append)
            self.len -= sum(len(m) for m in to_append)

            msgs.append(self.stream[0][:npos])
            self.len -= len(msgs[-1]) + 1
            self.stream[0] = self.stream[0][npos + 1:]

            # if we'd append in api_mode we'd have to change cqudp code,
            # so keep backward compatibility
            if not self.api_mode:
                msgs.append('\n')

            self._send(msgs)
            msgs = []

            npos = self.stream[0].find('\n')
            while npos > -1:
                msg = self.stream[0][:npos]
                # same note as above
                if not self.api_mode:
                    msg += '\n'
                    self.len -= len(msg)
                else:
                    self.len -= len(msg) + 1
                self._send(msg)
                self.stream[0] = self.stream[0][npos + 1:]
                npos = self.stream[0].find('\n')

        else:
            msglen = 0
            while self.len >= (1 << 14):
                msgs.append(self.stream.popleft())
                self.len -= len(msgs[-1])
                msglen += len(msgs[-1])
                if msglen >= (1 << 11):
                    self._send(msgs)
                    msgs = []
                    msglen = 0
        if msgs:
            self._send(msgs)

    def finish(self):
        msg, self.stream = ''.join(self.stream), deque()
        self.len = 0
        if msg:
            self._send(msg)


class SocketOutHandler(object):
    """Just take stdout and stderr from porto and send to sock"""
    def __init__(self, sock, container, log, stdout, stderr, api_mode=False, streaming=True):
        self.sock = sock
        self.job = container
        self.log = log
        self.timeout = 3600
        self.outfd = stdout
        self.errfd = stderr
        self.out_closed = False
        self.err_closed = False
        self.out_stream = Stream(sock=self.sock, mark='stdout', streaming=streaming, api_mode=api_mode)
        self.err_stream = Stream(sock=self.sock, mark='stderr', streaming=streaming, api_mode=api_mode)

    @property
    def pending_out(self):
        return bool(self.out_stream.stream)

    @property
    def pending_err(self):
        return bool(self.err_stream.stream)

    def _check_outputs(self):
        signal.alarm(self.timeout)

        r = (
            [self.outfd] if not self.out_closed else []) + (
            [self.errfd] if not self.err_closed else []) + (
            [self.sock]
            )
        r, _, _ = select.select(r, [], [], 1.0)

        signal.alarm(self.timeout)

        if self.sock in r:
            try:
                data = self.sock.recv(1024)
                if not data:
                    # socket is closed, so die
                    self.sock = None
                    return
            except EnvironmentError as e:
                if e.errno not in (errno.EINTR, errno.EAGAIN):
                    raise

        for fd, name, stream in ((self.outfd, 'out', self.out_stream), (self.errfd, 'err', self.err_stream)):
            if fd in r:
                data = os.read(fd, 16384)
                if not data:
                    self.log.debug('got 0 bytes from %r, closing', name)
                    setattr(self, name + '_closed', True)
                    setattr(self, name + 'fd', None)
                else:
                    self.log.debug('got %d bytes from %r', len(data), name)
                    stream.feed(data)

        signal.alarm(self.timeout)

    def run(self):
        signal.alarm(self.timeout)
        while self.sock is not None and (
            self.job.GetData("state") != "dead"
            or not self.out_closed
            or not self.err_closed
        ):
            self._check_outputs()

        self.err_stream.finish()
        self.out_stream.finish()


class WatchHandler(object):
    """Just take stdout and stderr from porto and send to channel"""
    def __init__(self, channel, container, log):
        self.channel = channel
        self.job = container
        self.log = log
        self.timeout = 3600
        self.sock_timeout = 300
        self.should_stop = False
        self.stdout = Stream(mark='stdout', direct=True, sendfun=self.channel.sendall)
        self.stderr = Stream(mark='stderr', direct=True, sendfun=self.channel.sendall_stderr)
        self.stdout_offset = -1
        self.stderr_offset = -1

    def _watch_sock(self):
        while self.channel.get_transport().is_alive():
            time.sleep(1)
        self.should_stop = True

    def _check_output(self, outtype):
        stream = getattr(self, outtype)

        offset = int(self.job.GetData(outtype + '_offset'))
        end = getattr(self, outtype + '_offset')

        if offset >= end:
            if end >= 0 and offset > end:
                self.log.info("%s new offset %s is greater than old end %s, at least %s bytes lost", outtype, offset, end, offset - end)
            end = offset
        else:
            # read only diff
            pass

        data = self.job.GetData("%s[%d]" % (outtype, end))
        setattr(self, outtype + '_offset', end + len(data))

        stream.feed(data)

    def run(self):
        signal.alarm(self.timeout)
        daemonthr(self._watch_sock)
        while self.job.GetData("state") != "dead" and not self.should_stop:
            signal.alarm(self.timeout)
            self._check_output('stderr')
            signal.alarm(self.timeout)
            self._check_output('stdout')

            time.sleep(1)
            signal.alarm(self.timeout)

        self._check_output('stderr')
        self._check_output('stdout')
        self.stderr.finish()
        self.stdout.finish()


class InOutHandler(object):
    """Proxy out/err from process to SSH channel and stdin from channel to process"""
    def __init__(self, log, container, channel, stdin, stdout, stderr):
        self.log = log
        self.job = container
        self.channel = channel
        self.timeout = 3600
        self.infd = stdin
        self.outfd = stdout
        self.errfd = stderr

        self.pending_in = deque()
        self.pending_out = deque()
        self.pending_err = deque()

        self.in_closed = False if stdin is not None else True
        self.in_channel_closed = False if stdin is not None else True
        self.out_closed = False if stdout is not None else True
        self.err_closed = False if stderr is not None else True

    def _check_outputs(self):
        signal.alarm(self.timeout)

        r = (
            [self.outfd] if not self.pending_out and not self.out_closed else []) + (
            [self.errfd] if not self.pending_err and not self.err_closed else []) + (
            [self.channel] if not self.pending_in and not self.in_channel_closed else []
            )
        w = [self.infd] if self.pending_in and not self.in_closed else []
        r, w, _ = select.select(r, w, [], 1.0)

        signal.alarm(self.timeout)

        if self.channel in r:
            data = self.channel.recv(16384)

            if not data:
                self.log.debug('got 0 bytes from channel, closing')
                self.in_channel_closed = True
            else:
                self.pending_in.append(data)
                self.log.debug('got %d bytes from channel', len(data))

        for fd, name, queue in ((self.outfd, 'out', self.pending_out), (self.errfd, 'err', self.pending_err)):
            if fd in r:
                data = os.read(fd, 16384)
                if not data:
                    self.log.debug('got 0 bytes from %r, closing', name)
                    setattr(self, name + '_closed', True)
                else:
                    self.log.debug('got %d bytes from %r', len(data), name)
                    queue.append(data)

        signal.alarm(self.timeout)

        if self.infd in w:
            try:
                self.log.debug('sending %d bytes to stdin', len(self.pending_in[0]))
                data = self.pending_in.popleft()
                written = 0
                while True:
                    n = os.write(self.infd, data[written:])
                    written += n
                    if written == len(data):
                        break
            except EnvironmentError as e:
                if e.errno == errno.EPIPE:
                    self.log.debug('broken pipe, closing stdin')
                    self.in_closed = True
                    self.in_channel_closed = True
                elif e.errno == errno.EAGAIN:
                    self.pending_in.appendleft(data[written:])
                else:
                    raise

        if self.pending_err:
            self.log.debug('sending %d bytes to channel stderr', len(self.pending_err[0]))
            self.channel.sendall_stderr(self.pending_err.popleft())

        if self.pending_out:
            self.log.debug('sending %d bytes to channel stdout', len(self.pending_out[0]))
            self.channel.sendall(self.pending_out.popleft())

        if self.in_channel_closed and not self.pending_in and not self.in_closed:
            self.log.debug('no pending stdin data, closing stdin')
            os.close(self.infd)
            self.in_closed = True
        signal.alarm(self.timeout)

    def run(self):
        signal.alarm(self.timeout)
        dead = False
        while not self.channel.closed and (
            not dead
            or self.pending_out or not self.out_closed
            or self.pending_err or not self.err_closed
            or not self.in_closed
        ):
            dead = dead or (self.job.GetData("state") == "dead")
            if dead:
                self.in_channel_closed = True
                self.pending_in = []

            self._check_outputs()
