from __future__ import absolute_import
import os
import pwd
import sys
import stat
import time
import uuid
import errno
import fcntl
import select
import signal
import socket
import struct
import tempfile
import traceback
import threading
from SocketServer import BaseRequestHandler

import six
import msgpack
import paramiko

from ya.skynet.util.misc import daemonthr

from .authenticate import authenticate, format_key
from .slots.slot_sources import find_slot
from .slots.exceptions import AuthError, SlotLookupError, ConfigurationLookupError
from .authsyslog import AuthSyslog
from .connserver import ReverseTcpForwardServer, forward
from .filesystem import futime, safe_chdir
from .portotools import get_portoconn
from .shell import Context, StartupException, check_for_scp
from .heartbeat import schedule_report
from .utils import suppress
from . import logger

from infra.skylib.porto import get_container_user_and_group
from infra.skylib import safe_container_actions


if sys.version_info.major < 3:
    def b(s):
        return s.encode('utf-8') if isinstance(s, unicode) else str(s)  # noqa
else:
    def b(s):
        return s if isinstance(s, bytes) else bytes(str(s), 'utf-8')


stat_type = type(os.stat('/'))


class AgentLocalProxy(threading.Thread):
    def __init__(self, agent, conn, sock, log):
        threading.Thread.__init__(self, target=self.run)
        self.log = log
        self._agent = agent
        self.conn = conn
        self.sock = sock

    def run(self):
        try:
            if not isinstance(self._agent, int) and (self.conn is None or not hasattr(self.conn, 'fileno')):
                self.log.error("invalid SSH agent connection")
                raise paramiko.AuthenticationException("Unable to connect to SSH agent")
            self._communicate()
        except Exception as e:
            self.log.debug("connection (channel %r) finished with: %s", self.conn.get_id(), e)
        finally:
            self._close()

    def _communicate(self):
        oldflags = fcntl.fcntl(self.sock, fcntl.F_GETFL)
        fcntl.fcntl(self.sock, fcntl.F_SETFL, oldflags | os.O_NONBLOCK)
        while not self._agent._exit:
            events = select.select([self.conn, self.sock], [], [], 1.)
            for fd in events[0]:
                if self.conn == fd:
                    data = self.conn.recv(512)
                    self.log.debug("got %d bytes from channel %r", len(data), self.conn.get_id())
                    if len(data) != 0:
                        self.sock.send(data)
                    else:
                        break
                elif self.sock == fd:
                    data = self.sock.recv(512)
                    if len(data) != 0:
                        self.log.debug("got %d bytes from socket", len(data))
                        self.conn.send(data)
                    else:
                        return
            time.sleep(0.01)

    def _close(self):
        self.sock.close()
        self.conn.close()
        self._agent.workers.discard(self)


class Agent(object):
    """
    Custom AgentServerProxy reimplementation.
    The standard one is broken in a couple of places
    """
    def __init__(self, transport, log):
        self.log = log
        self.transport = transport
        self.dirpath = None
        self.sockpath = None
        self.user = None

        self.sock = None
        self._exit = False
        self.thread = threading.Thread(target=self.accept_connections, name='AgentMainLoop')
        self.thread.daemon = True
        self.workers = set()

    def __del__(self):
        self.close()

    def get_filename(self):
        return self.sockpath

    def connect(self):
        chan = self.transport.open_forward_agent_channel()
        if chan is None:
            raise paramiko.SSHException('lost ssh-agent')
        chan.set_name('auth-agent')
        return chan

    def start(self, user=None, basedir=None):
        if self.dirpath is None:
            self.dirpath = tempfile.mkdtemp(dir=basedir)
            if user:
                target_user = pwd.getpwnam(user)
                self.user = (target_user.pw_uid, target_user.pw_gid)
                os.chown(self.dirpath, target_user.pw_uid, target_user.pw_gid)

            os.chmod(self.dirpath, stat.S_IRWXU)

        if self.sockpath is None:
            self.sockpath = os.path.join(self.dirpath, "sshproxy.sock")
            self.sock = socket.socket(socket.AF_UNIX)
            self.log.debug("binding to %r", self.get_filename())
            sock_dir, sock_filename = os.path.split(self.get_filename())
            with safe_chdir(sock_dir):
                self.sock.bind(sock_filename)
                if self.user:
                    os.chown(sock_filename, *self.user)
            self.sock.listen(1)
            self.thread.start()

    def accept_connections(self):
        while not self._exit:
            if self.sock in select.select([self.sock], [], [], 5)[0]:
                conn = None
                r, _ = self.sock.accept()
                try:
                    conn = self.connect()
                    self.log.debug("accepted connection and mapped to channel %r", conn.get_id())
                    thread = AgentLocalProxy(self, conn, r, self.log.getChild('conn'))
                    thread.daemon = True
                    self.workers.add(thread)
                    thread.start()
                except Exception:
                    self.log.exception("agent initialization failed:")
                    r.close()
                    if conn is not None:
                        conn.close()

    def close(self):
        self._exit = True
        sockpath = self.sockpath
        dirpath = self.dirpath
        if self.sock is not None:
            self.sock.close()
        if sockpath and os.path.exists(sockpath):
            os.unlink(sockpath)
        self.sockpath = None

        if dirpath:
            os.rmdir(dirpath)
        self.dirpath = None

        if self.thread.is_alive():
            self.thread.join(1000)
        self.thread = None
        for thread in list(self.workers):
            if thread.is_alive():
                thread.join(1000)
        self.workers = set()


