""" OS-specific wrappers and helpers. """

from __future__ import absolute_import

import os
import re
import sys
import stat
import time
import ctypes
import platform
import collections

import functools as ft

import six

from .. import enum
from .. import config
from .. import patterns

from .user import *  # noqa
from .subprocess import *  # noqa


class CGroup(object):
    """
    Representation of linux control groups (cgroups)

    Usage examples:

    .. code-block:: python

        ## cgroups of the current process
        cgroup = CGroup()

        ## cgroups of the specific process
        cgroup = CGroup(pid=<process id>)

        ## cgroups with path relative to cgroups of the current process
        cgroup = CGroup("custom/path")

        ## cgroups with path relative to cgroups of the specific process
        cgroup = CGroup("custom/path", pid=<process id>)

        ## cgroups with specific absolute path
        cgroup = CGroup("/custom/path")

        ## OS does not support cgroups
        assert CGroup() is None

        ## specification cgroup to subsystems
        cgroup.cpu
        cgroup.cpuacct
        cgroup.devices
        cgroup.freezer
        assert cgroup.freezer.name == "freezer"
        cgroup.memory
        cgroup["memory"]
        ...
        for subsystem in cgroup:  # iterate over all subsystems of the cgroup
            ...

        ## changing path of the cgroup
        cgroup = CGroup("/custom/path")
        assert cgroup.name == "/custom/path"
        cgroup2 = cgroup >> "subpath"
        assert cgroup2.name == "/custom/path/subpath"
        assert (cgroup << 1).name == "/custom"
        assert (cgroup2 << 2).name == "/custom"
        assert (cgroups.memory >> "subpath").cgroup.name == "/custom/path/subpath"
        assert (cgroups2.cpu << 2).cgroup.name == "/custom"

        ## create cgroup for all subsystems, does nothing if already exists
        cgroup = CGroup().create()
        assert cgroup.exists

        ## create cgroup for specific subsystem, does nothing if already exists
        cgroup = CGroup().freezer.create()
        assert cgroup.freezer.exists

        ## delete cgroup for all subsystems, does nothing if not exists
        cgroup = CGroup().delete()
        assert not cgroup.exists

        ## delete cgroup for specific subsystem, does nothing if not exists
        cgroup = CGroup().freezer.delete()
        assert not cgroup.freezer.exists

        ## check the occurrence of the process in the cgroup
        cgroup = CGroup()
        assert os.getpid() in cgroup  # process there is at least in one of subsystem

        ## add process to all subsystems of the cgroup
        cgroup = CGroup("custom/path")
        cgroup += <pid>
        assert <pid> in cgroup
        assert all(<pid> in subsystem for subsystem in cgroup)

        ## add process to specific subsystem of the cgroup
        freezer = CGroup("custom/path").freezer
        freezer += <pid>
        assert <pid> in freezer
        # or
        freezer.tasks = <pid>
        assert <pid> in freezer.tasks

        ## changing cgroup's limits
        cgroup = CGroup()
        cgroup.memory["limit_in_bytes"] = "11G"
        assert cgroup.memory["limit_in_bytes"] == "11G"

        ## common usage example
        import subprocess as sp
        from sandbox import common
        cg = common.os.CGroup("my_group")
        cg.memory["low_limit_in_bytes"] = 16911433728
        cg.memory["limit_in_bytes"] = 21421150208
        cg.cpu["smart"] = 1
        sp.Popen(<cmd>, preexec_fn=cg.set_current)
    """
    ROOT = "/sys/fs/cgroup"
    EXCLUDE_SUBSYSTEMS = ["cpuset"]

    @property
    class Subsystem(object):
        __name = None

        # *_ added just for dummy PyCharm
        def __init__(self, cgroup, *_):
            self.__cgroup = cgroup

        def __call__(self, name, pid=None, owner=None):
            assert name
            path = os.path.join(CGroup.ROOT, name)
            self.__name = real_name = name
            if os.path.exists(path) and os.path.islink(path):
                real_name = os.readlink(path)
            self.__owner = owner
            if pid is not None:
                name_substr = ":{}:".format(real_name)
                filename = "/proc/{}/cgroup".format(pid)
                if os.path.exists(filename):
                    with open(filename) as f:
                        cgroup_name = next(iter(filter(
                            lambda _: name_substr in _, f.readlines()
                        )), "").split(":", 3)[-1].strip()
                    return type(self.__cgroup)(
                        os.path.join(cgroup_name, self.__cgroup.name.lstrip("/")),
                        owner=owner
                    )[self.__name] if cgroup_name else None

            return self

        def __repr__(self):
            return "<{}.{}({}: {})>".format(
                type(self.cgroup).__name__, type(self).__name__, self.__name, self.__cgroup.name
            )

        @property
        def name(self):
            return self.__name

        @property
        def cgroup(self):
            return self.__cgroup

        @property
        def path(self):
            return os.path.join(CGroup.ROOT, self.__name, self.__cgroup.name.lstrip("/"))

        @property
        def exists(self):
            return os.path.exists(self.path)

        @property
        def tasks(self):
            with open(os.path.join(self.path, "tasks")) as f:
                return [int(_) for _ in f.readlines()]

        @tasks.setter
        def tasks(self, pid):
            self.create()
            with open(os.path.join(self.path, "tasks"), "wb") as f:
                f.write(six.ensure_binary("{}\n".format(pid)))

        def create(self):
            if not self.exists:
                os.makedirs(self.path, mode=0o755)
                if self.__owner:
                    import pwd
                    os.chown(self.path, pwd.getpwnam(self.__owner).pw_uid, -1)
            return self

        def delete(self):
            if self.exists:
                for _ in self:
                    _.delete()
                os.rmdir(self.path)
            return self

        def set_current(self):
            """ Place current process to the subsystem of cgroup """
            if self.exists:
                self.tasks = os.getpid()

        def __resource_path(self, resource):
            return os.path.join(self.path, ".".join((self.__name, resource)))

        def __getitem__(self, resource):
            with open(self.__resource_path(resource)) as f:
                return [_.strip() for _ in f.readlines()]

        def __setitem__(self, resource, value):
            self.create()
            with open(self.__resource_path(resource), "wb") as f:
                f.write(six.ensure_binary("{}\n".format(value)))

        def __iter__(self):
            path = self.path
            for _ in os.listdir(path):
                if os.path.isdir(os.path.join(path, _)):
                    yield type(self.__cgroup)(os.path.join(self.__cgroup.name, _), owner=self.__owner)[self.__name]

        def __rshift__(self, name):
            return (self.__cgroup >> name)[self.__name]

        def __lshift__(self, level):
            return (self.__cgroup << level)[self.__name]

        def __contains__(self, pid):
            return (
                self
                if pid in self.tasks else
                next(six.moves.filter(None, six.moves.map(lambda _: _.__contains__(pid), self)), None)
            )

        def __iadd__(self, pid):
            self.tasks = pid

    def __new__(cls, *args, **kws):
        if cls.mounted:
            return super(CGroup, cls).__new__(cls)
        assert not config.Registry().client.cgroup_available, "cgroups available but not mounted"

    def __init__(self, name=None, pid=None, owner=None, iterate_all_subsystems=False):
        self.__name = (name or "").rstrip("/") if name != "/" else name
        self.__pid = pid if pid is not None or self.__name.startswith("/") else os.getpid()
        self.__owner = owner
        self.__iterate_all_subsystems = iterate_all_subsystems

    def __repr__(self):
        return "<{}({})>".format(type(self).__name__, self.__name)

    def __rshift__(self, name):
        return type(self)(
            os.path.join(self.__name, name) if self.__name else name,
            iterate_all_subsystems=self.__iterate_all_subsystems
        )

    def __lshift__(self, level):
        return type(self)(
            ft.reduce(lambda p, _: os.path.dirname(p), range(level), self.__name),
            owner=self.__owner, iterate_all_subsystems=self.__iterate_all_subsystems
        )

    def __getitem__(self, subsystem):
        return self.Subsystem(subsystem, self.__pid, self.__owner)

    def __iter__(self):
        for subsys_name in os.listdir(self.ROOT):
            if not self.__iterate_all_subsystems and subsys_name in self.EXCLUDE_SUBSYSTEMS:
                continue
            subsys_path = os.path.join(self.ROOT, subsys_name)
            if os.path.isdir(subsys_path) and os.path.exists(os.path.join(subsys_path, "tasks")):
                subsys = self.Subsystem(subsys_name, self.__pid, self.__owner)
                if subsys is not None:
                    yield subsys

    def __contains__(self, pid):
        return any(six.moves.map(lambda subsys: pid in subsys, self))

    def __iadd__(self, pid):
        for subsys in self:
            subsys.__iadd__(pid)

    @patterns.classproperty
    def mounted(self):
        return os.path.isdir(self.ROOT) and bool(os.listdir(self.ROOT))

    @property
    def name(self):
        return self.__name

    @property
    def exists(self):
        return any(six.moves.map(lambda subsys: subsys.exists, self))

    def create(self):
        for subsys in self:
            subsys.create()
        return self

    def delete(self):
        for subsys in self:
            subsys.delete()
        return self

    def set_current(self):
        """ Place current process to all subsystems of the cgroup """
        for subsys in self:
            subsys.set_current()

    @property
    def cpu(self):
        return self.Subsystem("cpu", self.__pid, self.__owner)

    @property
    def cpuacct(self):
        return self.Subsystem("cpuacct", self.__pid, self.__owner)

    @property
    def devices(self):
        return self.Subsystem("devices", self.__pid, self.__owner)

    @property
    def freezer(self):
        return self.Subsystem("freezer", self.__pid, self.__owner)

    @property
    def memory(self):
        return self.Subsystem("memory", self.__pid, self.__owner)


