from __future__ import print_function

from functools import wraps
import argparse
import os
import py
import errno
import random
import pickle  # Pyro4 will be forced to use it, not cPickle
import signal
import stat
import sys
import logging

from gevent import spawn, sleep, socket
from gevent.event import Event
try:
    from gevent.coros import RLock
except ImportError:
    from gevent.lock import RLock

import Pyro4
import Pyro4.util
import Pyro4.socketutil
import Pyro4.socketserver.geventserver

from porto import Connection as PortoConnection
from porto.exceptions import SocketError

from kernel.pyro.server import TCPLock

from .kernel_util.console import setProcTitle
from .kernel_util.errors import formatException, saveTraceback
from .kernel_util.sys.getpeerid import getpeerid
from .kernel_util.sys.portoslave import same_container

from . import errors, utils
from .process import Proc
from .mailbox import mailbox
from .portoutils import PortoAdapter


PROGNAME = 'skynet.procman'


class Stop(Exception):
    """ Fake stop object for mailbox """
    pass


def _stop_raiser():
    raise Stop()


class AffinityGroups(object):
    """
    The class represent a storage of collected CPU affinity groups.
    At the moment the not yet known group will be requested, the group will be created
    based on the random set of available CPUs.
    The class will work on FreeBSD systems only, 'cause there's no need to stick
    SkyNet processes to specific CPUs on other OSes.
    """

    def __init__(self):
        self.groups = {}
        self.cpus = []

    def __getitem__(self, item):
        try:
            import psutil
            num_cpus = psutil.NUM_CPUS
        except ImportError:
            num_cpus = 1

        """
        The main entry point. Converts affinity metadata to list of CPUs to stick on.
        :param item: Tuple of group name and amount CPUs to be assigned to the group.
        :return: A list of CPU core number to stick on.
        """
        if not item or os.uname()[0].lower() != 'freebsd':
            # No need to perform CPU affinity on Linux (and Windows:)
            return None
        name, amount = item
        assert amount > 0
        res = self.groups.get(name, [])
        if res:
            return res
        if amount >= num_cpus:
            return None

        while amount:
            if not self.cpus:
                self.cpus = range(num_cpus)
                random.shuffle(self.cpus)
            chunk = self.cpus[:amount]
            self.cpus = self.cpus[amount:]
            amount -= len(chunk)
            res += chunk
        self.groups[name] = res = sorted(res)
        return res