def make_token():
    return dict(
        acc_user=None,
        ssh_user='',
        lookup_args={},
        watch_parent=False,
        inactivity_timeout=60 * 60 * 24 * 7,
    )


def wrap_oserror(fn):
    def wrapped(*args, **kwargs):
        try:
            return fn(*args, **kwargs)
        except OSError as e:
            return paramiko.sftp_server.SFTPServer.convert_errno(e.errno)
    return wrapped


class SFTPHandle(paramiko.SFTPHandle):
    def __init__(self, parent, fileobj, flags=0):
        super(SFTPHandle, self).__init__(flags)
        self.parent = parent
        self.fileobj = fileobj

    @property
    def readfile(self):
        return self.fileobj

    @property
    def writefile(self):
        return self.fileobj

    @wrap_oserror
    def chattr(self, attr):
        self.parent.set_fd_attr(self.fileobj.fileno(), attr)
        return paramiko.SFTP_OK

    @wrap_oserror
    def stat(self):
        result = os.fstat(self.fileobj.fileno())
        return _pack_attr(paramiko.SFTPAttributes.from_stat(result))


class SFTPHandleProxy(paramiko.SFTPHandle):
    def __init__(self, worker, fileno, flags=0):
        self.worker = worker
        self.fileno = fileno
        super(SFTPHandleProxy, self).__init__(flags=flags)

    def close(self):
        self.worker.apply('handle_close', self.fileno)

    def read(self, offset, length):
        return self.worker.apply('handle_read', self.fileno, offset, length)

    def write(self, offset, data):
        return self.worker.apply('handle_write', self.fileno, offset, data)

    def stat(self):
        result = self.worker.apply('handle_stat', self.fileno)
        if isinstance(result, int):
            return result

        return _unpack_attr(result)

    def chattr(self, attr):
        return self.worker.apply('handle_chattr', self.fileno, _pack_attr(attr))


