from __future__ import absolute_import, print_function, division

import os
import sys
import time
import errno
import socket
import bisect
import random
import inspect
import functools
import traceback as tb

import msgpack

import gevent
import gevent.queue
import gevent.socket

from .. import log

from ..joint import utils as joint_utils
from ..joint import errors as joint_errors
from ..joint import socket as joint_socket


class Connection(object):
    class Handler(joint_utils.ConnectionHandler):
        def __init__(self, ctx, rpc, age):
            self.rpc = rpc
            self.log = ctx.log.getChild('rpc.conn.handler')
            super(Connection.Handler, self).__init__(ctx.cfg.rpc, age)

        def handle_session(self, sock, sid=None):
            try:
                ret = super(Connection.Handler, self).handle_session(sock, sid)
                if ret is None:
                    self.log.error('Peer %r: handshake timed out (failed to write)', sock.peer)
                    sock.close()
                    return
            except joint_socket.EOF:
                self.log.error('Peer %r disconnected during handshake (we were attempting to send)', sock.peer)
                sock.close()
                return

            self.rpc.add_connection(sock, sid)
            return ret

    def __init__(self, ctx, sock, sid):
        self.id = sid
        self.ctx = ctx

        self.log = log.MessageAdapter(
            ctx.log.getChild('rpc.conn'),
            fmt='{%(sid)s} %(message)s',
            data={'sid': joint_utils.sid2str(sid)},
        )

        self.__sock = sock
        if sock.peer != '':
            self.log.info('Connection accepted on %s:%d', sock.peer[0], sock.peer[1])
        else:
            self.log.info('Connection accepted on unix socket')

    @staticmethod
    def _pack_registered(jid):
        return msgpack.dumps(('REGISTERED', jid))

    @staticmethod
    def _pack_error(msg, *args):
        return msgpack.dumps(("ERROR", msg % args), use_bin_type=True)

    @staticmethod
    def _pack_complete(jid, result):
        return msgpack.dumps(("COMPLETE", jid, result), use_bin_type=True)

    @staticmethod
    def _pack_failed(jid, ei):
        try:
            return msgpack.dumps((
                "FAILED", jid,
                ei[0].__module__, ei[0].__name__,
                ei[1].args, tb.format_exception(*ei)
            ), use_bin_type=True)
        except (TypeError, ValueError) as ex:
            return msgpack.dumps((
                "FAILED", jid,
                ei[0].__module__, ei[0].__name__,
                (str(ei[1].args[0]), "<UNKNOWN ARGUMENTS LIST: {!s}>".format(ex)), tb.format_exception(*ei)
            ), use_bin_type=True)

    @staticmethod
    def _pack_state(jid, state):
        return msgpack.dumps(("STATE", jid, state), use_bin_type=True)

    @staticmethod
    def _pack_pong(jid):
        return msgpack.dumps(('PONG', jid))

    @property
    def peerid(self):
        return self.__sock.peerid

    def peer(self):
        return self.__sock.peer

    def read(self, buffsize):
        return self.__sock.read_msgpack(buffsize)

    def register(self, jid):
        self.__sock.write(self._pack_registered(jid))

    def error(self, msg, *args):
        self.__sock.write(self._pack_error(msg, *args))

    def fail(self, jid, ei):
        self.__sock.write(self._pack_failed(jid, ei))

    def state(self, jid, state):
        self.__sock.write(self._pack_state(jid, state))

    def finish(self, jid, result):
        self.__sock.write(self._pack_complete(jid, result))

    def pong(self, jid=None):
        self.__sock.write(self._pack_pong(jid))

    def fatal(self, msg, *args):
        try:
            self.error(msg, *args)
        except (socket.error, joint_socket.EOF):
            # Ignore possible socket exceptions here
            pass
        finally:
            self.close()

    def close(self):
        self.__sock.close()