class Procs(object):
    def __init__(self, rundir):
        self.log = logging.getLogger('procs')
        self.rundir = rundir
        self.procs = {}
        self.agroups = AffinityGroups()
        self.portoconn = None
        self.portowatcher = None
        self.nobody_uid = utils.geteuid()

        self._connect_to_porto()

        # Reconnect each liner process if exists
        for path in self.rundir.listdir():
            if path.basename == 'skynet.procman.sock':
                continue

            fstat = path.stat()
            if stat.S_ISSOCK(fstat.mode):
                self.log.debug('Reconnecting liner socket at %s', path)
                Proc.reconnect(self, self.nobody_uid, sockname=path.strpath)
            elif fstat.isfile() and path.ext == '.porto':
                self.log.debug('Reconnecting porto container %s', path.basename)
                Proc.reconnect(self, self.nobody_uid, container_info_path=path.strpath)
            else:
                self.log.debug('Not reconnecting to %s: not a socket', path)

        if self.portoconn:
            self.portowatcher = spawn(self._porto_watcher)

    def _porto_watcher(self):
        while True:
            try:
                conts = self.portoconn.ListContainers()
            except Exception:
                conts = []

            for cont in conts:
                try:
                    private = cont.GetProperty('private')
                    state = cont.GetData('state')
                except Exception:
                    continue

                name = cont.name.split('/')[-1]

                if (
                    (private != 'PROCMAN' and not name.startswith('procman-')) or
                    state == 'running'
                ):
                    continue

                uuid = name[8:]  # strip 'procman-'
                if uuid not in self.procs:
                    try:
                        cont.Stop()
                    except Exception:
                        pass

                    try:
                        cont.Destroy()
                    except Exception:
                        pass
            sleep(60 * 60)

    def _connect_to_porto(self):
        try:

            with utils.auto_user_privileges('root', limit=False):
                portoconn = PortoConnection(
                    socket_constructor=socket.socket, lock_constructor=RLock, timeout=120, auto_reconnect=False
                )
                self.portoconn = PortoAdapter(portoconn, log=logging.getLogger('porto'))
        except ImportError as e:
            self.portoconn = None
            self.log.info("Porto is not available on host: %s", e)
        except (EnvironmentError, SocketError) as e:
            self.portoconn = None
            self.log.warning("Connection to porto failed: %s", e)

    def _gather_proc_cgroups(self, pid):
        possible_cgroups = set()
        process_cgroups = {}

        try:
            for entry in os.listdir('/sys/fs/cgroup'):
                path = os.path.join('/sys/fs/cgroup', entry)
                tasks = os.path.join(path, 'tasks')
                if os.path.isdir(path) and os.path.exists(tasks) and os.path.isfile(tasks):
                    possible_cgroups.add(entry)

            data = open('/proc/%d/cgroup' % (pid, ), mode='rb').read()

            for line in data.split('\n'):
                if not line:
                    continue

                _, controller, cgroup = line.split(':', 2)

                cgroup = cgroup.lstrip('/')
                controllers = controller.split(',')

                for controller in controllers:
                    if '=' in controller:
                        controller = controller.split('=', 1)[1]

                    if controller in possible_cgroups:
                        process_cgroups[controller] = cgroup
                        break
        except BaseException as ex:
            self.log.info("Failed to gather cgroups: %s", ex)

        return process_cgroups

    def create(self, pyro_connection, *args, **kwargs):
        peer_uid = None
        peer_pid = None
        peer_changed = False
        other_container = False

        if isinstance(pyro_connection, Pyro4.socketutil.SocketConnection):
            peer_uid = pyro_connection.peer_uid
            peer_pid = pyro_connection.peer_pid
            peer_changed = pyro_connection.peer_changed()
            other_container = pyro_connection.other_container

        if (
            utils.has_root()
            and os.uname()[0].lower() == 'linux'
            and peer_uid is not None
            and peer_pid
            and (other_container or peer_changed)
        ):
            with utils.auto_user_privileges():
                if self.portoconn and not peer_changed:
                    try:
                        job_container = self.portoconn.LocateProcess(peer_pid).name
                    except:
                        job_container = None
                        raise Exception(
                            'You are not allowed to run processes from another container '
                            '(was unable to detect containers with porto)'
                        )

                    try:
                        our_container = self.portoconn.LocateProcess(os.getpid()).name
                    except:
                        raise Exception(
                            'You are not allowed to run processes from another container '
                            '(our: %s, was unable to detect caller container)' % (job_container, )
                        )

                    if job_container != our_container:
                        raise Exception(
                            'You are not allowed to run processes '
                            'from another container (%s)' % (job_container, )
                        )
                else:
                    raise Exception('You are not allowed to run processes from other container')

        kwargs['cpus'] = self.agroups[kwargs.pop('affinity', None)]

        if peer_pid is not None:
            kwargs['cgroup_inherited'] = self._gather_proc_cgroups(peer_pid)
        else:
            kwargs['cgroup_inherited'] = {}

        return Proc(self, peer_pid, peer_uid, self.nobody_uid, *args, **kwargs).uuid

    create._passPyroConnection = True

    def run_debug_shell(self, pyro_connection, port):
        peer_uid = None

        if isinstance(pyro_connection, Pyro4.socketutil.SocketConnection):
            peer_uid = pyro_connection.peer_uid

        if peer_uid != 0:
            raise EnvironmentError(errno.EPERM, 'Don`t have root privileges')

        from .reverse import ReversePythonShell
        ReversePythonShell('localhost', port).start()

    run_debug_shell._passPyroConnection = True

    def find_by_pid(self, pid):
        for proc in self.enumerateExecuting():
            if proc.pid == pid:
                return proc.uuid

        raise errors.ProcessLookupException(127, 'can not find pid %s' % str(pid))

    def find_by_uuid(self, uuid):
        if uuid in self.procs:
            return uuid

        raise errors.ProcessLookupException(127, 'can not find uuid %s' % str(uuid))

    def find_by_tags_impl(self, tags, excludeTags=None):  # noqa
        exclude_tags = excludeTags

        if exclude_tags is None:
            exclude_tags = []

        tags = frozenset(tags)
        exclude_tags = frozenset(exclude_tags)

        for proc in self.enumerateRunning():
            for tag in tags:
                if tag in proc.tags:
                    if frozenset(proc.tags) & exclude_tags:
                        continue
                    yield proc

    def find_by_tags_entry_impl(self, tags):
        tags = frozenset(tags)
        for proc in self.enumerateRunning():
            if frozenset(proc.tags).issuperset(tags):
                yield proc

    def find_by_tags(self, tags, excludeTags=None):  # noqa
        return self.adapted(self.findByTagsImpl(tags, excludeTags))

    def find_by_tags_entry(self, tags):
        return self.adapted(self.findByTagsEntryImpl(tags))

    def list_tags(self):
        ret = []

        for proc in self.enumerateRunning():
            ret += proc.tags

        return list(frozenset(ret))

    def enumerate_executing(self):
        for (_, proc) in self.procs.items():
            if proc.executing():
                yield proc

    def enumerate_running(self):
        for (_, proc) in self.procs.items():
            if proc.running():
                yield proc

    # old style rpc meths
    findByPid = find_by_pid
    findByUUID = find_by_uuid
    findByTags = find_by_tags
    findByTagsImpl = find_by_tags_impl
    findByTagsEntry = find_by_tags_entry
    findByTagsEntryImpl = find_by_tags_entry_impl
    listTags = list_tags
    enumerateExecuting = enumerate_executing
    enumerateRunning = enumerate_running

    # callbacks
    def enumerate(self):
        return self.procs.keys()

    def stats(self, uid=None):
        return [proc.stat() for (uuid, proc) in self.procs.items() if uid is None or uuid == uid]

    def is_valid(self, uuid):
        try:
            self.findByUUID(uuid)

            return True
        except errors.ProcessLookupException:
            return False

    def __getattr__(self, name):
        allowed_methods = set([
            'stop_retries', 'keep_running', 'add_tags', 'delete_tags', 'get_tags', 'get_context',
            'stopRetries', 'executing', 'running',
            'keepRunning', 'send_signal', 'signal', 'kill',
            'terminate',
            'addTags', 'deleteTags', 'getTags', 'getContext',
            'stat'
        ])

        insecure_methods = set([
            'executing', 'running', 'get_tags', 'getTags', 'get_context', 'getContext', 'stat'
        ])

        if name not in allowed_methods:
            raise AttributeError('Method "%s" is not allowed' % (name, ))

        def func(pyro_connection, uuid, *args, **kwargs):
            peer_uid = None
            if isinstance(pyro_connection, Pyro4.socketutil.SocketConnection):
                peer_uid = pyro_connection.peer_uid

            proc = self.procs[self.findByUUID(uuid)]

            if (
                name not in insecure_methods and  # only for insecure methods
                proc.user is not None and         # only for processes with strict user privileges
                peer_uid != 0 and                 # only if we are not root
                peer_uid != proc.uid              # and we are not the process owner
            ):
                raise Exception(
                    'You are not allowed to use method "%s" on '
                    'process "%s" (your uid: %d, process uid: %d)' % (
                        name, proc.uuid, peer_uid, proc.uid
                    )
                )

            return getattr(proc, name)(*args, **kwargs)

        func._passPyroConnection = True
        return func

    @staticmethod
    def adapted(lst):
        return list(frozenset([x.uuid for x in lst]))

    def shutdown(self):
        if self.portowatcher is not None:
            self.portowatcher.kill()
        for proc in [proc for proc in self.enumerateRunning() if proc.liner is None]:
            proc.keepRunning(False)

            try:
                proc.kill()
            except Exception:
                pass