class SFTPServerInterface(paramiko.SFTPServerInterface):
    def __init__(self, server, ns_container, *args, **kwargs):
        self.sock = None
        self.pid = None
        self.ns_container = ns_container
        self.server = server
        container = server.slot.container

        portoconn = get_portoconn(False)
        target_user, _ = get_container_user_and_group(portoconn, container, server.token['ssh_user'])
        root_pid = int(portoconn.Find(container).GetData('root_pid'))

        self.worker = SFTPServerInterface.SFTPWorker()
        self.sock, self.pid = safe_container_actions.run_process(self.worker.run, target_user, root_pid)
        server.log.info("spawned sftp worker with pid %r" % (self.pid,))

        super(SFTPServerInterface, self).__init__(server, *args, **kwargs)

    def session_ended(self):
        if self.sock is not None:
            self.sock.sendall(msgpack.dumps(('shutdown', ())))
            self.sock.close()
        if self.pid is not None:
            os.kill(self.pid, 9)

    def apply(self, fn, *args):
        data = msgpack.dumps((fn, args))
        try:
            self.sock.sendall(struct.pack('i', len(data)))
            self.sock.sendall(data)
        except socket.error as e:
            self.server.log.info("worker disconnected: %s", e)
            return paramiko.SFTP_CONNECTION_LOST

        l = self.sock.recv(4)
        if not l:
            self.server.log.info("worker disconnected")
            return paramiko.SFTP_CONNECTION_LOST
        length = struct.unpack('i', l)[0]
        data = self.sock.recv(length)
        if not data:
            self.server.log.info("worker disconnected")
            return paramiko.SFTP_CONNECTION_LOST
        return msgpack.loads(data)

    class SFTPWorker(object):
        def __init__(self):
            self.fds = {}

        def _send(self, item):
            data = msgpack.dumps(item)
            length = struct.pack('i', len(data))
            self.sock.sendall(length)
            self.sock.sendall(data)

        def _recv(self):
            l = self.sock.recv(4)
            if not l:
                raise Exception("remote side disconnected")

            length = struct.unpack('i', l)[0]
            data = self.sock.recv(length)
            if not data:
                raise Exception("remote side disconnected")
            return msgpack.loads(data)

        def run(self, sock):
            self.sock = sock
            os.chdir('/')

            while True:
                try:
                    cmd, args = self._recv()
                except socket.timeout:
                    continue

                if cmd == 'shutdown':
                    sock.close()
                    os.kill(os.getpid(), 9)

                fn = getattr(self, cmd, None)
                if not fn or not callable(fn) or not getattr(fn, 'cmd', False):
                    self._send(paramiko.SFTP_FAILURE)

                try:
                    result = fn(*args)
                except Exception:
                    traceback.print_exc(file=sys.stderr)
                    self._send(paramiko.SFTP_FAILURE)
                else:
                    self._send(result)

        def cmd(fn):
            fn.cmd = True
            return fn

        def set_fd_attr(self, fd, attr):
            if sys.platform != 'win32':
                # mode operations are meaningless on win32
                if attr._flags & attr.FLAG_PERMISSIONS:
                    os.fchmod(fd, attr.st_mode)
                if attr._flags & attr.FLAG_UIDGID:
                    os.fchown(fd, attr.st_uid, attr.st_gid)
            if attr._flags & attr.FLAG_AMTIME:
                futime(fd, (attr.st_atime, attr.st_mtime))
            if attr._flags & attr.FLAG_SIZE:
                os.ftruncate(fd, attr.st_size)

        @cmd
        def handle_close(self, fileno):
            self.fds[fileno].close()
            del self.fds[fileno]

        @cmd
        def handle_read(self, fileno, offset, length):
            return self.fds[fileno].read(offset, length)

        @cmd
        def handle_write(self, fileno, offset, data):
            return self.fds[fileno].write(offset, data)

        @cmd
        def handle_stat(self, fileno):
            return self.fds[fileno].stat()

        @cmd
        def handle_chattr(self, fileno, attr):
            return self.fds[fileno].chattr(_unpack_attr(attr))

        @cmd
        @wrap_oserror
        def open(self, path, flags, attr):
            attr = _unpack_attr(attr)
            binary_flag = getattr(os, 'O_BINARY', 0)
            flags |= binary_flag
            mode = getattr(attr, 'st_mode', None)
            fd = os.open(path, flags, mode if mode is not None else 0o666)
            if (flags & os.O_CREAT) and attr is not None:
                attr._flags &= ~attr.FLAG_PERMISSIONS
                self.set_fd_attr(fd, attr)

            if flags & os.O_WRONLY:
                fstr = 'ab' if flags & os.O_APPEND else 'wb'
            elif flags & os.O_RDWR:
                fstr = 'a+b' if flags & os.O_APPEND else 'r+b'
            else:
                fstr = 'rb'

            f = os.fdopen(fd, fstr)
            self.fds[fd] = SFTPHandle(self, f, flags)
            return (fd,)

        @cmd
        @wrap_oserror
        def list_folder(self, path):
            out = []
            for fname in os.listdir(path):
                fpath = os.path.join(path, fname)
                out.append((tuple(os.lstat(fpath)), fname))
            return out

        @cmd
        @wrap_oserror
        def stat(self, path):
            return (tuple(os.stat(path)), os.path.basename(path))

        @cmd
        @wrap_oserror
        def lstat(self, path):
            return (tuple(os.lstat(path)), os.path.basename(path))

        @cmd
        @wrap_oserror
        def remove(self, path):
            os.remove(path)
            return paramiko.SFTP_OK

        @cmd
        @wrap_oserror
        def rename(self, oldpath, newpath):
            if os.path.exists(newpath):
                return paramiko.SFTP_FAILURE
            os.rename(oldpath, newpath)
            return paramiko.SFTP_OK

        @cmd
        @wrap_oserror
        def mkdir(self, path, attr):
            os.mkdir(path)
            if attr is not None:
                paramiko.sftp_server.SFTPServer.set_file_attr(path, _unpack_attr(attr))
            return paramiko.SFTP_OK

        @cmd
        @wrap_oserror
        def rmdir(self, path):
            os.rmdir(path)
            return paramiko.SFTP_OK

        @cmd
        @wrap_oserror
        def chattr(self, path, attr):
            paramiko.sftp_server.SFTPServer.set_file_attr(path, _unpack_attr(attr))
            return paramiko.SFTP_OK

        @cmd
        @wrap_oserror
        def readlink(self, path):
            return os.readlink(path)

        @cmd
        @wrap_oserror
        def symlink(self, target_path, path):
            os.symlink(target_path, path)
            return paramiko.SFTP_OK

    def open(self, path, flags, attr):
        result = self.apply('open', path, flags, _pack_attr(attr))
        if isinstance(result, int):
            return result
        return SFTPHandleProxy(self, *result)

    def list_folder(self, path):
        result = self.apply('list_folder', path)
        if isinstance(result, int):
            return result

        return [
            paramiko.SFTPAttributes.from_stat(
                stat_type(item[0]),
                item[1]
            )
            for item in result
        ]

    def stat(self, path):
        result = self.apply('stat', path)
        if isinstance(result, int):
            return result
        return paramiko.SFTPAttributes.from_stat(stat_type(result[0]), result[1])

    def lstat(self, path):
        result = self.apply('lstat', path)
        if isinstance(result, int):
            return result
        return paramiko.SFTPAttributes.from_stat(stat_type(result[0]), result[1])

    def remove(self, path):
        return self.apply('remove', path)

    def rename(self, oldpath, newpath):
        return self.apply('rename', oldpath, newpath)

    def mkdir(self, path, attr):
        return self.apply('mkdir', path, _pack_attr(attr))

    def rmdir(self, path):
        return self.apply('rmdir', path)

    def chattr(self, path, attr):
        return self.apply('chattr', path, _pack_attr(attr))

    def readlink(self, path):
        return self.apply('readlink', path)

    def symlink(self, target_path, path):
        return self.apply('symlink', target_path, path)

    # def canonicalize(self, path)


