# cython: language_level=2

import os
import sys
import array
import errno
import signal
import socket
import logging
import traceback
import subprocess
from itertools import chain

from library.python.capabilities import capabilities
from library.python.nstools import nstools
from library.python.sendmsg import sendmsg, recvmsg, ScmRights
from kernel.util.sys.user import UserPrivileges


cdef bytes b(x):
    if sys.version_info.major < 3:
        return str(x)
    else:
        return str(x).encode()


cdef void join_cgroups(int root_pid):
    cgroups = list(map(lambda line: line.strip().split(':', 2), open('/proc/%s/cgroup' % root_pid)))
    seen_cgroup_ids = set()
    seen_cgroup_names = set()
    for cgroup in cgroups:
        if len(cgroup) < 3:
            continue
        cgroup_names = cgroup[1].split(',')
        if cgroup[0] in seen_cgroup_ids or any(cgroup_name in seen_cgroup_names for cgroup_name in cgroup_names):
            raise RuntimeError("Possible breach attempt!")
        seen_cgroup_ids.add(cgroup[0])
        seen_cgroup_names.update(cgroup_names)
        if any(cgroup_name in ('freezer', 'devices') for cgroup_name in cgroup_names):
            with open('/sys/fs/cgroup/%s%s/tasks' % (cgroup[1], cgroup[2]), 'w') as f:
                f.write(str(os.getpid()))
        elif cgroup[1] == '':  # cgroup v2
            # here we are bound to the common /sys/fs/cgroup/unified/ path of cgroup2 mountpoint in yandex
            with open('/sys/fs/cgroup/unified/%s/cgroup.procs' % (cgroup[2],), 'w') as f:
                f.write(str(os.getpid()))


cdef list collect_nses(int root_pid):
    cdef nses = []
    for (proc_ns, ns_type) in (
        ('mnt', nstools.Mount),
        ('ipc', nstools.Ipc),
        ('cgroup', 0),  # CLONE_NEWCGROUP is not exported as it's not supported in ancient linuxes
        ('user', nstools.User),
        ('uts', nstools.Uts),
        ('net', nstools.Network),
    ):
        nspath = '/proc/%s/ns/%s' % (root_pid, proc_ns)
        oldns = os.readlink('/proc/self/ns/%s' % (proc_ns,))
        newns = os.readlink(nspath)
        if oldns != newns:
            ns = open(nspath)
            nses.append((ns_type, ns))

    return nses


cdef void drop_bounds():
    for value in range(0, 64):
        if value not in (
            capabilities.cap_dac_override,
            capabilities.cap_setgid,
        ):
            try:
                # we use loop over ints, since there's no suitable interface in linux
                # to leave only some allowed bounding set.
                # And some caps may even be added into kernel while we have old libcap
                # and don't know about them, so we need to drop 'em as well.
                capabilities.Capabilities.drop_bound(value)
            except OSError as e:
                if e.errno == errno.EINVAL:
                    pass


cdef void drop_logging():
    logging._acquireLock()

    cdef handlers = frozenset(logging._handlerList)
    cdef non_empty_loggers = []
    for log in chain(list(logging.Logger.manager.loggerDict.values()), [logging.getLogger()]):
        if isinstance(log, logging.Logger) and log.handlers:
            non_empty_loggers.append(log)

    for handler in handlers:
        _handler = handler()
        if type(_handler).__name__ == 'SkynetLoggingHandler':
            try:
                _handler.close()
            except:
                pass

    for log in non_empty_loggers:
        log.handlers = [handler for handler in log.handlers if type(handler).__name__ != 'SkynetLoggingHandler']


def join_container_namespace(int communication_fd, target_user, int root_pid):
    join_cgroups(root_pid)
    nses = collect_nses(root_pid)

    for ns_type, ns in nses:
        nstools.unshare_ns(ns_type)
        nstools.move_to_ns(ns, ns_type)
        ns.close()

    os.closerange(3, communication_fd)
    os.closerange(communication_fd + 1, subprocess.MAXFD)

    caps = capabilities.Capabilities.from_text("cap_setgid,cap_setuid,cap_dac_override+ep")
    drop_bounds()
    caps.set_current_proc()

    drop_logging()

    if target_user:
        UserPrivileges(target_user, limit=False, store=False, modifyGreenlet=False).__enter__()


