import os
import time
import errno
import select
import socket
import logging
import threading
from itertools import chain
from collections import deque
from SocketServer import ThreadingTCPServer, BaseRequestHandler, BaseServer

from ya.skynet.util.misc import daemonthr

from .portotools import get_portoconn
from . import logger

from library.python.nstools.nstools import move_to_ns, Network


def _eintr_retry(func, *args):
    """restart a system call interrupted by EINTR"""
    while True:
        try:
            return func(*args)
        except (OSError, select.error) as e:
            if e.args[0] != errno.EINTR:
                raise


class Server(object):
    request_queue_size = 5
    active_children = set()
    timeout = 60

    @classmethod
    def collect_children(cls, log):
        """
        Should be called in some thread once and forever
        WARNING: conflicts with any `subprocess` stuff and others
        """
        log = log.getChild('taras-bulba')
        while True:
            try:
                pid, status = os.waitpid(-1, 0)
                if pid in cls.active_children:
                    log.debug('pid %d collected (exit-code %d)', pid, status)
                    cls.active_children.discard(pid)
            except EnvironmentError as e:
                if e.errno == errno.ECHILD:
                    # we have no children, don't waste cpu
                    time.sleep(1.0)
                continue

    def __init__(self,
                 log,
                 telnet_port,
                 ssh_port,
                 check_auth=True,
                 tools_tarball=None,
                 iss=True,
                 host_keys_dir='/etc/ssh',
                 keys_storage=None,
                 ca_storage=None,
                 hostname=None,
                 ):
        self.check_auth = check_auth
        self.tools_tarball = tools_tarball
        self.iss = iss
        self.inode = None
        self.host_keys_dir = host_keys_dir
        self.keys_storage = keys_storage
        self.ca_storage = ca_storage
        self.hostname = hostname

        self.__is_shut_down = threading.Event()
        self.__pipe = os.pipe()

        self.logger = log
        self._set_log()

        self.host_telnet_socket = None
        self.host_ssh_socket = None
        self.mtn_sockets = {}

        if telnet_port == -1:  # for tests
            _, self.host_telnet_socket = self.bind_socket(('::', 0))
        elif telnet_port:
            _, self.host_telnet_socket = self.bind_socket(('::', telnet_port))

        if ssh_port == -1:  # for tests
            _, self.host_ssh_socket = self.bind_socket(('::', 0))
        elif ssh_port:
            _, self.host_ssh_socket = self.bind_socket(('::', ssh_port))

    def bind_socket(self, address, netns_container=None):
        if netns_container is None:
            inode = None
            sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
        else:
            result = self.create_ns_socket(netns_container)
            if not result:
                return
            inode, sock = result

        sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        sock.bind(address)
        # self.server_address = self.socket.getsockname()
        sock.listen(self.request_queue_size)
        return inode, sock

    def _set_log(self):
        self.log = logger.logging.MessageAdapter(
            self.logger.getChild('server'),
            fmt='[%(pid)s] %(message)s',
            data={'pid': os.getpid()},
        )

    def create_ns_socket(self, netns_container=None, family=None):
        sock_storage = []
        family = family or socket.AF_INET6

        def ns_socket_thread():
            with open('/proc/self/ns/net') as f:
                own_inode = os.fstat(f.fileno()).st_ino

            try:
                root_pid = get_portoconn(True).Find(netns_container).GetData('root_pid')
                ns = open('/proc/%s/ns/net' % (root_pid,))
                move_to_ns(ns, Network)
            except Exception:
                self.log.exception("failed to set network namespace of %s, skipping socket creation",
                                   netns_container)
            else:
                inode = os.fstat(ns.fileno()).st_ino
                if inode == own_inode:
                    self.log.warning("failed to create socket in %s, network namespace matches own",
                                     netns_container)
                    return

                sock = socket.socket(family, socket.SOCK_STREAM)
                sock_storage.append((inode, sock))

        daemonthr(ns_socket_thread).join()
        if sock_storage:
            return sock_storage[0]

    def shutdown_request(self, request):
        try:
            request.shutdown(socket.SHUT_WR)
        except socket.error:
            pass
        request.close()

    def _handle_request(self, sock):
        try:
            request, client_address = sock.accept()
        except socket.error:
            return

        try:
            from .ssh import SSHHandler, InContainerSSHHandler
            from .raw import RawHandler

            handler = None
            kwargs = {}
            info = self.mtn_sockets.get(sock)
            if info:
                handler = InContainerSSHHandler
                kwargs['netns_container'] = info['netns_container']
                kwargs['interfaces'] = info['interface_map']
            elif sock is self.host_telnet_socket:
                handler = RawHandler
            elif sock is self.host_ssh_socket:
                handler = SSHHandler
            else:
                self.log.warning("peer to unknown socket from %s, closing connection", client_address)
                return

            self.process_request(request, client_address, handler, kwargs)
        except:
            self.log.exception("Exception happened during processing of request from %s", client_address)
        finally:
            self.shutdown_request(request)

    def _poll_once(self, poll_interval):
        result = _eintr_retry(
            select.select,
            [
                self.__pipe[0]
            ] + ([
                self.host_telnet_socket
            ] if self.host_telnet_socket else []) + ([
                self.host_ssh_socket
            ] if self.host_ssh_socket else []) + [
                sock for sock in self.mtn_sockets.keys()
            ],
            [], [], poll_interval
        )[0]
        if self.__pipe[0] in result:
            # shutdown requested
            os.read(self.__pipe[0], 1024)
            return False

        for sock in result:
            self._handle_request(sock)

        return True

    def serve_forever(self, poll_interval=5.):
        self.__is_shut_down.clear()
        try:
            while self._poll_once(poll_interval):
                pass
        finally:
            self.__is_shut_down.set()

    def shutdown(self):
        os.write(self.__pipe[1], '1')
        self.__is_shut_down.wait()

    def server_close(self):
        for sock in chain(
            [self.host_telnet_socket, self.host_ssh_socket],
            list(self.mtn_sockets.keys())
        ):
            if sock is not None:
                try:
                    sock.close()
                except socket.error:
                    pass
        os.close(self.__pipe[0])
        os.close(self.__pipe[1])
        self.__pipe = None

    def process_request(self, request, client_address, handler, kwargs):
        self.log.info("incoming connection from %s", client_address)
        logging._acquireLock()
        try:
            pid = os.fork()

            if not pid:
                # almost every logging.Handler has a lock that has to be recreated after fork.
                # Python3 deals with it, but Python2 doesn't
                for log_handler_ref in logging._handlerList:
                    log_handler = log_handler_ref()
                    if log_handler is not None:
                        log_handler.createLock()

        finally:
            logging._releaseLock()

        if pid:
            self.log.debug("incoming connection %s -> %s, created pid %s", client_address, request.getsockname(), pid)
            self.active_children.add(pid)
            request.close()
            return

        try:
            self._set_log()
            self.log.debug("worker forked")
            self.server_close()
            # print >>sys.stderr, "[%s] server closed" % os.getpid()
            handler(request, client_address, self, **kwargs)
            self.log.debug("request finished")
            self.shutdown_request(request)
            self.log.debug("request shut down")
        except EOFError as e:
            self.log.error("Client disconnected with %s: %s", type(e).__name__, client_address)
            self.shutdown_request(request)
        except BaseException as e:
            self.log.error("Exception happened during processing of request from %s", client_address)
            self.log.exception("handling error %r %s", e, e)
            self.shutdown_request(request)
        finally:
            os._exit(0)