class SSHHandler(BaseRequestHandler, paramiko.ServerInterface):
    def send_banner(self, msg, *args):
        m = paramiko.Message()
        m.add_byte(paramiko.common.cMSG_USERAUTH_BANNER)
        m.add_string(msg % args + '\n')
        m.add_string('')
        self.transport._send_user_message(m)

    def send_debug(self, msg, *args):
        m = paramiko.Message()
        m.add_byte(paramiko.common.cMSG_DEBUG)
        m.add_byte('\x00')  # verbosity flag, "always" display
        m.add_string(msg % args + '\n')
        m.add_string('')
        self.transport._send_user_message(m)

    def _send_auth_error(self, msg, *args):
        if self.auth_error_sent:
            # ssh can attempt auth multiple times regardless of error reason
            return
        self.send_banner('Error: ' + msg, *args)
        self.auth_error_sent = True

    def check_auth_publickey(self, username, key):
        username = b(username)
        fp = key.get_fingerprint()
        self.log.debug("check_auth_publickey: (%r, key %r)", username, fp.encode('hex'))

        if self.authenticated_publickey == (username, fp):
            self.log.debug("publickey already successfully checked")
            return paramiko.AUTH_SUCCESSFUL

        try:
            token = self.split_ssh_user(username)
            if token is None:
                self._send_auth_error(
                    "malformed username: see https://wiki.yandex-team.ru/Skynet/Services/portoshell/#how-to-use-ssh"
                )
                return paramiko.AUTH_FAILED

        except Exception as e:
            self.log.exception("cannot extract params from login name", exc_info=sys.exc_info())
            self._send_auth_error("cannot extract params from login name: %s", e)
            return paramiko.AUTH_FAILED

        try:
            # slot = find_slot(token.slot_name, token.configuration_id, self.server.iss)
            slot = find_slot(iss_enabled=self.server.iss, **token['lookup_args'])

            login_user = token['ssh_user'] or token['acc_user']

            try:
                self.verify_slot_filter(slot, token)
            except AuthError as e:
                with AuthSyslog('portoshell') as syslog:
                    syslog.info(
                        "%s@%s failed authentication as %s to %s: %s" % (
                            token['acc_user'],
                            self.client_address,
                            login_user,
                            slot.as_auth_info(),
                            e))
                self.log.warning('auth failure: %s', e)
                return paramiko.AUTH_FAILED

            if self.server.check_auth:
                key = authenticate(self.log.getChild('auth'),
                                   user=login_user,
                                   slot_info=slot,
                                   fingerprint=fp,
                                   keys_storage=self.server.keys_storage,
                                   ca_storage=self.server.ca_storage,
                                   maybe_certificate=key,
                                   )
                token['acc_user'] = next(iter(key.userNames), None)

            self.ctx.sessionleader_session_id = str(uuid.uuid4())

            with AuthSyslog('portoshell') as syslog:
                syslog.info(
                    "%s@%s authenticated as %s by %s into %s, session_id=%s" % (
                        token['acc_user'],
                        self.client_address,
                        login_user,
                        format_key(key),
                        slot.as_auth_info(),
                        self.ctx.sessionleader_session_id,
                    ))
            for env_key, env_val in slot.get_env_vars():
                self.env.setdefault(env_key, env_val)

            self.ctx.sessionleader_user = token['acc_user']

            if type(slot).__name__ == 'YpSlot' and token['ssh_user'] not in ('root', 'nobody'):
                self.ctx.enable_shellwrapper = True
                self.send_banner("   Welcome to YP. You are logged as user %r.", token['ssh_user'])
                self.send_banner("   To login as root,   add '-l root' option to your SSH command.")
                self.send_banner("   To login as nobody, add '-l nobody' option to your SSH command.")
                self.send_banner("")
            elif type(slot).__name__ == 'YpSlot' and token['ssh_user'] == 'nobody':
                self.send_banner("   NOTE: you are logging as a shared 'nobody' user. That means, that")
                self.send_banner("         other developers of this service may be able to access")
                self.send_banner("         forwarded ssh-agent (if any) or any private tokens you use.")
                self.send_banner("         Please, be careful.")
                self.send_banner("")
        except AuthError as e:
            with AuthSyslog('portoshell') as syslog:
                syslog.info(
                    "%s@%s failed authentication as %s to %s: %s" % (
                        token['acc_user'],
                        self.client_address,
                        login_user,
                        slot.as_auth_info(),
                        e))
            self.log.warning('auth failure: %s', e)
        except (SlotLookupError, ConfigurationLookupError) as e:
            self.log.warning('not found slot %r for %s@%s' % (token['lookup_args'],
                                                              token['acc_user'],
                                                              self.client_address))
            # self.log.warning('not found slot %r, configuration_id %r for %s@%s' % (token.slot_name,
            #                                                                        token.configuration_id,
            #                                                                        token.acc_user,
            #                                                                        self.client_address))
            self._send_auth_error('%s', e)
        except Exception as e:
            self.log.exception('unexpected error during authentication: %s' % (e,), exc_info=sys.exc_info())
            self._send_auth_error('unexpected error during authentication: %s', e)
        else:
            self.token = token
            self.slot = slot
            self.authenticated_publickey = (username, fp)
            self.authentication_key = key
            self.ctx.user = token['ssh_user']
            self.ctx.telnet_timeout = token['inactivity_timeout']
            self.ctx.setup_motd(slot, self.server.hostname)
            return paramiko.AUTH_SUCCESSFUL

        return paramiko.AUTH_FAILED

    def check_channel_direct_tcpip_request(self, chanid, origin, destination):
        self.log.debug("check_channel_direct_tcpip_request: (%r, %r, %r)", chanid, origin, destination)
        return paramiko.OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED

    def check_channel_env_request(self, channel, name, value):
        name = b(name)
        value = b(value)
        self.log.debug("check_channel_env_request: (channel %r, %r, %r)", channel.get_id(), name, value)
        if name in (b'SHELL', b'SUDO_USER', b'SUDO_UID', b'SUDO_GID'):
            return False

        self.ctx.set_env(channel, name, value)
        return True

    def _run_server(self, channel, server, session):
        try:
            server.run()
        except socket.error as e:
            if e.errno not in (errno.ECONNRESET, errno.EPIPE, errno.ECONNABORTED):
                raise
        finally:
            try:
                # ssh doesn't understand full exit status, so we return just exit code
                exit_status = int(session.container.GetData('exit_code'))
                self.log.debug("sending exit code: %s", exit_status)
            except:
                self.log.warning("failed to get container exit_status, sending 255")
                exit_status = 255

            suppress(channel.send_exit_status, exit_status)
            if not self.token.get('watch_parent'):
                suppress(session.container.Destroy)

            try:
                channel.shutdown(2)
            except EnvironmentError as e:
                if e.errno != errno.ENOTCONN:
                    self.log.warning("failed to shutdown connection to client: %s", e)
            try:
                channel.close()
            except Exception as e:
                self.log.warning("failed to close channel %s: %s", channel.get_id(), e)

            self.ctx.close_session(channel)
            # suppress(channel.close)

    def check_channel_forward_agent_request(self, channel):
        self.log.debug("check_channel_forward_agent_request: (channel %r)", channel.get_id())

        if not channel.get_transport().is_authenticated():
            self.log.warning("agent requested without authentication on channel %s", channel.get_id())
            return False

        if self.token.get('watch_parent'):
            self.log.debug('denying forward agent because of watch parent request')
            return False

        try:
            self.ctx.make_agent(channel, Agent)
        except Exception as e:
            self.log.debug('forward agent start failed: %s', e)
            return False

        return True

    def check_channel_pty_request(self, channel, term, width, height, pixelwidth, pixelheight, modes):
        term = b(term)
        self.log.debug("check_channel_pty_request: (channel %r, %r, %r, %r, %r, %r, {modes ignored})",
                       channel.get_id(), term, width, height, pixelwidth, pixelheight)
        if not channel.get_transport().is_authenticated():
            self.log.warning("shell requested without authentication on channel %s", channel.get_id())
            return False

        if self.token.get('watch_parent'):
            self.log.debug('denying pty because of watch parent request')
            return False

        # modes are ignored
        self.env[b'TERM'] = term
        self.ctx.set_winsize(channel=channel, rows=height, cols=width, xpixel=pixelheight, ypixel=pixelwidth)
        return True

    def check_channel_request(self, kind, chanid):
        kind = b(kind)
        self.log.debug("check_channel_request: (%r, %r)", kind, chanid)
        if kind == 'session':
            return paramiko.OPEN_SUCCEEDED
        return paramiko.OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED

    def check_channel_exec_request(self, channel, command):
        command = b(command)
        self.log.debug("check_channel_exec_request: (channel %r, %r)", channel.get_id(), command)

        if not channel.get_transport().is_authenticated():
            self.log.warning("shell requested without authentication on channel %s", channel.get_id())
            return False

        if self.token.get('watch_parent'):
            self.log.debug('denying exec because of watch parent request')
            return False

        try:
            signal.alarm(60)

            container_name = self.slot.container
            command = check_for_scp(command, container_name)

            session = self.ctx.make_session(channel)
            session.tag = '%s@%s' % (self.token['acc_user'], self.client_address[0])
            session.container_name = "{}/ps-{}-{}".format(container_name, os.getpid(), channel.get_id())
            session.extra_env[:0] = self.env.items()
            session.cwd = self.slot.instance_dir
            session.cmd = command
            daemonthr(
                schedule_report,
                self.log,
                transport='ssh',
                user=self.token['ssh_user'],
                acc_user=self.token['acc_user'],
                acc_host=str(self.client_address[0]),
                timeout=self.token['inactivity_timeout'],
                auth_type=type(self.authentication_key).__name__,
                auth_bits=self.authentication_key and self.authentication_key.get_bits(),
                slot_type=type(self.slot).__name__,
                slot_info=self.slot.as_auth_info(),
                command=command,
                width=session.pty_params and session.pty_params['width'],
                height=session.pty_params and session.pty_params['height'],
                forward_agent=session.agent is not None,
                reverse_forwardings=bool(self.reverse_forwardings),
                direct_forwardings=bool(self.direct_forwardings),
            )
            server = session.start_ssh(channel=channel)
            signal.alarm(0)
        except StartupException as e:
            self.log.exception('shell start failed: %s', e.msg, exc_info=sys.exc_info())
            self.ctx.close_session(channel)
            channel.sendall_stderr("Shell start failed: %s" % (e.msg,))
        except Exception as e:
            self.log.exception('shell start failed: %s', e, exc_info=sys.exc_info())
            self.ctx.close_session(channel)
            channel.sendall_stderr("Shell start failed: %s" % (e,))
        else:
            self.log.debug('started ssh command %r', command)
            daemonthr(self._run_server, channel, server, session)
            return True

        return False

    def check_channel_shell_request(self, channel):
        self.log.debug("check_channel_shell_request: (channel %r)", channel.get_id())
        if not channel.get_transport().is_authenticated():
            self.log.warning("shell requested without authentication on channel %s", channel.get_id())
            return False

        try:
            signal.alarm(60)

            session = self.ctx.make_session(channel)
            session.container_name = self.slot.container

            daemonthr(
                schedule_report,
                self.log,
                transport='ssh',
                user=self.token['ssh_user'],
                acc_user=self.token['acc_user'],
                acc_host=str(self.client_address[0]),
                timeout=self.token['inactivity_timeout'],
                auth_type=type(self.authentication_key).__name__,
                auth_bits=self.authentication_key and self.authentication_key.get_bits(),
                slot_type=type(self.slot).__name__,
                slot_info=self.slot.as_auth_info(),
                watch_parent=self.token.get('watch_parent'),
                width=session.pty_params and session.pty_params['width'],
                height=session.pty_params and session.pty_params['height'],
                forward_agent=session.agent is not None,
                reverse_forwardings=bool(self.reverse_forwardings),
                direct_forwardings=bool(self.direct_forwardings),
            )

            if self.token.get('watch_parent'):
                server = session.watch_ssh(channel=channel)
            else:
                session.container_name = "{}/ps-{}-{}".format(session.container_name, os.getpid(), channel.get_id())
                session.extra_env[:0] = self.env.items()
                session.cwd = self.slot.instance_dir
                session.tag = '%s@%s' % (self.token['acc_user'], self.client_address[0])
                server = session.start_ssh(channel=channel)
            signal.alarm(0)
        except StartupException as e:
            self.log.exception('shell start failed: %s' % (e.msg,), exc_info=sys.exc_info())
            self.ctx.close_session(channel)
            channel.sendall("Shell start failed: %s" % (e.msg,))
        except Exception as e:
            self.log.exception('shell start failed: %s' % (e,), exc_info=sys.exc_info())
            self.ctx.close_session(channel)
            channel.sendall("Shell start failed: %s" % (e,))
        else:
            daemonthr(self._run_server, channel, server, session)
            return True

        return False

    def check_channel_subsystem_request(self, channel, name):
        self.log.debug("check_channel_subsystem_request: (channel %r, %r)", channel.get_id(), name)
        return False

    def check_channel_window_change_request(self, channel, width, height, pixelheight, pixelwidth):
        self.log.debug("check_channel_window_change_request: (channel %r, %r, %r, %r, %r)",
                       channel.get_id(), width, height, pixelheight, pixelwidth)
        return self.ctx.set_winsize(channel=channel, rows=height, cols=width, xpixel=pixelheight, ypixel=pixelwidth)

    def check_global_request(self, kind, msg):
        self.log.debug("check_global_request: (%r, %r)", kind, msg)
        return False

    def get_allowed_auths(self, username):
        self.log.debug("get_allowed_auths: (%r)", username)
        return "publickey"

    @staticmethod
    def split_ssh_user(username):
        token = make_token()

        if not username.startswith('//'):
            return

        delimiter = ':'
        for part in username.split('//'):
            if not part or delimiter not in part:
                continue

            k, v = part.split(delimiter, 1)
            if k in ('user', 'u'):
                token['acc_user'] = v
            elif k in ('login_as', 'l'):
                token['ssh_user'] = v
            elif k in ('slot', 's'):
                token['lookup_args']['slot_name'] = v
            elif k in ('configuration_id', 'C'):
                token['lookup_args']['configuration_id'] = v
            elif k == 'watch_parent':
                token['watch_parent'] = v in ('true', 'True', 'TRUE', 'yes', 'Yes', 'YES', '1')
            elif k in ('timeout', 't'):
                token['inactivity_timeout'] = v
            else:
                token['lookup_args'][k] = v

        return token

    def verify_slot_filter(self, slot, token):
        return True

    def init_transport(self):
        paramiko.Transport._CLIENT_ID = "portoshell"
        self.transport = paramiko.Transport(sock=self.request)
        for Keyclass, keyname in (
            (paramiko.RSAKey, 'rsa'),
            (paramiko.DSSKey, 'dsa'),
            (paramiko.ECDSAKey, 'ecdsa'),
        ):
            path = os.path.join(self.server.host_keys_dir, 'ssh_host_%s_key' % (keyname,))
            if os.path.exists(path):
                key = Keyclass.from_private_key_file(path)
                self.log.debug('added host key %s [%s]' % (Keyclass.__name__, key.get_fingerprint().encode('hex')))
                self.transport.add_server_key(key)

    def handle(self):
        signal.alarm(60)
        self.log = logger.logging.MessageAdapter(
            self.server.logger.getChild('sshconn'),
            fmt='[%(pid)s] %(message)s',
            data={'pid': os.getpid()},
        )

        self.auth_error_sent = False
        self.authentication_key = None
        self.authenticated_publickey = None
        self.ctx = Context(self.log)
        self.ctx.send_warning = self.send_debug
        self.ctx.tools_tarball = self.server.tools_tarball
        self.env = {}
        self.token = None
        self.pty_params = None
        self.command = None
        self.shell = False
        self.reverse_forwardings = {}  # (host, port) => TCPServer
        self.direct_forwardings = {}  # chanid => socket

        self.init_transport()
        self.transport.start_server(server=self)
        signal.alarm(0)
        try:
            self.idle_loop()
        finally:
            self.ctx.finalize()
            for s in self.reverse_forwardings.values():
                s.shutdown()
            self.transport.close()
            self.log.debug("main thread exiting")

    def idle_loop(self):
        while self.transport.is_active():
            time.sleep(2.0)