def run_process(callback, target_user, int root_pid):
    """
    Run job in container namespaces.
    :type callback: Callable[[socket.socket], None]
    :param callback: function to exec in the container. This function should
                     accept argument with socket that is used for communication
                     with parent.
    :param Optional[str] target_user: in-container user to drop privileges to
    :param int root_pid: PID of the root_process of container to join
    :rtyp: Tuple[socket.socket, int]
    :return: socket for communication with the process spawned, and its PID
    """
    cdef sock_parent, sock_child
    sock_parent, sock_child = socket.socketpair()

    sock_parent.settimeout(60)
    sock_child.settimeout(60)

    cdef pid1 = os.fork()
    cdef pid2
    if not pid1:
        signal.alarm(60)
        nstools.unshare_ns(nstools.Pid)

        pid2 = os.fork()
        if not pid2:
            try:
                signal.alarm(60)
                join_container_namespace(sock_child.fileno(), target_user, root_pid)
            except Exception:
                traceback.print_exc(file=sys.stderr)
            else:
                assert sock_child.recv(5) == b'ready'
                signal.alarm(0)
                callback(sock_child)
            finally:
                os._exit(0)
        else:
            sock_child.send(b'\x00')
            sock_child.send(b'%08d' % pid2)

        os._exit(0)

    sock_child.close()
    signal.alarm(60)
    try:
        status = sock_parent.recv(1)
        assert status == b'\x00'
        pid2 = int(sock_parent.recv(8), 10)
        sock_parent.sendall(b'ready')
    except:
        sock_parent.close()
        os.waitpid(pid1, 0)
        signal.alarm(0)
        raise

    os.waitpid(pid1, 0)
    signal.alarm(0)
    return sock_parent, pid2


def make_fds(callback, target_user, int root_pid):
    """
    Create fds in container namespaces.
    :type callback: Callable[[], Tuple[Iterable[int], bytes]
    :param callback: function to exec in the container.
                     It should return all fds created and any
                     auxillary data as bytestring.
    :param Optional[str] target_user: in-container user to drop privileges to
    :param int root_pid: PID of the root_process of container to join
    :return: Tuple[List[int], bytes]
    """

    cdef sock_parent, sock_child
    sock_parent, sock_child = socket.socketpair()
    cdef pid1 = os.fork()
    cdef pid2
    if not pid1:
        signal.alarm(60)
        nstools.unshare_ns(nstools.Pid)

        pid2 = os.fork()
        if not pid2:
            try:
                signal.alarm(60)
                join_container_namespace(sock_child.fileno(), target_user, root_pid)
                fds, aux_data = callback()
                assert all(isinstance(fd, int) for fd in fds)
                assert isinstance(aux_data, bytes)
            except Exception as e:
                sock_child.sendall(b'\x01')
                sock_child.sendall(b'Fd creation failed with: {}'.format(e))
            else:
                sendmsg(sock_child, b'\x00', [(socket.SOL_SOCKET, ScmRights, array.array('i', fds).tostring())])
                sock_child.sendall(aux_data)
            finally:
                os._exit(0)

        os._exit(0)

    sock_child.close()
    msg, ancdata, _ = recvmsg(sock_parent)
    try:
        if msg == b'\x01':
            exc = sock_parent.recv(1024)
            raise Exception(exc)

        fds = array.array('i')
        for cmsg_level, cmsg_type, cmsg_data in ancdata:
            if (cmsg_level == socket.SOL_SOCKET and cmsg_type == ScmRights):
                fds.fromstring(cmsg_data[:len(cmsg_data) - (len(cmsg_data) % fds.itemsize)])

        aux_data = sock_parent.recv(1024)
    finally:
        sock_parent.close()
        signal.alarm(60)
        os.waitpid(pid1, 0)
        signal.alarm(0)
    return list(fds), aux_data