class FreezerState(enum.Enum):
    FROZEN = None
    THAWED = None


class Namespace(object):
    """ Representation of Linux namespaces """

    NS_ID_REGEXP = re.compile(r"([a-z]+):\[(\d+)]")
    ParsedNamespace = collections.namedtuple("ParsedNamespace", "type id")

    class Type(enum.Enum):
        enum.Enum.lower_case()

        IPC = None
        MNT = None
        NET = None
        PID = None
        USER = None
        UTS = None

    class CloneFlags(enum.Enum):
        CLONE_NEWUSER = 0x10000000
        CLONE_NEWNS = 0x00020000

    def __init__(self, namespace):
        self.__namespace = namespace
        self.__parsed_namespace = self.ParsedNamespace(*self.NS_ID_REGEXP.match(namespace).groups())

    def __repr__(self):
        return "<{}: {}>".format(type(self).__name__, self.__namespace)

    @patterns.singleton_classproperty
    def _libc(self):
        return ctypes.CDLL("libc.so.6")

    @property
    def parsed(self):
        return self.__parsed_namespace

    @classmethod
    def from_pid(cls, pid, ns_type):
        try:
            if ns_type not in cls.Type:
                # noinspection PyTypeChecker
                raise ValueError("ns_type must be one of: {}".format(list(cls.Type)))
            ns_path = "/proc/{}/ns/{}".format(pid, ns_type)
            return Namespace(os.readlink(ns_path))
        except (OSError, IOError):
            pass

    @property
    def pids(self):
        pids = []
        root_path = "/proc"
        for pid in os.listdir(root_path):
            if not pid.isdigit():
                continue
            path = os.path.join(root_path, pid, "ns", self.__parsed_namespace.type)
            try:
                if os.readlink(path) == self.__namespace:
                    pids.append(int(pid))
            except (OSError, IOError):
                pass
        return pids

    @classmethod
    def unshare(cls, flags, uid_map=None, gid_map=None):
        """
        Replaces command `unshare` but with additional mapping ability
        """

        pid = os.getpid()
        cls._libc.unshare(ctypes.c_int(flags))
        if uid_map:
            with open("/proc/{}/uid_map".format(pid), "w") as f:
                f.write(uid_map)
        if gid_map:
            with open("/proc/{}/setgroups".format(pid), "w") as f:
                f.write("deny")
            with open("/proc/{}/gid_map".format(pid), "w") as f:
                f.write(gid_map)


