from __future__ import absolute_import

import os
import sys
import time
import fcntl
import Queue
import socket
import select
import threading
import subprocess
import collections

import bz2
import zlib

import py
import gevent
import gevent.event
import gevent.socket
import msgpack

from kernel.util.net import socketpair

from ...utils.stats import TimeFrameCounter


class JobBroker(object):
    """
    The class will perform the synchronization job between main thread (with greenlets) and
    the actual worker threads by registering each job with separate greenlet event and
    communication between threads with interconnected sockets pair.

    Since the server works actually with green threads, we can't use here
    simple threading events. Instead, we have to create a pipe pair here
    and pass one pipe side to the thread to block on read operation
    on this side.
    """

    # Options for sockets pair, which will be created for each worker to communicate with (even for threads).
    SOCKET_OPTS = dict(family=socket.AF_UNIX, type=socket.SOCK_STREAM)

    def __init__(self):
        super(JobBroker, self).__init__()

        self.log = None
        # Socket pair to signal to scheduler
        self._sockPair = None
        # Watcher green thread
        self._watcher = None
        # Data passing back channel
        self._queue = None
        # Amount of clients
        self._users = 0

    def start(self, log):
        self._users += 1
        if self._users > 1:
            return
        self.log = log.getChild('broker')

        self._queue = Queue.Queue(maxsize=10)
        self._sockPair = collections.namedtuple('SocketPair', ['master', 'slave'])(
            *socketpair.socketPair(socketModule=socket, **self.SOCKET_OPTS)
        )

        self._watcher = gevent.spawn(self._watcherLoop)
        # Take a chance for scheduler to perform initial run.
        gevent.sleep()

    def stop(self):
        self._users -= 1
        if self._users:
            return
        self.log.debug('Asking service green thread to stop.')
        # Close the pipe completely to notify a scheduler to exit.
        # Actually, scheduler will not wake up on close (WHY?!), so just send
        # a "special" message to it.

        self._sockPair.slave.send('\0')
        self._watcher.join()
        for sock in self._sockPair:
            sock.close()

        # AKA `self._queue.clear()`
        self._queue = Queue.Queue()

    def notify(self, ctx, data):
        """
        Job completion notifier. It will set the @code{ev} object with @code{data} passed, but
        in the __main__ thread.
        :param ctx:     Context object, created by the task initiator.
        :param data:    Data to be set on the event object.
        :return:        None
        """
        # Pass the context and the data to the thread's input queue.
        self._queue.put((ctx, data, ))
        # Also, notify events watcher thread.
        self._sockPair.slave.send('\1')

    def _watcherLoop(self):
        running = True
        self.log.info('Service green thread started.')
        sock = gevent.socket.fromfd(self._sockPair.master.fileno(), *map(self.SOCKET_OPTS.get, ['family', 'type']))
        while running:
            # Wait for write event on other side of the pipe
            data = sock.recv(0xFFFF)
            # It actually does not matter which data has been passed.
            # But let assume '\1' is "ok" signal, and '\0' is signal to stop.
            if not data or ('\0' in data):
                running = False

            try:
                while True:
                    ctx, data = self._queue.get(False)
                    ctx(data)
            except Queue.Empty:
                pass

        self.log.info('Service green thread stopped.')


class Stats(object):
    class Entry(object):
        __slots__ = ['requests', 'psize', 'upsize', 'tsdiv']

        def __init__(self, requests=0, psize=0., upsize=0., tsdiv=0):
            super(Stats.Entry, self).__init__()
            for a, v in zip(self.__slots__, [requests, psize, upsize, tsdiv]):
                setattr(self, a, v)

        def __repr__(self):
            return 'SE(r: %r, p: %r, u: %r, t: %r)' % tuple([getattr(self, a) for a in self.__slots__])

        def add(self, requests=0, psize=0, upsize=0):
            for a, v in zip(self.__slots__[0:-1], [requests, psize, upsize]):
                setattr(self, a, v + getattr(self, a))

    __slots__ = ['log', 'idleCounter', 'reportPeriod']

    def __init__(self, reportPeriod):
        self.reportPeriod = reportPeriod
        self.idleCounter = TimeFrameCounter(reportPeriod)
        self.log = [self.Entry() for _ in range(reportPeriod)]

    # Statistic update function using moving average (MA) with n=report_period.
    # Isn't it too complicated? :)
    def update(self, start, plen=0., uplen=0., requests=0, now=None):
        if not now:
            now = time.time()
        self.idleCounter.start(start)
        idle = self.idleCounter.stop(now)
        tsdiv, index = divmod(int(now), self.reportPeriod)

        # Update requests rate
        st = self.log[index]
        if st.tsdiv != tsdiv:
            self.log[index] = self.Entry(requests, plen, uplen, tsdiv)
        else:
            st.add(requests, plen, uplen)

        reqs = float(sum(x.requests for x in self.log))
        res = [
            1 - idle,                   # no-op time percent
            reqs / self.reportPeriod,   # rps
            0, 0                        # avg packed/unpacked report size
        ]
        if reqs > 0:
            res[-2] = sum(x.psize for x in self.log) / reqs
            res[-1] = sum(x.upsize for x in self.log) / reqs

        return tuple(res)