class ReverseTcpForwardServer(ThreadingTCPServer):
    daemon_threads = True
    allow_reuse_address = True

    def __init__(self, log, sock, transport, *args, **kwargs):
        self.ssh_transport = transport
        BaseServer.__init__(self, *args, **kwargs)
        self.logger = log
        self._set_log()
        self.socket = sock

    def _set_log(self):
        self.log = logger.logging.MessageAdapter(
            self.logger.getChild('server'),
            fmt='[%(pid)s] %(message)s',
            data={'pid': os.getpid()},
        )

    class ReverseTcpForwardHandler(BaseRequestHandler):
        def handle(self):
            channel = self.server.ssh_transport.open_forwarded_tcpip_channel(
                self.request.getpeername(),
                self.server.socket.getsockname()
            )

            forward(channel, self.request)


def forward(channel, sock):
    r = [channel, sock]
    w = [sock]

    to_channel = deque()
    to_socket = deque()

    socket_dead = False

    while not socket_dead:
        if channel.eof_received:
            channel.close()
            return

        rr, rw, _ = select.select(r, w, [], 120.)

        if channel.eof_received:
            channel.close()
            return

        if channel in rr:
            to_socket.append(channel.recv(16384))
        if sock in rr:
            data = sock.recv(16384, socket.MSG_DONTWAIT)
            if not data:
                socket_dead = True
            else:
                to_channel.append(data)

        while to_channel:
            channel.sendall(to_channel.popleft())

        if to_socket and sock in rw and not socket_dead:
            sock.sendall(to_socket.popleft())

    channel.close()
