from __future__ import absolute_import, print_function, division

import gevent
import gevent.queue
import msgpack
import struct
import socket
import errno
import logging
import time
import random
import functools

from kernel.util.errors import formatException
from kernel.util.logging import MessageAdapter

from ..server import ConnectionHandler
from ..utils.socket import Socket, EOF


class RPCConnectionHandler(ConnectionHandler):
    def __init__(self, ctx, rpc):
        self.log = ctx.log.getChild('rpc.connhandler')

        self.__magic = ctx.cfg.rpc.connection_magic
        self.__rpc = rpc

        self.__flagsPacker = struct.Struct('!I')

        self.cfg = ctx.cfg.rpc
        super(RPCConnectionHandler, self).__init__()

    def getMagic(self):
        return self.packMagic(self.__magic)

    def handleConnection(self, sock):
        # Simple handshake.
        # Send flags (uint32) and receive client flags (uint32)

        try:
            if sock.write(self.__flagsPacker.pack(0), timeout=self.cfg.handshake_send_timeout) is None:
                self.log.warning('Peer %r: handshake timed out (failed to write)', sock.peer)
                sock.close()
                return
        except EOF:
            self.log.warning('Peer %r disconnected during handshake (we were attemping to send)', sock.peer)
            sock.close()
            return

        try:
            data = sock.read(self.__flagsPacker.size, timeout=self.cfg.handshake_receive_timeout)
            if data is None:
                self.log.warning('Peer %r: handshake timed out (failed to read)', sock.peer)
                sock.close()
                return
        except EOF:
            self.log.warning('Peer %r disconnected during handshake (we were attemping to read)', sock.peer)
            sock.close()
            return

        flags, = self.__flagsPacker.unpack(data)
        assert flags == 0

        self.__rpc.addConnection(sock)


class RPCConnection(object):
    def __init__(self, ctx, sock):
        self.ctx = ctx
        if sock.peer != '':
            self.log = MessageAdapter(
                ctx.log.getChild('rpc.conn'),
                fmt='[%(host)s:%(port)d] %(message)s',
                data={
                    'host': sock.peer[0],
                    'port': sock.peer[1]
                }
            )
        else:
            self.log = MessageAdapter(
                ctx.log.getChild('rpc.conn'),
                fmt='[unix socket] %(message)s'
            )
        self.__sock = sock

    def _packRegistered(self, uid):
        return msgpack.dumps(('REGISTERED', uid))

    def _packError(self, msg, *args):
        return msgpack.dumps(('ERROR', msg % args))

    def _packComplete(self, uid, result):
        return msgpack.dumps(('COMPLETE', uid, result))

    def _packFailed(self, uid, msg, *args):
        return msgpack.dumps(('FAILED', uid, msg % args))

    def _packState(self, uid, state):
        return msgpack.dumps(('STATE', uid, state))

    def _packPong(self):
        return msgpack.dumps(('PONG', ))

    def peer(self):
        return self.__sock.peer

    def read(self, buffsize):
        return self.__sock.readMsgpack(buffsize)

    def register(self, uid):
        self.__sock.write(self._packRegistered(uid))

    def error(self, msg, *args):
        self.__sock.write(self._packError(msg, *args))

    def fail(self, uid, msg, *args):
        self.__sock.write(self._packFailed(uid, msg, *args))

    def state(self, uid, state):
        self.__sock.write(self._packState(uid, state))

    def finish(self, uid, result):
        self.__sock.write(self._packComplete(uid, result))

    def pong(self):
        self.__sock.write(self._packPong())

    def fatal(self, msg, *args):
        try:
            self.error(msg, *args)
        finally:
            self.close()

    def close(self):
        self.__sock.close()