class InContainerSSHHandler(SSHHandler):
    def __init__(self, *args, **kwargs):
        self.interfaces = kwargs.pop('interfaces')
        self.netns_container = kwargs.pop('netns_container')
        super(InContainerSSHHandler, self).__init__(*args, **kwargs)

    def split_ssh_user(self, username):
        def fix_token(token):
            if not token['acc_user'] and not token['ssh_user']:
                return None
            elif token['acc_user'] and token['ssh_user']:
                return token

            if not token['ssh_user']:
                token['ssh_user'] = (
                    ''
                    if 'pod' not in token['lookup_args'] and token['acc_user'] not in ('root', 'nobody')
                    else token['acc_user']
                )
            else:
                token['acc_user'] = token['ssh_user']

            return token

        if username.startswith('//'):
            token = super(InContainerSSHHandler, self).split_ssh_user(username)
            # token.slot_name, conf = self.select_container(self.request.getsockname()[0])[0]
            # token.configuration_id = token.configuration_id or conf
            token['lookup_args'].update(self.select_container(self.request.getsockname()[0])[1])
            token['inactivity_timeout'] = token['inactivity_timeout'] or 60 * 60 * 24 * 7
            return fix_token(token)

        token = make_token()

        # token.slot_name, token.configuration_id = self.select_container(self.request.getsockname()[0])[0]
        token['lookup_args'] = self.select_container(self.request.getsockname()[0])[0]
        token['acc_user'] = username
        # FIXME dirty hack, use more generalized code
        token['watch_parent'] = False
        token['inactivity_timeout'] = 60 * 60 * 24 * 7
        return fix_token(token)

    def verify_slot_filter(self, slot, token):
        lookup_args = token['lookup_args']
        for k, v in six.iteritems(lookup_args):
            slot_val = getattr(slot, k, None)
            if slot_val != v:
                self.log.error("mtn property %r=%r doesn't match interface slot value %r", k, slot_val, v)
                raise AuthError(
                    "possible breach attempt: mtn property %r=%r doesn't match interface slot value %r" % (
                        k,
                        slot_val,
                        v,
                    )
                )

    def select_container(self, ip):
        if ip not in self.interfaces:
            raise Exception("ip is not mapped to any container")

        return self.interfaces[ip]

    def cancel_port_forwarding_request(self, address, port):
        srv = self.reverse_forwardings.pop((address, port), None)
        if srv is not None:
            srv.shutdown()

    def check_channel_direct_tcpip_request(self, chanid, origin, destination):
        self.log.debug("check_channel_direct_tcpip_request: (chanid %r, origin %r, destination %r)",
                       chanid, origin, destination)
        host = destination[0]
        port = destination[1]

        try:
            addrs = socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM, socket.IPPROTO_TCP)
        except socket.gaierror as e:
            self.log.warning("check_channel_direct_tcpip_request: failed to resolve address: %s", e)
            return paramiko.OPEN_FAILED_CONNECT_FAILED

        if not addrs:
            self.log.warning("check_channel_direct_tcpip_request: no matching addrs found")
            return paramiko.OPEN_FAILED_CONNECT_FAILED

        for addr in addrs:
            family = addr[0]
            h = addr[4][0]
            p = addr[4][1]

            sock = self.server.create_ns_socket(self.netns_container, family)
            if sock is None:
                continue
            sock = sock[1]

            try:
                sock.connect((h, p))
            except EnvironmentError as e:
                self.log.warning("check_channel_direct_tcpip_request: connect to (%s, %s) failed: %s", h, p, e)
                sock.close()
                continue

            self.direct_forwardings[chanid] = sock
            return paramiko.OPEN_SUCCEEDED

        return paramiko.OPEN_FAILED_CONNECT_FAILED

    def check_channel_subsystem_request(self, channel, name):
        self.log.debug("check_channel_subsystem_request: (channel %r, %r)", channel.get_id(), name)

        if not channel.get_transport().is_authenticated():
            self.log.warning("subsystem requested without authentication on channel %s", channel.get_id())
            return False

        handler_class, larg, kwarg = channel.get_transport()._get_subsystem_handler(name)
        if handler_class is None:
            return False

        daemonthr(
            schedule_report,
            self.log,
            transport='ssh',
            user=self.token['ssh_user'],
            acc_user=self.token['acc_user'],
            acc_host=str(self.client_address[0]),
            timeout=self.token['inactivity_timeout'],
            auth_type=type(self.authentication_key).__name__,
            auth_bits=self.authentication_key and self.authentication_key.get_bits(),
            slot_type=type(self.slot).__name__,
            slot_info=self.slot.as_auth_info(),
            subsystem=name,
        )

        handler = handler_class(channel, name, self, *larg, **kwarg)
        handler.start()

        signal.alarm(0)
        return True

    def check_port_forward_request(self, host, listen_port):
        address = host

        self.log.debug("check_port_forward_request: (address %r, port %r)", address, listen_port)

        sock = None

        if len(address) > socket.NI_MAXHOST:
            self.log.debug("check_port_forward_request: forward host name too long")
            return False

        if not address:
            address = None
        elif address in ('*', '0.0.0.0'):
            address = '::'
        elif address == 'localhost':
            address = '::1'

        try:
            addrs = socket.getaddrinfo(address, listen_port, 0, socket.SOCK_STREAM, socket.IPPROTO_TCP)
        except socket.gaierror as e:
            self.log.warning("check_port_forward_request: failed to resolve address: %s", e)
            return False

        addrs = filter(lambda x: x[0] in (socket.AF_INET, socket.AF_INET6), addrs)
        if not addrs:
            self.log.warning("check_port_forward_request: no matching addrs found")
            return False

        for addr in addrs:
            family = addr[0]
            address = addr[4][0]
            port = addr[4][1]

            try:
                socket.getnameinfo((address, port), socket.NI_NUMERICHOST | socket.NI_NUMERICSERV)
            except socket.gaierror as e:
                self.log.warning("check_port_forward_request: getnameinfo() failed: %s", e)
                continue

            sock = self.server.create_ns_socket(self.netns_container, family)
            if sock is None:
                continue

            sock = sock[1]

            try:
                sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
                sock.bind((address, port))
                sock.listen(1)
            except EnvironmentError as e:
                self.log.warning("check_port_forward_request: bind((%r, %r)) failed: %s", address, port, e)
                sock.close()
                continue

            port = sock.getsockname()[1]

            self.reverse_forwardings[(host, port)] = srv = ReverseTcpForwardServer(
                self.log.getChild('rforward'),
                sock,
                self.transport,
                sock.getsockname(),
                ReverseTcpForwardServer.ReverseTcpForwardHandler,
            )
            daemonthr(srv.serve_forever)
            return port

        return False

    def process_forward_channel(self, channel, sock):
        try:
            forward(channel, sock)
        finally:
            sock.close()

    def idle_loop(self):
        while self.transport.is_active():
            chan = self.transport.accept(2.0)
            if chan is None:
                continue

            sock = self.direct_forwardings.pop(chan.get_id(), None)
            if sock is not None:
                daemonthr(self.process_forward_channel, chan, sock)

    def init_transport(self):
        super(InContainerSSHHandler, self).init_transport()
        if self.netns_container:
            self.transport.set_subsystem_handler(
                'sftp',
                paramiko.sftp_server.SFTPServer,
                SFTPServerInterface,
                self.netns_container)


def _unpack_attr(attr):
    if attr is None:
        return attr
    ret = paramiko.SFTPAttributes()
    ret.__dict__.update(attr)
    return ret


def _pack_attr(attr):
    return attr.__dict__ if attr is not None else None