def withtrace(func):
    @wraps(func)
    def wrapper(*args, **kwargs):
        try:
            return func(*args, **kwargs)
        except Exception as ex:
            saveTraceback(ex)
            raise

    return wrapper


class Handler(object):
    def __init__(self, rundir):
        self.procs = Procs(rundir)

    def ping(self):
        return True

    def closefds(self):
        fd = os.open(getattr(os, 'devnull', '/dev/null'), os.O_RDWR)
        os.dup2(fd, 0)
        os.dup2(fd, 1)
        os.dup2(fd, 2)
        os.close(fd)

    def __getattr__(self, name):
        if name == '_pyroId':  # pyro checks that attribute
            raise AttributeError()
        return withtrace(getattr(self.procs, name))


def watchdog(orig_ino, path, procs, timeout=30, kill=os.kill):
    """
    watch for unix socket path changes
    and kill ourself if we're not accessible from outside
    """
    log = logging.getLogger('watchdog')

    def bailout(msg):
        log.error(msg)
        log.error('killing children and myself')
        try:
            for proc in procs.enumerate_executing():
                proc.stop()
        except BaseException:
            pass
        kill(os.getpid(), signal.SIGKILL)

    if orig_ino == -1:
        bailout('file %s doesnt exist, we are isolated' % (path, ))

    while 1:
        try:
            ino = path.stat().ino
        except EnvironmentError as ex:
            log.error('stat on %s failed: %s', path, ex)
            ino = -1
        if ino != orig_ino:
            bailout('file %s changed inode %s => %s' % (path, orig_ino, ino))
        sleep(timeout)