class SystemStatistics(object):
    """
    Manually accumulate system resources usage, such as RAM and CPU percentage, and yield them on demand.
    Updates that occur more than once in a second are ignored

    Usage:

        .. code-block:: python

            import time

            meter = SystemStatistics()
            for _ in xrange(10):
                meter.checkpoint()
                time.sleep(1)

            # every time __iter__ is called, the accumulated points are popped
            points = list(meter)

    Available statistics:
        - CPU load, percents of total CPU time available; read from /proc/{pid}/stat
        - RSS (memory currently occupied by a process) and VMS (total memory allocated); read from /proc/{pid}/statm
    """

    Point = collections.namedtuple("Point", "time user_cpu system_cpu rss vms")

    def __init__(self, pid=None):
        self.__pid = pid or os.getpid()
        self.__proc_stat = "/proc/{}/stat".format(self.__pid)
        self.__proc_statm = "/proc/{}/statm".format(self.__pid)
        self.__last_time = None
        self.__last_user_ticks = None
        self.__last_system_ticks = None
        self.__last_ticks = None
        self.__points = []
        self.CLOCK_TICKS = os.sysconf("SC_CLK_TCK")
        self.PAGESIZE = os.sysconf("SC_PAGE_SIZE")

    def __update(self, now, user_ticks, system_ticks, ticks):
        self.__last_time = now
        self.__last_user_ticks = user_ticks
        self.__last_system_ticks = system_ticks
        self.__last_ticks = ticks

    def checkpoint(self):
        now = time.time()
        ticks = int(now * self.CLOCK_TICKS)
        now = int(now)
        if self.__last_time >= now:
            return
        with open(self.__proc_stat) as f:
            st = f.read().strip()
            user_ticks, system_ticks = map(int, st[st.find(")") + 2:].split()[11:13])
        if self.__last_time is None:
            self.__update(now, user_ticks, system_ticks, ticks)
            return
        with open(self.__proc_statm) as f:
            vms, rss = map(lambda _: int(_) * self.PAGESIZE, f.readline().split()[:2])
        ticks_delta = ticks - self.__last_ticks
        user_cpu = (user_ticks - self.__last_user_ticks) * 100 / ticks_delta
        system_cpu = (system_ticks - self.__last_system_ticks) * 100 / ticks_delta
        self.__points.append(
            self.Point(time=self.__last_time, user_cpu=user_cpu, system_cpu=system_cpu, rss=rss, vms=vms)
        )
        self.__update(now, user_ticks, system_ticks, ticks)

    def __iter__(self):
        points, self.__points = self.__points, []
        for point in points:
            yield point