class RPCJob(object):
    def __init__(self, ctx, jid, conn, codename, handler, args, kwargs):
        self.id = jid
        self.cfg = ctx.cfg.rpc
        self.__stalled_jobs = ctx.stalled_jobs

        self.__conn = conn
        self.__handler = handler
        self.__args = args
        self.__kwargs = kwargs

        self.__worker_grn = None
        self.__start_time = None
        self.__finish_time = None
        self.__state = 0  # 0:stopped, 1:running, 2:complete, 3:failed
        self.__feedback_queue = gevent.queue.Queue()

        self.log = log.MessageAdapter(
            ctx.log.getChild('rpc.job'),
            fmt='{%(uid)s:%(codename)s} %(message)s',
            data={'uid': self.jid, 'codename': codename}
        )

    @property
    def started(self):
        return self.__start_time

    @property
    def stopped(self):
        return self.__finish_time

    @property
    def peer(self):
        return self.__conn.peer()

    @property
    def connection(self):
        return self.__conn

    @property
    def sid(self):
        return joint_utils.sid2str(self.__conn.id)

    @property
    def jid(self):
        return "{}:{}".format(self.sid, self.id)

    def __call(self, func, args, kwargs):
        try:
            assert 'job' not in kwargs
            kwargs['job'] = self
            self.log.info('Started')
            result = func(*args, **kwargs)
        except gevent.GreenletExit:
            self.log.warning('Killed (caught GreenletExit)')
        except BaseException as ex:
            self.__conn.fail(self.id, sys.exc_info())
            if not isinstance(ex, joint_errors.SilentException):
                self.log.exception('Unhandled exception during handler run')
            else:
                ext = type(ex)
                self.log.info('Method returned exception %s.%s(%s)', ext.__module__, ext.__name__, ex)
        else:
            self.finish(result)

    def register(self):
        self.__conn.register(self.id)

    def start(self):
        self.log.debug('Start')
        self.__start_time = time.time()
        self.__worker_grn = gevent.Greenlet(self.__call, self.__handler, self.__args, self.__kwargs)
        self.__worker_grn.start()

    def stop(self, error=gevent.GreenletExit, reason=None):
        self.log.debug('Stop (reason: %r)', reason)

        if self.__worker_grn is not None:
            try:
                if self.__stalled_jobs is not None:
                    self.__stalled_jobs[self.jid] = self.__worker_grn
                self.__worker_grn.kill(error)
            except error:
                pass

    def state(self, state):
        try:
            self.__conn.state(self.id, state)
        except socket.error as ex:
            if ex.errno in (errno.EBADF, ):
                self.stop()
                return
            raise

    @property
    def feedback(self):
        return self.__feedback_queue.get()

    @feedback.setter
    def feedback(self, value):
        self.__feedback_queue.put(value)

    def finish(self, result):
        self.__finish_time = time.time()
        self.log.info('Finished (total time %0.4fs)', self.__finish_time - self.__start_time)

        try:
            self.__conn.finish(self.id, result)
        except socket.error as ex:
            if ex.errno in (errno.EBADF, ):
                return
            raise

    def destroy(self, reason=None):
        self.log.debug('Destroy (reason: %r)', reason)

        if self.__worker_grn is not None and not self.__worker_grn.ready():
            self.stop(reason=reason)

        del self.log
        del self.__worker_grn
        del self.__conn
        del self.__handler


class StalledJobs(object):
    def __init__(self, timeout):
        self.__greenlets = {}
        self.__timeout = timeout
        self.__timestamps = []
        self.__jid_timestamps = {}

    def __setitem__(self, jid, greenlet):
        if jid in self.__greenlets:
            return
        self.__greenlets[jid] = greenlet
        timestamp = time.time()
        bisect.insort_left(self.__timestamps, (timestamp, jid))
        self.__jid_timestamps[jid] = timestamp

    def __delitem__(self, jid):
        if jid not in self.__greenlets:
            return
        del self.__greenlets[jid]
        del self.__timestamps[bisect.bisect_left(self.__timestamps, (self.__jid_timestamps.pop(jid), jid))]

    def __iter__(self):
        deadline = time.time() - self.__timeout
        for timestamp, jid in self.__timestamps:
            if timestamp >= deadline:
                break
            yield jid, self.__greenlets[jid], timestamp