def start_watchdog(path, handler):
    # get ino of listening socket
    # to provide it to watchdog
    try:
        ino = path.stat().ino
    except EnvironmentError:
        ino = -1

    spawn(watchdog, ino, path, handler.procs)


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '-l', '--lock_port', type=int, dest='lockPort', action='store', required=True, help='TCP port to lock on'
    )
    parser.add_argument('--appdir')
    parser.add_argument('--workdir', required=True)
    parser.add_argument('--logdir', required=True)
    parser.add_argument('--runcov', action='store_true')

    return parser.parse_args()


def setup_cov(args):
    if args.runcov:
        try:
            from coverage.control import coverage
        except ImportError:
            return None
        else:
            cov = coverage(data_suffix=True)
            cov.start()
            return cov


def loop_mailbox(log, daemon):
    # ==== start the show ====
    def runloop():
        try:
            daemon.requestLoop()
        except Exception as ex:
            log.error(str(formatException(ex)))
        finally:
            mailbox().put(_stop_raiser)

    spawn(runloop)

    try:
        for obj in mailbox().iterate():
            obj()
    except Stop:
        log.info('Exiting on stop command')
        return 0
    except KeyboardInterrupt:
        log.info('Exiting on user request')
        return 0
    except Exception as err:
        log.error('Exiting on error %s', err)
        return 1

    return 0


def motd(log, rundir, logdir, rpcpath, cfg):
    log.info('Starting...')

    if utils.has_root():
        log.info('  uid: %d (stored root)', os.getuid())
    else:
        log.info('  uid: %d', os.getuid())

    log.info('  workdir: %s', rundir)
    log.info('  logdir: %s', logdir)
    log.info('  rpc: %s', rpcpath)
    log.info('  logging: %s', cfg)