def real_pid(pid, namespace):
    """ Returns pid on host for pid from namespace """
    for ns_pid in namespace.pids:
        try:
            with open("/proc/{}/status".format(ns_pid)) as f:
                for line in f.readlines():
                    rows = line.split()
                    if rows[0] == "NSpid:":
                        if pid == int(rows[-1]):
                            return int(rows[1])
                        break
        except (OSError, IOError):
            pass


def path_env(value, prepend=True, key="PATH"):
    """ Mostly used to start 3rd-party binaries subprocesses. Additionally cut-off skynet binaries path. """
    env = list(six.moves.filter(
        lambda x: not x.startswith("/skynet/python/bin"),
        os.environ.get(key, "").split(os.pathsep)
    ))
    (lambda x: env.insert(0, x) if prepend else env.append)(value)
    return os.pathsep.join(filter(None, env))


def system_log_path(prefix="/"):
    """ Returns the path to system log, depending on platform """
    if sys.platform.startswith("linux"):
        valid_paths = ["var/log/messages", "var/log/syslog"]
        for tail in valid_paths:
            path = os.path.join(prefix, tail)
            if os.path.exists(path):
                return path
    elif sys.platform == "darwin":
        return "/var/log/system.log"


def processes():
    """
    The method is a simple replacement for :py:mod:`psutil`. It will list `/proc` directory and yield
    each found process basic information (UID, GID and PID).
    The method is suitable to work on Linux, FreeBSD and Cygwin.
    """

    proc_t = collections.namedtuple("Proc", "pid uid gid exe")
    if platform.system().startswith("CYGWIN"):
        for dname in os.listdir("/proc"):
            try:
                pid = int(dname)
                st = os.stat(os.path.join("/proc", dname))
                if st.st_mode & stat.S_IFDIR:
                    yield proc_t(pid, st.st_uid, st.st_gid, None)
            except (ValueError, OSError):
                pass
    else:
        # noinspection PyUnresolvedReferences
        import psutil

        def __safe_exe(p):
            try:
                return p.exe() if callable(p.exe) else p.exe
            except (psutil.NoSuchProcess, psutil.AccessDenied):
                return p.name() if callable(p.name) else p.name

        for proc in psutil.process_iter():
            try:
                if callable(proc.uids):
                    yield proc_t(proc.pid, proc.uids().real, proc.gids().real, __safe_exe(proc))
                else:
                    yield proc_t(proc.pid, proc.uids.real, proc.gids.real, __safe_exe(proc))
            except (psutil.NoSuchProcess, psutil.AccessDenied):
                pass