class RPCJob(object):

    def __init__(self, ctx, uid, conn, codename, handler, args, kwargs):
        self.uid = uid
        self.cfg = ctx.cfg.rpc
        self.log = MessageAdapter(
            ctx.log.getChild('rpc.job'),
            fmt='[0x%(uid)s : %(codename)s] %(message)s',
            data={
                'uid': hex(uid)[2:].rstrip('L').zfill(self.cfg.uid_bits // 4).lower(),
                'codename': codename
            }
        )

        self.__conn = conn
        self.__handler = handler
        self.__args = args
        self.__kwargs = kwargs
        self.__workerGrn = None
        self.__startTime = None
        self.__finishTime = None
        self.__state = 0  # 0:stopped, 1:running, 2:complete, 3:failed

    @property
    def started(self):
        return self.__startTime

    @property
    def stopped(self):
        return self.__finishTime

    @property
    def peer(self):
        return self.__conn.peer()

    def __call(self, func, args, kwargs):
        try:
            assert 'job' not in kwargs
            kwargs['job'] = self
            self.log.info('Started')
            result = func(*args, **kwargs)
        except gevent.GreenletExit:
            self.log.debug('Killed (caught GreenletExit)')
        except BaseException as ex:
            self.__conn.fail(self.uid, 'Got unhandled exception during handler run: %s', str(ex))
            self.log.warning('Unhandled exception during handler run: %s', formatException())
        else:
            self.finish(result)

    def register(self):
        self.__conn.register(self.uid)

    def start(self):
        self.log.debug('Start')
        self.__startTime = time.time()
        self.__workerGrn = gevent.Greenlet(self.__call, self.__handler, self.__args, self.__kwargs)
        self.__workerGrn.start()

    def stop(self, error=gevent.GreenletExit, reason=None):
        self.log.debug('Stop (reason: %r)', reason)

        if self.__workerGrn is not None:
            try:
                self.__workerGrn.kill(error)
            except error:
                pass

    def state(self, state):
        try:
            self.__conn.state(self.uid, state)
        except socket.error as ex:
            if ex.errno in (errno.EBADF, ):
                self.stop()
                return
            raise

    def finish(self, result):
        self.__finishTime = time.time()
        self.log.info('Finished (total time %0.4fs)', self.__finishTime - self.__startTime)

        try:
            self.__conn.finish(self.uid, result)
        except socket.error as ex:
            if ex.errno in (errno.EBADF, ):
                return
            raise

    def destroy(self, reason=None):
        self.log.debug('Destroy (reason: %r)', reason)

        if self.__workerGrn is not None and not self.__workerGrn.ready():
            self.stop(reason=reason)

        del self.log
        del self.__workerGrn
        del self.__conn
        del self.__handler


class RPC(object):
    def __init__(self, ctx):
        self.ctx = ctx
        self.cfg = ctx.cfg.rpc
        self.log = ctx.log.getChild('rpc')
        self.log.debug('Initializing')

        self.__workerGrn = None
        self.__connectionWorkerGrns = []
        self.__jobs = {}
        self.__handlers = {}

    # Decorators {{{
    @staticmethod
    def simple(func):
        @functools.wraps(func)
        def _wrapper(*args, **kwargs):
            kwargs.pop('job')
            return func(*args, **kwargs)
        return _wrapper

    @staticmethod
    def yielder(func):
        @functools.wraps(func)
        def _wrapper(*args, **kwargs):
            job = kwargs.pop('job')
            state = None
            for state in func():
                job.state(state)
            return state
        return _wrapper

    @staticmethod
    def full(func):
        @functools.wraps(func)
        def _wrapper(*args, **kwargs):
            job = kwargs.pop('job')
            return func(job, *args, **kwargs)
        return _wrapper
    # Decorators }}}

    def start(self):
        return self

    def stop(self):
        for grn in self.__connectionWorkerGrns:
            grn.kill(gevent.GreenletExit('RPC stopping'))

        return self

    def join(self):
        pass

    def getConnectionHandler(self):
        return RPCConnectionHandler(self.ctx, self)

    def __generateJobUid(self):
        for i in range(self.cfg.uid_generation_tries):
            uid = random.getrandbits(self.cfg.uid_bits)
            if uid not in self.__jobs:
                break
            gevent.sleep(self.cfg.uid_generation_retry_sleep)

        if uid in self.__jobs:
            log.warning(
                'Failed to generate unique uid, tried %d times.',
                self.cfg.uid_generation_tries
            )
            raise Exception('Bad entropy')

        return uid

    def __connectionWorker(self, conn):
        log = conn.log
        log.info('New connection (spawned)')

        activeJobs = {}  # active jobs by uid for this connection
        idleTimer = gevent.Timeout(self.cfg.idle_timeout)

        try:
            for message in conn.read(self.cfg.receive_buffer):
                assert isinstance(message, tuple), 'Invalid message received'

                try:
                    messageType = message[0]
                    if messageType == 'PING':
                        pongSleep = self.cfg.pingpong_sleep_seconds
                        if pongSleep > 0:
                            gevent.sleep(pongSleep)
                        conn.pong()

                    elif messageType == 'CALL':
                        idleTimer.cancel()
                        hasattr(idleTimer, 'close') and idleTimer.close()
                        idleTimer = gevent.Timeout(self.cfg.idle_timeout)

                        uid = message[1]
                        name, args, kwargs = message[2]
                        if name not in self.__handlers:
                            conn.error('Handler for %r not registered', name)
                            continue

                        if uid in self.__jobs:
                            # We received bad UID, try to generate new one.
                            try:
                                uid = self.__generateJobUid()
                            except Exception as ex:
                                conn.error('Failed to generate unique id: %s', str(ex))
                                continue

                        try:
                            job = RPCJob(self.ctx, uid, conn, name, self.__handlers[name], args, kwargs)
                        except Exception as ex:
                            self.log.warning('Unhandled exception during creating job object: %s', formatException())
                            conn.error('Failed to create job: %s', str(ex))
                            continue

                        self.__jobs[job.uid] = activeJobs[job.uid] = job

                        try:
                            job.register()
                            job.start()
                        except Exception as ex:
                            self.log.warning('Unhandled exception during starting job: %s', formatException())
                            conn.error('Failed to add job: %s', str(ex))
                            self.__jobs.pop(job.uid)
                            activeJobs.pop(job.uid)
                            job.destroy(reason='Error: %s' % (str(ex), ))
                            continue

                    elif messageType == 'DROP':
                        uid = message[1]
                        job = self.__jobs.pop(uid, None)
                        if job:
                            job.destroy(reason='DROP requested')
                            del activeJobs[job.uid]

                        if len(activeJobs) == 0:
                            idleTimer.start()
                    else:
                        log.warning('Got invalid message: %r', message)
                        conn.error('Not supported message %r', message)

                except Exception as ex:
                    log.warning('Failed to handle message %r: %s', message, formatException())
                    conn.error('Failed to handle message: %s', str(ex))

            ex = False  # see note below in except block

        except gevent.GreenletExit as ex:
            log.info('Connection greenlet killed: %s', ex)
            conn.error(str(ex))

        except (Exception, gevent.Timeout) as ex:
            # We just grabbing exception to "ex" variable here
            # and will use it in finally: block to determine was exception
            # happened or not. But, also log error
            if ex == idleTimer:
                log.info('Connection was idle for %d seconds, closing', self.cfg.idle_timeout)
                exDesc = 'Idle for %d seconds, closing connection' % (self.cfg.idle_timeout, )
            elif isinstance(ex, socket.error) and str(ex) == 'timed out':
                log.info('Connection timed out (%d seconds)', self.cfg.socket_timeout)
                exDesc = 'Connection timed out'
            else:
                log.warning('Unhandled exception: %s', formatException())
                conn.error(str(ex))

        finally:
            idleTimer.cancel()
            hasattr(idleTimer, 'close') and idleTimer.close()

            for job in activeJobs.values():
                del self.__jobs[job.uid]

                try:
                    reason = exDesc
                except NameError:
                    if ex:
                        reason = 'Error: %s' % (str(ex), )
                    else:
                        reason = 'EOF'

                job.destroy(reason=reason)

            conn.close()
            if not ex:
                log.debug('No more data to process, connection closed')

    def addConnection(self, sock):
        if sock.peer != '':  # nodelay not supported on unix sockets
            sock.nodelay = self.cfg.socket_nodelay

        sock.sendBuffer = self.cfg.socket_send_buffer
        sock.receiveBuffer = self.cfg.socket_receive_buffer
        sock.timeout = self.cfg.socket_timeout

        conn = RPCConnection(self.ctx, sock)
        conn.log.debug('New connection (initialized)')

        grn = gevent.spawn(self.__connectionWorker, conn)
        self.__connectionWorkerGrns.append(grn)
        grn.link(self.__connectionWorkerGrns.remove)

        # Link any unhandled error, so they will try to report error
        grn.link_exception(lambda grn: conn.fatal('(connection worker fatal) %s: %s' % (
            grn.exception.__class__.__name__,
            str(grn.exception)
        )))

    def registerHandler(self, name, handler):
        self.__handlers[name] = handler