class RPC(object):
    # Registered methods, which will be accessible remotely
    __handlers = {}

    class Worker(object):
        def __init__(self, conn, rpc):
            self.rpc = rpc
            self.conn = conn
            self.log = conn.log
            self.cfg = conn.ctx.cfg.rpc
            self.active_jobs = {}  # active jobs by uid for this connection
            self.idle_timer = gevent.Timeout(self.cfg.idle_timeout)

            self.log.info('New connection (worker spawned)')

        def loop(self):
            ex = False  # see note below in except block
            ex_desc = None
            try:
                while True:
                    iteration = None
                    for iteration, message in enumerate(self.conn.read(self.cfg.receive_buffer)):
                        self.process(message)
                    if not len(self.active_jobs):
                        break
                    if iteration is None:
                        self.log.warning("No data received from socket - connection closed by remote side.")
                        break

            except gevent.GreenletExit as ex:
                self.log.info('Connection greenlet killed: %s', ex)
                self.conn.fatal('RPC worker killed: %s', str(ex))

            except (Exception, gevent.Timeout) as ex:
                # We just grabbing exception to "ex" variable here
                # and will use it in finally: block to determine was exception
                # happened or not. But, also log error
                if ex == self.idle_timer:
                    self.log.info('Connection was idle for %d seconds, closing', self.cfg.idle_timeout)
                    ex_desc = 'Idle for %d seconds, closing connection' % (self.cfg.idle_timeout, )
                elif isinstance(ex, socket.error) and str(ex) == 'timed out':
                    self.log.info('Connection timed out (%d seconds)', self.cfg.socket_timeout)
                    ex_desc = 'Connection timed out'
                else:
                    self.log.exception('Unhandled exception on main loop')
                    self.conn.error(str(ex))

            finally:
                self.idle_timer.cancel()

                for job in self.active_jobs.values():
                    if ex_desc is not None:
                        reason = ex_desc
                    else:
                        if ex:
                            reason = 'Error: %s' % (str(ex), )
                        else:
                            reason = 'EOF'

                    job.destroy(reason=reason)

                self.stop()
                if not ex:
                    self.log.debug('No more data to process, connection closed')

        def process(self, message):
            try:
                message_type = message[0]
                self.log.debug('Processing %r message', message_type)
                if message_type == 'PING':
                    pong_sleep = self.cfg.pingpong_sleep_seconds
                    if pong_sleep > 0:
                        gevent.sleep(pong_sleep)
                    self.conn.pong(jid=None if len(message) < 2 else message[1])

                elif message_type == 'CALL':
                    self.idle_timer.cancel()
                    jid = message[1]
                    if self.rpc.stopping:
                        raise joint_errors.Reconnect

                    name, args, kwargs = message[2]
                    handler = self.rpc.handler(name)
                    if not handler:
                        self.conn.error('Handler for %r not registered', name)
                        return

                    if jid in self.active_jobs:
                        self.conn.error('Job #%d already registered.', jid)
                        return

                    try:
                        job = RPCJob(self.conn.ctx, jid, self.conn, name, handler, args, kwargs)
                    except Exception as ex:
                        self.log.exception('Unhandled exception during creating job object.')
                        self.conn.error('Failed to create job: %s', str(ex))
                        return

                    self.active_jobs[job.id] = job

                    try:
                        job.register()
                        job.start()
                    except Exception as ex:
                        self.log.exception('Unhandled exception during starting job.')
                        self.conn.error('Failed to add job: %s', str(ex))
                        self.active_jobs.pop(job.id)
                        job.destroy(reason='Error: %s' % (str(ex), ))
                        return

                elif message_type == 'FEEDBACK':
                    self.idle_timer.cancel()
                    jid = message[1]
                    value = message[2]
                    job = self.active_jobs.get(jid)
                    if job is None:
                        self.conn.error('Job #%d not registered.', jid)
                        return
                    job.feedback = value

                elif message_type == 'DROP':
                    jid = message[1]
                    job = self.active_jobs.pop(jid, None)
                    if job:
                        job.destroy(reason='DROP requested')

                    if len(self.active_jobs) == 0:
                        self.idle_timer.start()
                else:
                    self.log.error('Got invalid message: %r', message)
                    self.conn.error('Not supported message %r', message)

            except joint_errors.Reconnect:
                self.log.warning('Server currently is in stopping state. Asking client to reconnect.')
                self.conn.fail(jid, sys.exc_info())
            except Exception as ex:
                self.log.exception('Failed to handle message %r', message)
                self.conn.error('Failed to handle message: %s', str(ex))

        def stop(self, grn=None):
            if not self.conn:
                return
            self.log.debug('Connection drop (worker stopped)')
            self.rpc.drop_connection(self.conn.id)
            self.conn.close()
            self.conn = None

    def __init__(self, ctx):
        self.ctx = ctx
        self.cfg = ctx.cfg.rpc
        self.log = ctx.log.getChild('rpc')
        self.log.debug('Initializing')

        self.__worker_grn = None
        self.__workers = {}
        self.__stalled_jobs_timeout = getattr(self.ctx, 'stalled_jobs_timeout', None)
        self.__stalled_jobs_watchdog_grn = None
        self.ctx.stalled_jobs = None
        cls = self.__class__
        self.stopping = False
        self.__handlers = cls.__handlers[cls.__module__]
        assert self.__handlers

    @classmethod
    def _register_handler(cls, func, name=None):
        if func.__module__ not in cls.__handlers:
            cls.__handlers[func.__module__] = {}
        cls.__handlers[func.__module__][name if name else func.__name__] = func
        return func

    # Decorators {{{
    @classmethod
    def simple(cls, name=None, _=None):
        func, name = (None, name) if not name or isinstance(name, str) else (name, None)

        def _decorator(func):
            @functools.wraps(func)
            def _wrapper(*args, **kwargs):
                kwargs.pop('job')
                return func(*args, **kwargs)
            return cls._register_handler(_wrapper, name)
        return _decorator if not func else _decorator(func)

    @classmethod
    def __job_arg(cls, func):
        arg_spec = inspect.getargspec(func)
        try:
            idx = arg_spec.args.index('job')
        except ValueError:
            idx = None
        job_in_kwargs = idx is not None and idx >= len(arg_spec.args) - (
            len(arg_spec.defaults) if arg_spec.defaults else 0
        )
        return idx, job_in_kwargs

    @classmethod
    def generator(cls, name=None, original_func=None):
        func, name = (None, name) if not name or isinstance(name, str) else (name, None)

        def _decorator(func):
            idx, job_in_kwargs = cls.__job_arg(original_func or func)

            @functools.wraps(func)
            def _wrapper(*args, **kwargs):
                job = (kwargs.pop if idx is None else kwargs.get)('job')
                if idx and not job_in_kwargs:
                    args = args[:idx] + (job,) + args[idx:]
                try:
                    it = func(*args, **kwargs)
                    while True:
                        job.state(it.next())
                except StopIteration as ex:
                    return None if not ex.args else ex.args[0]
            return cls._register_handler(_wrapper, name)
        return _decorator if not func else _decorator(func)

    @classmethod
    def dupgenerator(cls, name=None, original_func=None):
        func, name = (None, name) if not name or isinstance(name, str) else (name, None)

        def _decorator(func):
            idx, job_in_kwargs = cls.__job_arg(original_func or func)

            @functools.wraps(func)
            def _wrapper(*args, **kwargs):
                job = (kwargs.pop if idx is None else kwargs.get)('job')
                if idx and not job_in_kwargs:
                    args = args[:idx] + (job,) + args[idx:]
                generator = func(*args, **kwargs)
                for value in generator:
                    job.state(value)
                    feedback = job.feedback
                    if feedback is not None:
                        try:
                            generator.send(feedback)
                        except StopIteration:
                            break
            return cls._register_handler(_wrapper, name)
        return _decorator if not func else _decorator(func)

    @classmethod
    def full(cls, name=None, original_func=None):
        func, name = (None, name) if not name or isinstance(name, str) else (name, None)

        def _decorator(func):
            idx, job_in_kwargs = cls.__job_arg(original_func or func)

            @functools.wraps(func)
            def _wrapper(*args, **kwargs):
                if not job_in_kwargs:
                    args = args[:idx] + (kwargs.pop('job'),) + args[idx:]
                return func(*args, **kwargs)
            return cls._register_handler(_wrapper, name)
        return _decorator if not func else _decorator(func)
    # Decorators }}}

    def start(self):
        if self.__stalled_jobs_timeout:
            self.ctx.stalled_jobs = StalledJobs(self.__stalled_jobs_timeout)
            self.__stalled_jobs_watchdog_grn = gevent.spawn(self.__stalled_jobs_watchdog)
        return self

    def stop(self, graceful=False):
        self.stopping = True
        for v in self.__workers.values():
            if not graceful or not len(v[1].active_jobs):
                v[0].kill(gevent.GreenletExit('RPC stopping'))
        if self.__stalled_jobs_watchdog_grn is not None:
            self.__stalled_jobs_watchdog_grn.kill(gevent.GreenletExit)
            self.__stalled_jobs_watchdog_grn = None
        return self

    def join(self):
        if self.__stalled_jobs_watchdog_grn:
            self.__stalled_jobs_watchdog_grn.join()

    def get_connection_handler(self):
        return Connection.Handler(self.ctx, self, Server.AGE)

    def handler(self, name):
        h = self.__handlers.get(name, None)
        if not h:
            return

        @functools.wraps(h)
        def _wrapper(*args, **kwargs):
            try:
                return h(self, *args, **kwargs)
            finally:
                if self.ctx.stalled_jobs:
                    del self.ctx.stalled_jobs[kwargs['job'].jid]
        return _wrapper

    def add_connection(self, sock, sid):
        if sock.peer != '':  # nodelay not supported on unix sockets
            sock.nodelay = self.cfg.socket_nodelay

        sock.send_buffer = self.cfg.socket_send_buffer
        sock.receive_buffer = self.cfg.socket_receive_buffer
        sock.timeout = self.cfg.socket_timeout

        conn = Connection(self.ctx, sock, sid)
        worker = RPC.Worker(conn, self)

        grn = gevent.spawn(worker.loop)
        self.__workers[sid] = (grn, worker)
        grn.link(worker.stop)

        # Link any unhandled error, so they will try to report error
        grn.link_exception(lambda grn: conn.fatal('(connection worker fatal) %s: %s' % (
            grn.exception.__class__.__name__,
            str(grn.exception)
        )))

    def drop_connection(self, sid):
        self.__workers.pop(sid, None)
        if self.stopping and not self.__workers:
            self.stop()

    @property
    def counters(self):
        return {
            joint_utils.sid2str(sid): len(worker[1].active_jobs)
            for sid, worker in self.__workers.iteritems()
        }

    def kill_active_jobs(self, reason=''):
        for worker in self.__workers.values():
            worker[0].kill(gevent.GreenletExit(reason))

    def __stalled_jobs_watchdog(self):
        log = self.log.getChild('stalled_jobs')
        log.info('Started')

        try:
            while True:
                gevent.sleep(self.__stalled_jobs_timeout)
                stalled = list(self.ctx.stalled_jobs)
                if stalled:
                    for job_id, greenlet, timestamp in stalled:
                        del self.ctx.stalled_jobs[job_id]
                    self.on_stalled_jobs(stalled)
                    continue
        finally:
            log.info('Stopped')

    def on_stalled_jobs(self, stalled):
        pass