class Capabilities(object):
    """
    Linux capabilities manager

    See:
        https://man7.org/linux/man-pages/man7/capabilities.7.html
        https://man7.org/linux/man-pages/man3/cap_get_proc.3.html
    """

    class Cap(patterns.Abstract):
        DEFAULT_CAP_VERSION = 0x20080522  # _LINUX_CAPABILITY_VERSION_3

        class Bits(enum.Enum):
            """
            See https://github.com/torvalds/linux/blob/master/include/uapi/linux/capability.h
            """

            CAP_CHOWN = 1
            CAP_DAC_OVERRIDE = 1 << 1
            CAP_DAC_READ_SEARCH = 1 << 2
            CAP_FOWNER = 1 << 3
            CAP_FSETID = 1 << 4
            CAP_KILL = 1 << 5
            CAP_SETGID = 1 << 6
            CAP_SETUID = 1 << 7
            CAP_SETPCAP = 1 << 8
            CAP_LINUX_IMMUTABLE = 1 << 9
            CAP_NET_BIND_SERVICE = 1 << 10
            CAP_NET_BROADCAST = 1 << 11
            CAP_NET_ADMIN = 1 << 12
            CAP_NET_RAW = 1 << 13
            CAP_IPC_LOCK = 1 << 14
            CAP_IPC_OWNER = 1 << 15
            CAP_SYS_MODULE = 1 << 16
            CAP_SYS_RAWIO = 1 << 17
            CAP_SYS_CHROOT = 1 << 18
            CAP_SYS_PTRACE = 1 << 19
            CAP_SYS_PACCT = 1 << 20
            CAP_SYS_ADMIN = 1 << 21
            CAP_SYS_BOOT = 1 << 22
            CAP_SYS_NICE = 1 << 23
            CAP_SYS_RESOURCE = 1 << 24
            CAP_SYS_TIME = 1 << 25
            CAP_SYS_TTY_CONFIG = 1 << 26
            CAP_MKNOD = 1 << 27
            CAP_LEASE = 1 << 28
            CAP_AUDIT_WRITE = 1 << 29
            CAP_AUDIT_CONTROL = 1 << 30
            CAP_SETFCAP = 1 << 31
            CAP_MAC_OVERRIDE = 1 << 32
            CAP_MAC_ADMIN = 1 << 33
            CAP_SYSLOG = 1 << 34
            CAP_WAKE_ALARM = 1 << 35
            CAP_BLOCK_SUSPEND = 1 << 36
            CAP_AUDIT_READ = 1 << 37
            CAP_PERFMON = 1 << 38
            CAP_BPF = 1 << 39

        __slots__ = "version", "effective", "permitted", "inheritable"
        __defs__ = (DEFAULT_CAP_VERSION,) + (0,) * (len(__slots__) - 1)

        def set_effective(self, bits):
            self.effective |= bits

        def unset_effective(self, bits):
            self.effective &= ~bits

        def check_effective(self, bits):
            return self.effective & bits

        def set_permitted(self, bits):
            self.permitted |= bits

        def unset_permitted(self, bits):
            self.permitted &= ~bits

        def check_permitted(self, bits):
            return self.permitted & bits

        def set_inheritable(self, bits):
            self.inheritable |= bits

        def unset_inheritable(self, bits):
            self.inheritable &= ~bits

        def check_inheritable(self, bits):
            return self.inheritable & bits

    # noinspection PyPep8Naming
    class cap_t(ctypes.Structure):
        """
        .. code-block:: c

        struct {
            __u32 version;
            int pid;
            __u32 effective;
            __u32 permitted;
            __u32 inheritable;
            __u32 effective2;
            __u32 permitted2;
            __u32 inheritable2;
        };
        """

        _fields_ = [
            ("version", ctypes.c_uint),
            ("pid", ctypes.c_int),
            ("effective", ctypes.c_uint),
            ("permitted", ctypes.c_uint),
            ("inheritable", ctypes.c_uint),
            ("effective2", ctypes.c_uint),
            ("permitted2", ctypes.c_uint),
            ("inheritable2", ctypes.c_uint),
        ]

    def __init__(self, effective_bits=0):
        self.__effective_bits = effective_bits
        self.__saved_cap = None

    def __enter__(self):
        self.__saved_cap = self.get()
        cap = self.__saved_cap.copy()
        cap.set_effective(self.__effective_bits)
        self.set(cap)

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.set(self.__saved_cap)

    @patterns.singleton_classproperty
    def _libcap(self):
        return ctypes.CDLL("libcap.so.2")

    @classmethod
    def get(cls, pid=0):
        cap_get_pid = cls._libcap.cap_get_pid
        cap_get_pid.restype = ctypes.POINTER(cls.cap_t)
        ret = cap_get_pid(pid)
        cap = ret.contents
        cap = cls.Cap(
            version=cap.version,
            effective=cap.effective2 << 32 | cap.effective,
            permitted=cap.permitted2 << 32 | cap.permitted,
            inheritable=cap.inheritable2 << 32 | cap.inheritable,
        )
        cls._libcap.cap_free(ret)
        return cap

    @classmethod
    def set(cls, cap):
        cap_init = cls._libcap.cap_init
        cap_init.restype = ctypes.POINTER(cls.cap_t)
        new_cap_p = cls._libcap.cap_init()
        new_cap = new_cap_p.contents
        new_cap.version = cap.version
        new_cap.pid = 0
        new_cap.effective = cap.effective & 0xffffffff
        new_cap.permitted = cap.permitted & 0xffffffff
        new_cap.inheritable = cap.inheritable & 0xffffffff
        new_cap.effective2 = cap.effective >> 32
        new_cap.permitted2 = cap.permitted >> 32
        new_cap.inheritable2 = cap.inheritable >> 32
        ret = cls._libcap.cap_set_proc(new_cap_p)
        cls._libcap.cap_free(new_cap_p)
        return not ret

    @classmethod
    def get_bound(cls, cap_bits):
        cap_get_bound = cls._libcap.cap_get_bound
        set_bits = 0
        bit_index = 0
        while cap_bits:
            set_bits <<= 1
            if cap_bits & 1 and cap_get_bound(bit_index) == 1:
                set_bits |= 1
            cap_bits >>= 1
            bit_index += 1
        return set_bits

    @classmethod
    def drop_bound(cls, cap_bits):
        cap_drop_bound = cls._libcap.cap_drop_bound
        set_bits = 0
        bit_index = 0
        while cap_bits:
            set_bits <<= 1
            if cap_bits & 1 and cap_drop_bound(bit_index) == 0:
                set_bits |= 1
            cap_bits >>= 1
            bit_index += 1
        return set_bits
