from __future__ import absolute_import, print_function, division

import collections
import errno
import functools
import gevent
import gevent.queue
import gevent.socket
import os
import six
import random
import socket

from ya.skynet.util.errors import formatException
from ya.skynet.util.logging import MessageAdapter
from ya.skynet.util.sys.gettime import monoTime
from ya.skynet.util import pickle

known_exceptions = ()
try:
    from api.cqueue import exceptions as ace
    known_exceptions = known_exceptions + (ace.CQueueError,)
except ImportError:
    pass

from .. import exceptions
known_exceptions = known_exceptions + (exceptions.CQueueError,)

from .socket import Socket, EOF
from .utils import ConnectionHandler, sid2str
from . import MAGIC


class Connection(object):
    class Handler(ConnectionHandler):
        HANDSHAKE_RECEIVE_TIMEOUT = 10
        HANDSHAKE_SEND_TIMEOUT = 10

        def __init__(self, log, rpc):
            self.rpc = rpc
            self.log = log.getChild('handler')
            super(Connection.Handler, self).__init__(
                MAGIC,
                self.HANDSHAKE_SEND_TIMEOUT,
                self.HANDSHAKE_RECEIVE_TIMEOUT
            )

        def handle(self, sock, sid=None):
            try:
                ret = super(Connection.Handler, self).handle(sock, sid)
                if ret is None:
                    self.log.error('Peer %r: handshake timed out (failed to write)', sock.peer)
                    sock.close()
                    return
            except EOF:
                self.log.error('Peer %r disconnected during handshake (we were attempting to send)', sock.peer)
                sock.close()
                return

            self.rpc.add_connection(sock, sid)
            return ret

    def __init__(self, log, sock, sid):
        self.id = sid

        self.log = MessageAdapter(
            log.getChild('conn'),
            fmt='[jb]{%(sid)s:0}  %(message)s',
            data={'sid': sid2str(sid, Server.UID_BITS)},
        )

        self.__sock = sock
        if sock.peer != b'':
            self.log.info('Connection accepted on %s:%d', sock.peer[0], sock.peer[1])
        else:
            self.log.info('Connection accepted on unix socket')

    def peer(self):
        return self.__sock.peer

    def peer_id(self):
        return self.__sock.peer_id

    def read(self, buffsize):
        return self.__sock.read_msgpack(buffsize)

    def register(self, jid):
        self.__sock.write_msgpack(('REGISTERED', jid))

    def error(self, msg, *args):
        self.__sock.write_msgpack(('ERROR', msg % args))

    def fail(self, jid, ex):
        self.__sock.write_msgpack(('FAILED', jid, pickle.dumps(ex)))

    def state(self, jid, state):
        self.__sock.write_msgpack(('STATE', jid, state))

    def finish(self, jid, result):
        self.__sock.write_msgpack(('COMPLETE', jid, result))

    def pong(self, jid=None):
        self.__sock.write_msgpack(('PONG', jid))

    def fatal(self, msg, *args):
        try:
            self.error(msg, *args)
        except (socket.error, EOF):
            # Ignore possible socket exceptions here
            pass
        finally:
            self.close()

    def close(self):
        self.__sock.close()