class SocketConnection(Pyro4.socketutil.SocketConnection):
    __slots__ = ['peer_uid', 'peer_pid', 'peer_starttime', 'other_container']

    def __init__(self, sock, objectId=None):
        super(SocketConnection, self).__init__(sock, objectId)

        peer_uid = None
        peer_pid = None
        peer_starttime = None
        other_container = False

        if isinstance(sock, socket.SocketType):
            try:
                try:
                    peer_uid, _, peer_pid = getpeerid(sock, getpid=True)
                except NotImplementedError:
                    peer_uid = getpeerid(sock)[0]

            except (EnvironmentError, AttributeError):
                pass

        if peer_pid is not None and os.uname()[0].lower() == 'linux':
            other_container = not same_container(peer_pid, files=False, procs=True)
            stat_path = '/proc/{}/stat'.format(peer_pid)
            with open(stat_path) as f:
                stat_data = f.read().strip().split()
                # take -31, not +21, since comm field may be hijacked with spaces
                peer_starttime = int(stat_data[-31])

        self.peer_uid = peer_uid
        self.peer_pid = peer_pid
        self.peer_starttime = peer_starttime
        self.other_container = other_container

    def peer_changed(self):
        if self.peer_pid is None or self.peer_starttime is None:
            return False

        try:
            stat_path = '/proc/{}/stat'.format(self.peer_pid)
            with open(stat_path) as f:
                stat_data = f.read().strip().split()
                peer_starttime = int(stat_data[-31])
                return peer_starttime != self.peer_starttime
        except EnvironmentError:
            return True


def _main():
    args = parse_args()
    cov = setup_cov(args)

    rundir = py.path.local(args.workdir)
    logdir = py.path.local(args.logdir)
    rpcpath = rundir.join('skynet.procman.sock')
    cdumpdir = rundir.join('coredumps')

    # legacy
    old_sock = rundir.dirpath().join('skynet.procman.sock')
    if old_sock.check(exists=1):
        old_sock.remove()

    utils.fixperms(rundir, logdir, cdumpdir)

    from library import config
    log_cfg = config.query('skynet.services.procman', 'service').get('logging', {})
    max_bytes = log_cfg.get('max_bytes', 10 * 1024 * 1024)
    backup_count = log_cfg.get('backup_count', 5)

    log = utils.setup_logger(logdir, max_bytes, backup_count)

    motd(log, logdir, rundir, rpcpath, log_cfg)

    random.seed()
    setProcTitle(PROGNAME)

    # ==== set up Pyro ====
    Pyro4.util.Serializer.pickle = pickle
    Pyro4.util.Serializer.pickle.Unpickler = utils.SafeUnpickler
    Pyro4.config.SERVERTYPE = 'gevent'
    Pyro4.threadutil.Lock = RLock
    Pyro4.threadutil.Event = Event
    Pyro4.socketutil.SocketConnection = SocketConnection
    Pyro4.socketserver.geventserver.SocketConnection = SocketConnection

    rpcpath.dirpath().ensure(dir=1)

    exitcode = 127  # unpredicted error
    try:
        with TCPLock(args.lockPort):
            try:
                with utils.auto_user_privileges():
                    rpcpath.remove()
            except Exception:
                pass

            try:
                daemon = Pyro4.Daemon(unixsocket=rpcpath.strpath)
            except socket.error as err:
                log.error('Failed to create daemon on %s: %s', rpcpath, err)
                return 1

            rpcpath.chmod(0o777)  # set 0777 on socket path, so other people will able to connect us

            handler = Handler(rundir)
            daemon.register(handler, PROGNAME)

            start_watchdog(rpcpath, handler)

            log.info('Daemon is up and running')

            exitcode = loop_mailbox(log, daemon)
            daemon.shutdown()

    except Exception:
        import traceback
        log.error('Exiting on error:')
        trace = traceback.format_exc()
        for line in trace.split('\n'):
            if line:
                log.error(line)
        sys.stderr.write(trace)
        exitcode = 1

    if cov:
        cov.stop()
        cov.save()

    return exitcode


def main():
    with utils.user_privileges(os.environ.get('SKYNET_PROCMANUSER', 'skynet'), limit=False):
        raise SystemExit(_main())


if __name__ == '__main__':
    main()