class Supervisor(object):
    """
    The class is designed to handle specific bulldozer running session.
    It also collects statistics for HBS state report.
    """

    # Job broker object.
    _broker = JobBroker()
    # Sub-process start lock - it seems as this operation isn't safe.
    _execLock = threading.Lock()

    Entry = collections.namedtuple('Entry', ['expires', 'size', 'callback', 'source', 'name', 'report'])

    def __init__(self, name, cfg, log):
        super(Supervisor, self).__init__()

        self.cfg = cfg
        self.log = log
        self.name = name
        self.restarts = 0           # generic restarts after processing some reports
        self.restarts_critical = 0  # restarts after first report
        self.restarts_initial = 0   # restarts before greeting

        self.aliases = getattr(cfg, 'aliases', []) + [name]
        self.bin = py.path.local(cfg.executable)

        self._threads = []
        self.terminated = False

        self._queue = Queue.Queue(maxsize=1000)
        self._queue_bytes = 0

        self._stat_requests = 0
        self._stat_discards = 0

        log.info('Initializing. Aliases: %r.', self.aliases)
        if not self.bin.check(file=1):
            raise Exception('Plugin %r executable file %r not found.' % (name, self.bin))

    def start(self):
        # Make sure job scheduler started (actually needed only for tests).
        self._broker.start(self.log)

        for i in range(self.cfg.instances):
            th = threading.Thread(target=self._thread, name='#%d' % i, args=(i, ))
            th.daemon = True
            self._threads.append([th, (1.0, 0, 0, 0, )])
            th.start()

        self.restarts = self.restarts_critical = self.restarts_initial = 0

    def stop(self):
        self.terminated = True

        if not self._threads:
            return

        while True:
            try:
                entry = self._queue.get_nowait()
                if entry is not None:
                    self._queue_bytes -= entry.size
            except Queue.Empty:
                break

        for _ in range(len(self._threads)):
            self._queue.put(None)

        [thr[0].join() for thr in self._threads]

        # Make sure job scheduler stopped (actually needed only for tests).
        self._broker.stop()

    def status(self):
        threads = len(self._threads)
        ret = {
            # Sleep ratio to the report period (total sleep time divided by the report period).
            'busy': 1. if not threads else sum(map(lambda x: x[1][0], self._threads)) / threads,
            # Requests rate per second for the report period.
            'rate': 0. if not threads else sum(map(lambda x: x[1][1], self._threads)),
            # Average packed report size for the report period.
            'psize': 0. if not threads else sum(map(lambda x: x[1][2], self._threads)) / threads,
            # Average unpacked report size for the report period.
            'upsize': 0. if not threads else sum(map(lambda x: x[1][3], self._threads)) / threads,
            # Amount of queued (still not processed) reports.
            'queue': self._queue.qsize(),
            # Amount of queued (still not processed) reports data.
            'queue_data': self._queue_bytes,
            # Amount of requests processed since the plugin start.
            'requests': self._stat_requests,
            # Amount of discarded packages since the plugin start.
            'discards': self._stat_discards,
            # Running plugin instances.
            'instances': len(filter(lambda x: x[0].isAlive(), self._threads)),
            # Amount of restarts of underlying sub-process
            'restarts': self.restarts,
            # Amount of restarts during initial greeting with sub-process
            'restarts_initial': self.restarts_initial,
            # Amount of restarts before processing first report
            'restarts_critical': self.restarts_critical,
        }
        return ret

    def discard(self, cb, msg, scheduler=None, log=None):
        self._stat_discards += 1
        (self.log.warn if not log else log)(msg)
        if scheduler:
            scheduler.notify(cb, msg)
        else:
            cb(msg)

    def process(self, src, name, report, expires=float('inf'), callback=None):
        """
        Processes the report given. In case of callback passed, the method will return immediately
        after scheduling the task, otherwise it will block till result will be available.

        :param src:         Source host name.
        :param name:        Report type (name).
        :param report:      Report data itself.
        :param expires:     Report expiration timestamp.
        :param callback:    Callable object, which will be called with task result.
                            In case of `None` passed, the function will block till result is ready.
        :return:            `None` in case of callback object provided, otherwise
                            `None` on success or error message on discard.
        """

        ares = None
        if not callback:
            ares = gevent.event.AsyncResult()
            callback = ares.set

        if not isinstance(report['report'], (dict, basestring, )):
            self.discard(callback, 'Malformed report of type %r from %r: %r' % (name, src, report))
        else:
            # Now we can pass the task to the worker thread.
            if not self.terminated:
                try:
                    entry = self.Entry(
                        source=src,
                        name=name,
                        report=report,
                        expires=expires,
                        callback=callback,
                        size=len(report['report'])
                    )
                    self._queue.put_nowait(entry)
                except Queue.Full:
                    self.discard(
                        callback,
                        'Queue overflow, report of type %r from %r dropped' % (
                            name, src
                        ),
                        log=self.log.debug
                    )
                else:
                    self._queue_bytes += entry.size
            else:
                self.discard(
                    callback,
                    'Server terminating, report of type %r from %r dropped' % (
                        name, src
                    )
                )

        # Ok, now we should just wait for the event. Separate greenlet scheduler will wait
        # on the pipe an signal about the task competition and will set our event object.
        return ares.get() if ares else None

    def _proc_create_poller(self, proc):
        poller = select.poll()
        read_only = select.POLLIN | select.POLLPRI | select.POLLHUP | select.POLLERR

        stdout_fd = proc.stdout.fileno()
        stderr_fd = proc.stderr.fileno()

        for fd in (stdout_fd, stderr_fd):
            fcntl.fcntl(
                fd, fcntl.F_SETFL, fcntl.fcntl(
                    fd, fcntl.F_GETFL
                ) | os.O_NONBLOCK
            )

        poller.register(stdout_fd, read_only)
        poller.register(stderr_fd, read_only)

        return {
            'poller': poller,
            'stdout_fd': stdout_fd,
            'stderr_fd': stderr_fd,
            'stdout_buf': [],
            'stderr_buf': [],
            'stdout_first_read': True
        }

    def _proc_stop_poller(self, ctx):
        if 'stdout_fd' in ctx:
            ctx['poller'].unregister(ctx['stdout_fd'])

        if 'stderr_fd' in ctx:
            ctx['poller'].unregister(ctx['stderr_fd'])

    def _proc_poll(self, ctx, log, timeout):
        """
        Returns [lines] if some lines were read from stdout
                None if stdout stream is finished
        """

        ret = noret = object()

        deadline = time.time() + timeout

        while ret is noret:
            timeout = deadline - time.time()
            if timeout > 0:
                events = ctx['poller'].poll(timeout * 1000)
            else:
                ret = -1
                break

            for fd, flag in events:
                if flag & (select.POLLIN | select.POLLPRI):
                    if fd == ctx['stderr_fd']:
                        data = os.read(fd, 1024)
                        if not data:
                            ctx['poller'].unregister(fd)
                            ctx.pop('stderr_fd', None)
                        else:
                            buf = ctx['stderr_buf']
                            buf.append(data)
                            if '\n' in buf[-1]:
                                data = ''.join(buf)
                                if not data.endswith('\n'):
                                    data, rem = data.rsplit('\n', 1)
                                    buf[:] = [rem]
                                else:
                                    buf[:] = []

                                for line in data.split('\n')[:-1]:
                                    log.warn('Subprocess STDERR:%s', line)
                    elif fd == ctx['stdout_fd']:
                        if ctx['stdout_first_read']:
                            data = os.read(fd, 1)
                            if not data:
                                ctx['poller'].unregister(fd)
                                ctx.pop('stdout_fd', None)
                                ret = None
                            else:
                                ctx['stdout_first_read'] = False
                                ret = data
                        else:
                            data = os.read(fd, 1024)
                            if not data:
                                ctx['poller'].unregister(fd)
                                ret = None
                            else:
                                buf = ctx['stdout_buf']
                                buf.append(data)
                                if '\n' in buf[-1]:
                                    data = ''.join(buf)
                                    if not data.endswith('\n'):
                                        data, rem = data.rsplit('\n', 1)
                                        buf[:] = [rem]
                                    else:
                                        buf[:] = []

                                    ret = data.split('\n')[:-1]
                elif flag & (select.POLLHUP | select.POLLERR):
                    stdout_fd, stderr_fd = ctx.get('stdout_fd', None), ctx.get('stderr_fd', None)
                    if fd == stdout_fd:
                        if flag & select.POLLHUP:
                            log.warn('Subprocess stdout received SIGHUP')
                        else:
                            log.warn('Subprocess stdout received POLLERR')
                        ctx['poller'].unregister(fd)
                        ctx.pop('stdout_fd')
                        ret = None

                    elif fd == stderr_fd:
                        if flag & select.POLLHUP:
                            log.warn('Subprocess stderr received SIGHUP')
                        else:
                            log.warn('Subprocess stderr received POLLERR')
                        ctx['poller'].unregister(fd)
                        ctx.pop('stderr_fd')

        return ret

    def _unpack_report(self, report, log):
        if not isinstance(report['report'], basestring):
            return True, None

        datalen = (len(report['report']), 0)
        try:
            dec = {
                None: lambda x: x,
                'bz2': lambda x: bz2.decompress(x),
                'zip': lambda x: zlib.decompress(x),
            }[report.get('compression', None)]
            ddata = dec(report['report'])
        except Exception as ex:
            log.warning('Unable to unpack report: %s (comp: %s)', ex, report.get('compression', None))
            return False, datalen

        datalen = (datalen[0], len(ddata))
        fmt = report.get('format', 'msgpack')
        if fmt in 'msgpack':
            report['report'] = msgpack.unpackb(ddata)
        elif fmt == 'raw':
            report['report'] = ddata

        return True, datalen

    def _thread(self, index):
        log = self.log.getChild('thread.%s' % threading.currentThread().name)

        wakeup_period = min(self.cfg.queue_limits_check, self.cfg.processing_limit)
        # Statistics log - an array of `report_period` of `StatEntry`.
        stats = Stats(int(self.cfg.report_period))

        log.info('Thread started.')
        ths = self._threads[index]

        args = [sys.executable, '-u'] if self.bin.ext == '.py' else []
        args.append(self.bin.strpath)
        confArgs = getattr(self.cfg, 'arguments', None)

        if confArgs:
            if isinstance(self.cfg.arguments, basestring):
                args += self.cfg.arguments.split()
            else:
                args += self.cfg.arguments

        entry = None
        critical = 0
        count = 0
        fails_in_a_row = 0

        while True:
            greeting_completed = False

            try:
                with self._execLock:
                    log.debug('Executing %r', args)
                    env = os.environ.copy()
                    env['PYTHONPATH'] = ':'.join(sys.path)
                    env['DATABASE_URI'] = self.cfg.database_uri
                    if self.cfg.lacmus_uri:
                        env['LACMUS_URI'] = self.cfg.lacmus_uri
                        env['LACMUS_TIMEOUT'] = str(self.cfg.lacmus_timeout)
                        env['LACMUS_RETRIES'] = str(self.cfg.lacmus_retries)
                    env['APPLICATION_ROOT'] = self.cfg.application_root
                    proc = subprocess.Popen(
                        args,
                        env=env,
                        stdin=subprocess.PIPE,
                        stdout=subprocess.PIPE,
                        stderr=subprocess.PIPE,
                    )
            except BaseException as ex:
                self.log.error('Unable to start process: %s', ex)
                if self.terminated:
                    return
                time.sleep(1)
                continue

            try:
                poller_ctx = self._proc_create_poller(proc)

                ret = self._proc_poll(poller_ctx, log, timeout=self.cfg.greetings_wait)
                if ret and ret != -1:
                    if len(ret) == 1 and len(ret[0]) == 1:
                        log.info(
                            'Sub-process started (PID is %d, stdout pipe #%d), going listen on the queue.',
                            proc.pid, proc.stdout.fileno()
                        )
                        greeting_completed = True
                    else:
                        log.info(
                            'Sub-process invalid greeting: %r', ret[0]
                        )
                        proc.terminate()
                elif ret == -1:
                    log.critical(
                        'Sub-process did not sent ready signal in %d seconds. Killing it.',
                        self.cfg.greetings_wait
                    )
                    proc.terminate()
                else:
                    log.critical(
                        'Sub-process invalid greeting: %r', ret
                    )
                    proc.terminate()

                if not greeting_completed:
                    fails_in_a_row += 1

                while greeting_completed:
                    while not entry:
                        start = time.time()
                        try:
                            entry = self._queue.get(timeout=wakeup_period)
                            if not entry:
                                log.debug('Exit requested. Terminating sub-process.')
                                proc.stdin.close()
                                return
                            else:
                                self._stat_requests += 1
                                self._queue_bytes -= entry.size
                        except Queue.Empty:
                            log.debug('Updating idle statistics.')
                            ths[1] = stats.update(start)

                    now = time.time()
                    deadline = now + self.cfg.processing_limit

                    log.debug('Got new task %r from %r.', entry.name, entry.source)

                    unpacked, datalen = self._unpack_report(entry.report, log)
                    if not unpacked:
                        log.warning('Failed to unpack task %s from %s', entry.name, entry.source)
                        break

                    data2send = msgpack.packb((entry.source, entry.name, entry.report, ))
                    if not datalen:
                        datalen = (len(data2send), ) * 2
                    # Update statistics before passing data to the plugin
                    ths[1] = stats.update(start, *datalen, requests=1, now=now)

                    proc.stdin.write(data2send)
                    log.debug('Task sent for the processing.')

                    read = None
                    while True:
                        ret = self._proc_poll(poller_ctx, log, timeout=wakeup_period)
                        if ret and ret != -1:
                            assert len(ret) == 1
                            read = ret[0] + '\n'
                            break
                        elif ret == -1:
                            if time.time() > deadline:
                                log.warn(
                                    'Report was not processed in %d seconds. Terminating process.',
                                    self.cfg.processing_limit
                                )
                                proc.terminate()
                                break
                            continue
                        else:
                            log.warn(
                                'Process died while processing report'
                            )
                            proc.terminate()
                            break

                    if not read or not read.endswith('\n'):
                        break

                    if len(read) > 1:  # A message from client means it has a problem processing a message
                        self.discard(
                            entry.callback,
                            'Discarding report from %r as invalid by the following reason: '
                            '%r. Report content is %r.' %
                            (entry.source, read.strip(), entry.report),
                            scheduler=self._broker,
                            log=log.warn
                        )
                    else:
                        fails_in_a_row = 0
                        # Notify scheduler green thread the task has been completed successfully.
                        log.debug('Task successfully finished. Notifying initiator.')
                        self._broker.notify(entry.callback, None)

                    entry = None
                    count += 1

                proc.stdin.close()
                proc.poll()

                self._proc_stop_poller(poller_ctx)
                if 'stdout_fd' in poller_ctx:
                    poller_ctx.pop('stdout_fd')
                    try:
                        proc.stdout.read()
                        proc.stdout.close()
                    except IOError:
                        pass

                if 'stderr_fd' in poller_ctx:
                    poller_ctx.pop('stderr_fd')
                    try:
                        proc.stderr.read()
                        proc.stderr.close()
                    except IOError:
                        pass

                log_entry = log.critical
                log_entry(
                    'Subprocess terminated with exit code #%s after %d requests processed.',
                    str(proc.returncode), count
                )
                if entry:
                    log_entry(
                        'Report sent by %r of type %r, which leads to process crash: %r',
                        entry.source, entry.name, entry.report
                    )
                    self.discard(
                        entry.callback,
                        'Discarding report from %r because we crashed during processing it.' %
                        (entry.source, ),
                        scheduler=self._broker,
                        log=log.warn
                    )
                    entry = None

                if greeting_completed:
                    if count == 0:  # Sub-process crashed on first
                        self.restarts_critical += 1

                        critical += 1
                        if not critical % self.cfg.discard_restarts and entry:
                            self.discard(
                                entry.callback,
                                'Discarding report from %r because of it leads %d times to sub-process restart.' %
                                (entry.source, self.cfg.discard_restarts),
                                scheduler=self._broker,
                                log=log.warn
                            )
                            entry = None
                    else:
                        # We did processed some reports, so reset critical counter
                        critical = 0
                        self.restarts += 1
                else:
                    # count < 0, so we crashed even before greeting
                    self.restarts_initial += 1

                sleep = self.cfg.restart_pause * min(fails_in_a_row, 20)
                log.info('Sleeping %d seconds before restarting the sub-process..', sleep)

                for i in range(sleep):
                    time.sleep(1)
                    if self.terminated:
                        return

            except Exception as ex:
                import traceback
                log.critical('Unhandled exception in thread worker:')
                log.critical(traceback.format_exc())
                try:
                    proc.terminate()
                except:
                    pass

                self._proc_stop_poller(poller_ctx)
                proc.poll()

                if 'stdout_fd' in poller_ctx:
                    try:
                        proc.stdout.read()
                        proc.stdout.close()
                    except IOError:
                        pass

                if 'stderr_fd' in poller_ctx:
                    try:
                        proc.stderr.read()
                        proc.stderr.close()
                    except IOError:
                        pass

                if entry:
                    self.discard(
                        entry.callback,
                        'Discarding report from %r because we crashed during processing it with %s.' %
                        (entry.source, ex),
                        scheduler=self._broker,
                        log=log.warn
                    )
                    entry = None
                time.sleep(1)

        log.debug('Terminated normally.')