class RPCJob(object):
    def __init__(
        self, log, jid, conn, codename, handler,
        args, kwargs, on_finish=None, on_fail=None,
        silent=False
    ):
        self.id = jid

        self.__conn = conn
        self.__handler = handler
        self.__codename = codename
        self.__args = args
        self.__kwargs = kwargs

        self.__on_finish = on_finish
        self.__on_fail = on_fail

        self.__worker_grn = None
        self.__start_time = None
        self.__finish_time = None
        self.__state = 0  # 0:stopped, 1:running, 2:complete, 3:failed
        self.__feedback_queue = gevent.queue.Queue()

        self.__silent = silent

        self.log = MessageAdapter(
            log.logger.getChild('job'),
            fmt='[jb]{%(uid)s}  [%(codename)s]  %(message)s',
            data={
                'uid': '%s:%d' % (sid2str(conn.id, Server.UID_BITS), jid),
                'codename': codename
            }
        )

    @property
    def name(self):
        return self.__codename

    @property
    def uid(self):
        return '%x:%d' % (self.__conn.id, self.id)

    @property
    def started(self):
        return self.__start_time

    @property
    def stopped(self):
        return self.__finish_time

    @property
    def peer(self):
        return self.__conn.peer()

    @property
    def peer_id(self):
        return self.__conn.peer_id()

    def __set_slocal(self, reset=False):
        gevent.getcurrent().slocal = {}  # reset from default

        if not reset:
            gevent.getcurrent().slocal['job_uid'] = '%s:%d' % (sid2str(self.__conn.id, Server.UID_BITS), self.id)
            gevent.getcurrent().slocal['job_name'] = self.__codename

    def __call(self, func, args, kwargs):
        try:
            assert 'job' not in kwargs
            kwargs['job'] = self

            if not self.__silent:
                self.log.info('Started')

            self.__set_slocal()
            result = func(*args, **kwargs)
            self.__set_slocal(reset=True)
        except gevent.GreenletExit:
            if not self.__silent:
                self.log.debug('Killed (caught GreenletExit)')
        except BaseException as ex:
            self.__set_slocal(reset=True)
            self.fail(ex, formatException())
        else:
            self.finish(result)

    def register(self):
        self.__conn.register(self.id)

    def start(self):
        if not self.__silent:
            self.log.debug('Start')
        self.__start_time = monoTime()
        self.__worker_grn = gevent.Greenlet(self.__call, self.__handler, self.__args, self.__kwargs)
        self.__worker_grn.start()

    def stop(self, error=gevent.GreenletExit, reason=None):
        if not self.__silent:
            self.log.debug('Stop (reason: %r)', reason)

        if self.__worker_grn is not None:
            try:
                self.__worker_grn.kill(error)
            except error:
                pass

    def state(self, state):
        try:
            self.__conn.state(self.id, state)
        except socket.error as ex:
            if ex.errno in (errno.EBADF, ):
                self.stop()
                return
            raise

    @property
    def feedback(self):
        return self.__feedback_queue.get()

    @feedback.setter
    def feedback(self, value):
        self.__feedback_queue.put(value)

    def finish(self, result):
        self.__finish_time = monoTime()

        if not self.__silent:
            self.log.info('Finished (total time %0.4fs)', self.__finish_time - self.__start_time)

        if self.__on_finish:
            self.__on_finish()

        try:
            self.__conn.finish(self.id, result)
        except socket.error as ex:
            if ex.errno in (errno.EBADF, ):
                return
            raise

    def fail(self, ex, tb):
        self.__finish_time = monoTime()

        if isinstance(ex, known_exceptions):
            self.log.error('Error: %s', tb.strip().rsplit('\n', 1)[1])
        else:
            self.log.error('Unhandled exception during handler run')
            for line in tb.split('\n'):
                if line:
                    self.log.error(line)

        self.log.error('Failed (total time %0.4fs)', self.__finish_time - self.__start_time)

        if self.__on_fail:
            self.__on_fail(ex)

        try:
            self.__conn.fail(self.id, ex)
        except socket.error as ex:
            if ex.errno in (errno.EBADF, ):
                return
            raise

    def destroy(self, reason=None):
        if not self.__silent:
            self.log.debug('Destroy (reason: %r)', reason)

        if self.__worker_grn is not None and not self.__worker_grn.ready():
            self.stop(reason=reason)

        del self.__worker_grn
        del self.__conn
        del self.__handler