class Server(object):
    AGE = 1  # Server age (or epoch or version) to be reported to client

    def __init__(self, ctx):
        self.ctx = ctx
        self.cfg = ctx.cfg.server

        self.log = ctx.log.getChild('server')
        self.log.debug('Initializing server aged %r', self.AGE)

        self.__sock = None
        self.__sessions = {}
        self.__worker_grn = None
        self.__connection_handlers = {
            'CHK ': lambda sock, sid: sock.close() or self.log.debug('Got CHK packet, closing connection')
        }

    def register_connection_handler(self, handler):
        assert isinstance(handler, Connection.Handler)
        self.__connection_handlers[handler.magic] = handler.handle_session

    def start(self):
        assert self.__worker_grn is None
        log = self.log.getChild('start')

        unix = self.cfg.unix

        host = self.cfg.host
        port = int(self.cfg.port) if self.cfg.port is not None else None

        backlog = int(self.cfg.backlog)

        if unix:
            assert unix and port is None and host is None
        family = socket.AF_UNIX if unix else socket.AF_INET6 | socket.AF_INET

        self.__sock = gevent.socket.socket(family, socket.SOCK_STREAM)
        self.__sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)

        if not unix:
            self.__sock.bind((host, port))
        else:
            sock_dir = os.path.dirname(unix)
            if not os.path.exists(sock_dir):
                os.makedirs(sock_dir)

            if os.path.exists(unix):
                sock2 = gevent.socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)

                # We can wait a lot if other process will not accept connection
                # (e.g. if it is in STSTP state).
                sock2.settimeout(10)
                try:
                    sock2.connect(unix)
                    sock2.sendall('CHK ')
                except Exception as ex:
                    if isinstance(ex, socket.timeout):
                        raise
                else:
                    log.critical('We connected to socket \'%s\' -- probably another instance running' % (unix, ))
                    raise SystemExit(2)
                finally:
                    sock2.close()

                os.unlink(unix)

            self.__sock.bind(unix)
            os.chmod(unix, self.cfg.mode)

        self.__sock.listen(backlog)

        if not unix:
            log.info('Listening on: %s:%d (backlog: %d)', host, port, backlog)
        else:
            log.info('Listening on: unix:%s (backlog: %d)', unix, backlog)

        self.__worker_grn = gevent.spawn(self.__worker_loop)
        return self

    @property
    def port(self):
        return self.__sock.getsockname()[1]

    @property
    def running(self):
        return bool(self.__worker_grn)

    def on_fork(self):
        if self.__worker_grn is not None:
            self.__worker_grn.kill(gevent.GreenletExit)
            self.__sock.close()
        self.__worker_grn = None

    def stop(self):
        if self.__worker_grn is not None:
            self.__worker_grn.kill(gevent.GreenletExit)
            try:
                self.__sock.shutdown(socket.SHUT_RDWR)
            except:
                pass
            self.__sock.close()
            try:
                unix = self.cfg.unix
                os.unlink(unix)
            except:
                pass
        self.__worker_grn = None
        return self

    def join(self):
        if self.__worker_grn:
            self.__worker_grn.join()

    def age(self, sid):
        return self.__sessions.get(sid)

    def __worker(self, conn, addr):
        logger = self.log.getChild('worker')
        sock = joint_socket.Socket(conn, gevent_mode=True)
        try:
            age, magic = joint_utils.ConnectionHandler.handle_greetings(
                sock, self.ctx.cfg.rpc.handshake_receive_timeout
            )
        except gevent.Timeout:
            self.log.error('Magic receive timeout for client %r', sock.addr)
            sock.close()
            return
        except joint_socket.EOF:
            logger.error('Got EOF while receiving magic from %r', sock.addr)
            sock.close()
            return

        if magic not in self.__connection_handlers:
            logger.error('No handlers for magic 0x%x for client %r', magic, sock.addr)
            sock.close()
            return

        sid = None
        for i in range(self.ctx.cfg.rpc.uid_generation_tries):
            sid = random.getrandbits(joint_utils.SID_BITS)
            if sid not in self.__sessions:
                break

        if sid is None or sid in self.__sessions:
            logger.error('Unable to generate unique session ID for %r', sock.addr)
            sock.close()
            return
        self.__sessions[sid] = age

        try:
            self.__connection_handlers[magic](sock, sid)
        except Exception:
            logger.exception(
                'Got unhandled error while running connection handler %s, client %r',
                self.__connection_handlers[magic], sock.addr
            )
            sock.close()

        if addr != '':
            logger.debug('New connection aged %r from %r', age, addr)
        else:
            logger.debug('New connection aged %r (unix socket)', age)

    def __worker_loop(self):
        log = self.log.getChild('worker')
        log.info('Started')

        while 1:
            try:
                gevent.spawn(self.__worker, *self.__sock.accept())
            except gevent.GreenletExit:
                log.info('Received stop signal')
                break
            except Exception:
                log.exception('Unhandled exception')
                gevent.sleep(0.1)  # avoid busy loops