class RPC(object):
    # Registered methods, which will be acceptable remotely

    SOCKET_NODELAY = True
    SOCKET_NODELAY_CUTOFF = 2 * 1024 * 1024
    SOCKET_RECEIVE_BUFFER = 16 * 1024
    SOCKET_SEND_BUFFER = 16 * 1024
    SOCKET_TIMEOUT = 600

    stats = {
        'counters': {
            'errors': {},
            'completed': collections.defaultdict(int),
            'sessions': 0,
            'messages': collections.defaultdict(int),
        },
        'active': {
            'sessions': 0,
            'jobs': collections.defaultdict(int)
        }
    }

    class Worker(object):
        IDLE_TIMEOUT = 60
        RECEIVE_BUFFER = 16 * 1024
        PINGPONG_SLEEP_SECONDS = 0

        def __init__(self, conn, rpc):
            self.rpc = rpc
            self.conn = conn
            self.log = conn.log
            self.active_jobs = {}  # active jobs by uid for this connection
            self.idle_timer = gevent.Timeout(self.IDLE_TIMEOUT)
            self.conn_stuck_timer = gevent.Timeout(RPC.SOCKET_TIMEOUT)

            self.log.info('New connection (worker spawned)')

        def loop(self):
            ex = False  # see note below in except block
            ex_desc = None
            try:
                while True:
                    self.conn_stuck_timer.start()

                    iteration = None
                    for iteration, message in enumerate(self.conn.read(self.RECEIVE_BUFFER)):
                        self.conn_stuck_timer.cancel()
                        self.conn_stuck_timer.start()
                        self.process(message)

                    self.conn_stuck_timer.cancel()

                    if not len(self.active_jobs):
                        break
                    if iteration is None:
                        self.log.warning("No data received from socket - connection closed by remote side.")
                        break

            except gevent.GreenletExit as e:
                ex = e
                self.log.info('Connection greenlet killed: %s', ex)
                self.conn.fatal('RPC worker killed: %s', str(ex))

            except (Exception, gevent.Timeout) as e:
                # We just grabbing exception to "ex" variable here
                # and will use it in finally: block to determine was exception
                # happened or not. But, also log error
                ex = e
                if ex == self.idle_timer:
                    self.log.info('Connection was idle for %d seconds, closing', self.IDLE_TIMEOUT)
                    ex_desc = 'Idle for %d seconds, closing connection' % (self.IDLE_TIMEOUT, )
                elif ex == self.conn_stuck_timer:
                    self.log.info('Connection timed out (%d seconds)', RPC.SOCKET_TIMEOUT)
                    ex_desc = 'Connection timed out'
                elif getattr(ex, 'errno', None) == errno.ECONNRESET:
                    self.log.info('Connection closed (ECONNRESET)')
                    ex_desc = 'Connection closed (ECONNRESET)'
                else:
                    self.log.warning('Unhandled exception: %s', formatException())
                    self.conn.error(str(ex))

            finally:
                self.idle_timer.cancel()

                for job in list(self.active_jobs.values()):
                    self.rpc.stats['active']['jobs'][job.name] -= 1

                    if ex_desc is not None:
                        reason = ex_desc
                    else:
                        if ex:
                            reason = 'Error: %s' % (str(ex), )
                        else:
                            reason = 'EOF'

                    job.destroy(reason=reason)

                self.stop()
                if not ex:
                    self.log.debug('No more data to process, connection closed')

        def process(self, message):
            try:
                message_type = message[0]

                self.rpc.stats['counters']['messages'][message_type] += 1

                if message_type == 'PING':
                    pong_sleep = self.PINGPONG_SLEEP_SECONDS
                    if pong_sleep > 0:
                        gevent.sleep(pong_sleep)
                    self.conn.pong(jid=None if len(message) < 2 else message[1])

                elif message_type == 'NAME':
                    self.log = self.log.getChild(message[1])

                elif message_type == 'CALL':
                    self.idle_timer.cancel()

                    jid = message[1]
                    name, args, kwargs = message[2]
                    handler, handler_opts = self.rpc.handler(name)
                    if not handler:
                        self.conn.error('Handler for %r not registered', name)
                        return

                    if jid in self.active_jobs:
                        self.conn.error('Job #%d already registered.', jid)
                        return

                    def on_finish():
                        self.rpc.stats['counters']['completed'][name] += 1

                    def on_fail(ex):
                        errstat = self.rpc.stats['counters']['errors']
                        if name not in errstat:
                            errstat[name] = collections.defaultdict(int)
                        errstat[name][ex.__class__.__name__] += 1

                    try:
                        job = RPCJob(
                            self.log, jid, self.conn, name, handler,
                            args, kwargs, on_finish, on_fail,
                            silent=handler_opts.get('silent', False)
                        )
                    except Exception as ex:
                        self.log.error('Unhandled exception during creating job object: %s', formatException())
                        self.conn.error('Failed to create job: %s', str(ex))
                        return

                    self.rpc.stats['active']['jobs'][name] += 1
                    self.active_jobs[job.id] = job

                    try:
                        job.register()
                        job.start()
                    except Exception as ex:
                        self.log.error('Unhandled exception during starting job: %s', formatException())
                        self.conn.error('Failed to add job: %s', str(ex))
                        self.active_jobs.pop(job.id)
                        self.rpc.stats['active']['jobs'][name] -= 1
                        job.destroy(reason='Error: %s' % (str(ex), ))
                        return

                elif message_type == 'FEEDBACK':
                    self.idle_timer.cancel()
                    jid = message[1]
                    value = message[2]
                    job = self.active_jobs.get(jid)
                    if job is None:
                        self.conn.error('Job #%d not registered.', jid)
                        return
                    job.feedback = value

                elif message_type == 'DROP':
                    jid = message[1]
                    job = self.active_jobs.pop(jid, None)
                    if job:
                        self.rpc.stats['active']['jobs'][job.name] -= 1
                        job.destroy(reason='DROP requested')

                    if len(self.active_jobs) == 0:
                        self.idle_timer.start()
                else:
                    self.log.error('Got invalid message: %r', message)
                    self.conn.error('Not supported message %r', message)

            except Exception as ex:
                self.log.error('Failed to handle message %r: %s', message, formatException())
                self.conn.error('Failed to handle message: %s', str(ex))

        def stop(self, grn=None):
            if not self.conn:
                return
            self.log.debug('Connection drop (worker stopped)')
            self.rpc.drop_connection(self.conn.id)
            self.conn.close()
            self.conn = None

    def __init__(self, log):
        self.log = log.getChild('rpc')
        self.log.debug('Initializing')

        self.__workers = {}

        self.__handlers = {}

    def mount(self, func, name=None, typ='simple', silent=False):
        self.__handlers[name if name else func.__name__] = (typ, func, {'silent': silent})

    # Decorators {{{
    @classmethod
    def simple(cls, name=None):
        func, name = (None, name) if not name or isinstance(name, str) else (name, None)

        def _decorator(func):
            @functools.wraps(func)
            def _wrapper(*args, **kwargs):
                kwargs.pop('job')
                return func(*args, **kwargs)
            return _wrapper
        return _decorator if not func else _decorator(func)

    @classmethod
    def generator(cls, name=None):
        func, name = (None, name) if not name or isinstance(name, str) else (name, None)

        def _decorator(func):
            @functools.wraps(func)
            def _wrapper(*args, **kwargs):
                job = kwargs.pop('job')
                try:
                    it = func(*args, **kwargs)
                    while True:
                        job.state(next(it))
                except StopIteration as ex:
                    return None if not ex.args else ex.args[0]
            return _wrapper
        return _decorator if not func else _decorator(func)

    @classmethod
    def dupgenerator(cls, name=None):
        func, name = (None, name) if not name or isinstance(name, str) else (name, None)

        def _decorator(func):
            @functools.wraps(func)
            def _wrapper(*args, **kwargs):
                job = kwargs.pop('job')
                generator = func(*args, **kwargs)
                for value in generator:
                    job.state(value)
                    feedback = job.feedback
                    if feedback is not None:
                        try:
                            generator.send(feedback)
                        except StopIteration:
                            break
            return _wrapper
        return _decorator if not func else _decorator(func)

    @classmethod
    def full(cls, name=None):
        func, name = (None, name) if not name or isinstance(name, str) else (name, None)

        def _decorator(func):
            @functools.wraps(func)
            def _wrapper(*args, **kwargs):
                job = kwargs.pop('job')
                return func(job, *args, **kwargs)
            return _wrapper
        return _decorator if not func else _decorator(func)
    # Decorators }}}

    def start(self):
        return self

    def stop(self):
        for v in list(self.__workers.values()):
            v[0].kill(gevent.GreenletExit('RPC stopping'))

        return self

    def join(self):
        gevent.joinall(list(self.__workers.values()))

    def get_connection_handler(self):
        return Connection.Handler(self.log, self)

    def handler(self, name):
        typ, meth, opts = self.__handlers.get(name, (None, None, {}))
        if not meth:
            return None, None

        _wrapper = getattr(self, typ)(name)(meth)

        return _wrapper, opts

    def add_connection(self, sock, sid):
        self.stats['counters']['sessions'] += 1
        self.stats['active']['sessions'] += 1

        if sock.peer != b'':  # nodelay not supported on unix sockets
            sock.nodelay = self.SOCKET_NODELAY

        sock.send_buffer = self.SOCKET_SEND_BUFFER
        sock.receive_buffer = self.SOCKET_RECEIVE_BUFFER

        conn = Connection(self.log, sock, sid)
        worker = RPC.Worker(conn, self)

        grn = gevent.spawn(worker.loop)
        self.__workers[sid] = (grn, worker)
        grn.link(worker.stop)

        # Link any unhandled error, so they will try to report error
        grn.link_exception(lambda grn: conn.fatal('(connection worker fatal) %s: %s' % (
            grn.exception.__class__.__name__,
            str(grn.exception)
        )))

    def drop_connection(self, sid):
        self.stats['active']['sessions'] -= 1
        self.__workers.pop(sid, None)

    @property
    def counters(self):
        return {
            sid2str(sid, self.UID_BITS): len(worker[1].active_jobs)
            for sid, worker in six.iteritems(self.__workers)
        }


class Server(object):
    MAGIC_RECEIVE_TIMEOUT = 10
    UID_BITS = 32
    UID_GENERATION_TRIES = 10
    UID_GENERATION_RETRY_SLEEP = 0.1

    def __init__(self, log, backlog, max_conns, host=None, port=None, unix=None):
        self.log = log.getChild('server')
        self.log.debug('Initializing')

        self.__sessions = {}
        self.__connection_handlers = {
            'CHK ': lambda sock, sid: sock.close() or self.log.debug('Got CHK packet, closing connection')
        }
        self.__worker_grn = None

        self._s_host = host
        self._s_port = port
        self._s_unix = unix
        self._s_backlog = backlog
        self._max_conns = max_conns

    def register_connection_handler(self, handler):
        assert isinstance(handler, Connection.Handler)

        magic = handler.get_magic()
        assert isinstance(magic, six.binary_type) and len(magic) == 4

        self.__connection_handlers[magic] = handler.handle

    def start(self):
        assert self.__worker_grn is None
        log = self.log.getChild('start')

        unix = self._s_unix

        host = self._s_host
        port = int(self._s_port) if self._s_port is not None else None

        backlog = int(self._s_backlog)

        if unix:
            assert unix and port is None and host is None

        family = socket.AF_INET if not unix else socket.AF_UNIX

        self.__sock = gevent.socket.socket(family, socket.SOCK_STREAM)
        self.__sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)

        if not unix:
            self.__sock.bind((host, port))
        else:
            unix_abstract = unix.startswith('\0')
            if not unix_abstract:
                sock_dir = os.path.dirname(unix)
                if not os.path.exists(sock_dir):
                    os.makedirs(sock_dir)

                if os.path.exists(unix):
                    sock2 = gevent.socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)

                    # We can wait a lot if other process will not accept connection
                    # (e.g. if it is in STSTP state).
                    sock2.settimeout(10)
                    try:
                        sock2.connect(unix)
                        sock2.sendall('CHK ')
                    except Exception as ex:
                        if isinstance(ex, socket.timeout):
                            raise
                    else:
                        log.critical('We connected to socket \'%s\' -- probably another instance running' % (unix, ))
                        raise SystemExit(2)
                    finally:
                        sock2.close()

                    os.unlink(unix)

            self.__sock.bind(unix)

            if not unix_abstract:
                os.chmod(unix, 0o666)

        self.__sock.listen(backlog)

        if not unix:
            log.info('Listening on: %s:%d (backlog: %d)', host, port, backlog)
        else:
            log.info(
                'Listening on: unix:%s (%sbacklog: %d)',
                unix.lstrip(),
                'abstract, ' if unix_abstract else '',
                backlog
            )

        self.__worker_grn = gevent.spawn(self.__worker_loop)
        return self

    @property
    def port(self):
        return self.__sock.getsockname()[1]

    def stop(self):
        if self.__worker_grn is not None:
            self.__worker_grn.kill(gevent.GreenletExit)
            try:
                self.__sock.shutdown(socket.SHUT_RDWR)
            except:
                pass
            self.__sock.close()
            try:
                unix = self._s_unix
                os.unlink(unix)
            except:
                pass
        return self

    def join(self):
        self.__worker_grn.join()

    def __worker(self, conn, addr):
        log = self.log.getChild('worker')
        if addr != '':
            log.debug('New connection from %r', addr)
        else:
            log.debug('New connection (unix socket)')

        sock = Socket(conn, gevent_mode=True)

        try:
            magic = sock.read(4, timeout=self.MAGIC_RECEIVE_TIMEOUT)
        except gevent.Timeout:
            self.log.error('Magic receive timeout for client %r', sock.addr)
            sock.close()
            return
        except EOF:
            log.error('Got EOF while receiving magic from %r', sock.addr)
            sock.close()
            return

        if magic not in self.__connection_handlers:
            log.error('No handlers for magic %r for client %r', magic, sock.addr)
            sock.close()
            return

        sid = None
        for i in range(self.UID_GENERATION_TRIES):
            sid = random.getrandbits(self.UID_BITS)
            if sid not in self.__sessions:
                break

        if sid is None or sid in self.__sessions:
            log.error('Unable to generate unique session ID for %r', sock.addr)
            sock.close()
            return
        self.__sessions[sid] = sid

        try:
            self.__connection_handlers[magic](sock, sid)
        except Exception:
            log.error(
                'Got unhandled error while running connection handler %s: %s, client %r',
                self.__connection_handlers[magic],
                formatException(),
                sock.addr
            )
            sock.close()

    def __worker_loop(self):
        log = self.log.getChild('worker')
        log.info('Started')

        while 1:
            try:
                conn, addr = self.__sock.accept()
                gevent.spawn(self.__worker, conn, addr)
            except gevent.GreenletExit:
                log.info('Received stop signal')
                break
            except Exception:
                log.error('Unhandled exception: %s', formatException())
                gevent.sleep(0.1)  # avoid busy loops
