import argparse
import logging
import msgpack
import os
import psutil
import Queue
import random
import resource
import signal
import simplejson
import socket
import sys
import threading
import time
import traceback

from setproctitle import setproctitle

from api.logger import SkynetLoggingHandler

from core.daemon import logger


BASE_PROCTITLE = 'skybone-coord-router'
BUFFER_JOBS = 0.5


def pwatch():
    while True:
        if os.getppid() == 1:
            os._exit(1)
        time.sleep(1)


def setup_logger(logprefix, logpath=None):
    log = logging.getLogger('')
    log.setLevel(logging.DEBUG)
    if logpath is not None:
        handler = logging.handlers.RotatingFileHandler(os.path.join(logpath, 'skybone-coord-router.log'), maxBytes=0.75*(1<<30), backupCount=2)
        handler.setFormatter(logger.get_formatter('full'))
    else:
        handler = SkynetLoggingHandler(app='skybone-coord', filename='skybone-coord-router.log')
    handler.setLevel(logging.DEBUG)
    log.addHandler(handler)
    return log.getChild('router').getChild(logprefix)


class Transaction(object):
    __slots__ = ('ip', 'port', 'wait', 'deadline')

    def __init__(self, ip, port, wait, deadline):
        self.ip = ip
        self.port = port
        self.wait = wait
        self.deadline = deadline


def avg(v):
    v = list(v)
    if not v:
        return 0
    return sum(v) / len(v)


class SkyboneCoordRouter(object):
    def __init__(self, shards, port, bind_v4, bind_v6, logger):
        self.shards = shards

        self._stopev = threading.Event()
        self._stopev.clear()

        self.bind_v4 = bind_v4
        self.bind_v6 = bind_v6

        self._sock4 = None
        self._sock6 = None

        self.port = port

        self.log = logger

        self._udpqueue = Queue.Queue(maxsize=10)  # max 10 unmsgpacked packets (the rest will be in the socket buff)

        self._udp4_thread = None
        self._udp6_thread = None
        self._job_thread = None

        self._shard_threads = []
        self._shard_queues = []
        self._shard_jobs = []
        self._shard_jobs_ts = time.time()

        # tid => (deadline, min_response, ip, port)
        self._transactions = {}
        self._transactions_clean_ts = 0

        self.msg_packer = msgpack.Packer()

        self.stats = {
            'in_drop': 0,
            'in_drop_job': 0,
            'in_rps': 0,
            'in_rps_misc': 0,
            'in_rps_connect': 0,
            'in_rps_connect_v1': 0,
            'in_rps_connect_v2': 0,
            'in_rps_connect_v3': 0,
            'in_rps_announce': 0,
            'in_rps_announce_legacy': 0,
            'in_rps_announce_v3': 0,
            'in_rps_state_leech': 0,
            'in_rps_state_seed': 0,
            'in_rps_state_stop': 0,
            'in_rps_info': 0,
            'in_rps_invalid': 0,
            'ou_rps': 0,
            'ou_drop': 0,
        }

        self.statslock = threading.Lock()
        self.translock = threading.Lock()

    def statsthread(self):
        while not self._stopev.isSet():
            ts = time.time()

            proc = psutil.Process(os.getpid())

            resusage = resource.getrusage(resource.RUSAGE_SELF)
            try:
                self.stats['rss'] = proc.memory_info().rss
            except: 
                self.stats['rss'] = proc.get_memory_info().rss
            self.stats['cpu_utime'] = resusage.ru_utime
            self.stats['cpu_stime'] = resusage.ru_stime

            self.stats['udp_q'] = self._udpqueue.qsize()
            self.stats['trans'] = len(self._transactions)

            self.stats['shard_queue'] = avg(q.qsize() for q in self._shard_queues)
            self.stats['shard_queue_max'] = avg(q.maxsize for q in self._shard_queues)

            self.stats['squeue'] = '%3d/%3d' % (self.stats['shard_queue'], self.stats['shard_queue_max'])

            sys.stdout.write('STATS: %s\n' % (simplejson.dumps(self.stats), ))
            sys.stdout.flush()

            proctitle = '%s [rps_in:%5d|rps_out:%5d|uqueue:%4d|squeue:%4s]' % (
                BASE_PROCTITLE,
                self.stats['in_rps'], self.stats['ou_rps'],
                self.stats['udp_q'], self.stats['squeue']
            )
            setproctitle(proctitle)

            for k in self.stats:
                self.stats[k] = 0

            with self.translock:
                if ts - self._transactions_clean_ts >= 300:
                    to_drop = set()
                    for tid, transaction in self._transactions.items():
                        if transaction.deadline <= ts:
                            to_drop.add(tid)

                    for tid in to_drop:
                        self._transactions.pop(tid, None)

                    if to_drop:
                        self.log.info('Cleaned %d outdated transactions', len(to_drop))

                    self._transactions_clean_ts = ts

            time.sleep(1 - ts % 1)

    def savestat(self, n, v):
        with self.statslock:
            self.stats[n] += v

    def logerr(self, msg, *args):
        try:
            if args:
                msg = msg % args

            sys.stderr.write(msg + '\n')
            sys.stderr.flush()
            self.log.error(msg)
        except Exception:
            sys.stderr.write('Unable to deliver error log\n')
            sys.stderr.write(traceback.format_exc())
            sys.stderr.flush()

    def udpworker(self, sock):
        self.log.debug('started udp worker on %s', sock.getsockname())

        sock.settimeout(0.1)

        packets = []
        packets_ts = time.time()

        while not self._stopev.isSet():
            try:
                data, peer = sock.recvfrom(2048)
            except socket.error:
                if self._stopev.isSet():
                    return

                if packets and time.time() - packets_ts > 0.1:
                    while not self._stopev.isSet():
                        try:
                            self._udpqueue.put_nowait(packets)
                            break
                        except Queue.Full:
                            self.savestat('in_drop', len(packets))
                            break
                    else:
                        return

                    packets = []
                    packets_ts = time.time()

                continue

            if self._stopev.isSet():
                sock.close()
                return

            packets.append((data, peer))

            if packets and time.time() - packets_ts > 0.1:
                while not self._stopev.isSet():
                    try:
                        self._udpqueue.put_nowait(packets)
                        break
                    except Queue.Full:
                        self.savestat('in_drop', len(packets))
                        break
                else:
                    return

                packets = []
                packets_ts = time.time()

    def _add_shard_jobs(self):
        for idx, jobs in enumerate(self._shard_jobs):
            if not jobs:
                continue

            while not self._stopev.isSet():
                try:
                    self._shard_queues[idx].put(jobs, timeout=1)
                    break
                except Queue.Full:
                    self.savestat('in_drop_job', len(jobs))
                    break

            self._shard_jobs[idx] = []

        self._shard_jobs_ts = time.time()

    def jobworker(self):
        initial_cid = 1736820976
        initial_cid2 = 1964596860
        initial_cid3 = 4079332072

        act_connect = 0
        act_announce = 10
        # act_startup = 20
        # act_shutdown = 30

        shardscnt = len(self.shards)
        rndshard = 0

        while not self._stopev.isSet():
            try:
                try:
                    packets = self._udpqueue.get(timeout=1)
                except Queue.Empty:
                    if self._stopev.isSet():
                        return

                    if time.time() - self._shard_jobs_ts > BUFFER_JOBS:
                        self._add_shard_jobs()

                    continue

                if packets is None:
                    return

                for data, peer in packets:
                    use_any = use_all = False
                    use = None

                    data_raw = data
                    data = msgpack.loads(data_raw)

                    if shardscnt > 1 or 1:
                        cid, action = data[0], data[1]

                        self.savestat('in_rps', 1)

                        if cid == 'PI' and action == 'NG':
                            use_any = True
                            tid = random.randint(0, 2 ** 32 - 1)
                            self.savestat('in_rps_misc', 1)

                        elif cid == 'IN' and action == 'FO':
                            try:
                                hash = data[2].encode('hex')
                                tid = data[3]
                                hashint = int(hash, 16)
                            except Exception as ex:
                                self.log.warning('Invalid INFO request: %r (%s: %s)', data, type(ex).__name__, ex)
                                self.savestat('in_rps_invalid', 1)
                                continue

                            shardnum = hashint % shardscnt

                            use = shardnum
                            self.savestat('in_rps_info', 1)

                        elif cid in (initial_cid, initial_cid2):
                            if action == act_connect:
                                self.log.info('Legacy CONNECT request: %s', peer[0])
                                use_all = True
                                try:
                                    tid = data[2]
                                except Exception as ex:
                                    self.log.warning(
                                        'Invalid CONNECT (v2) request: %r (%s: %s)',
                                        data, type(ex).__name__, ex
                                    )
                                    self.savestat('in_rps_invalid', 1)
                                    continue
                                self.savestat('in_rps_connect', 1)
                                if cid == initial_cid:
                                    self.savestat('in_rps_connect_v1', 1)
                                else:
                                    self.savestat('in_rps_connect_v2', 1)
                            else:
                                self.savestat('in_rps_invalid', 1)
                                continue

                        elif cid == initial_cid3:
                            if action == act_connect:
                                use_all = True
                                try:
                                    tid = data[2]
                                except Exception as ex:
                                    self.log.warning(
                                        'Invalid CONNECT (v3) request: %r (%s: %s)',
                                        data, type(ex).__name__, ex
                                    )
                                    self.savestat('in_rps_invalid', 1)
                                    continue
                                self.savestat('in_rps_connect', 1)
                                self.savestat('in_rps_connect_v3', 1)

                            elif action == act_announce:
                                try:
                                    tid = data[2]
                                    hash = data[4]
                                    state = data[5]

                                    if len(hash) == 40:
                                        hashint = int(hash, 16)
                                    else:
                                        hashint = None

                                except Exception as ex:
                                    self.log.warning(
                                        'Invalid ANNOUNCE (v3) request: %r (%s: %s)',
                                        data, type(ex).__name__, ex
                                    )
                                    self.savestat('in_rps_invalid', 1)
                                    continue

                                if hash == 'CLEAN':
                                    use_all = True
                                else:
                                    shardnum = hashint % shardscnt
                                    use = shardnum

                                self.savestat('in_rps_announce', 1)
                                self.savestat('in_rps_announce_v3',1)

                                try:
                                    state = {
                                        1: 'leech',
                                        2: 'seed',
                                        3: 'stop',
                                    }[state]

                                    self.savestat('in_rps_state_%s' % (state, ), 1)
                                except KeyError:
                                    self.savestat('in_rps_state_unknown', 1)
                            else:
                                self.savestat('in_rps_invalid', 1)
                                continue

                        elif action == act_announce:
                            # Legacy proto without cids
                            try:
                                tid = data[2]
                                state = data[3]
                                hash = data[4]
                                hashint = int(hash, 16)
                            except Exception as ex:
                                self.log.warning(
                                    'Invalid ANNOUNCE (v2) request: %r (%s: %s)',
                                    data, type(ex).__name__, ex
                                )
                                self.savestat('in_rps_invalid', 1)
                                continue

                            self.log.info('Legacy ANNOUNCE request: %s shared %s', peer[0], hash.encode('hex'))

                            shardnum = hashint % shardscnt
                            use = shardnum

                            self.savestat('in_rps_announce_legacy', 1)

                            try:
                                state = {
                                    1: 'leech',
                                    2: 'seed',
                                    3: 'stop',
                                }[state]

                                self.savestat('in_rps_state_%s' % (state, ), 1)
                            except KeyError:
                                self.savestat('in_rps_state_unknown', 1)

                        else:
                            self.savestat('in_rps_invalid', 1)
                            self.log.warning('Invalid action %r, ignoring', action)
                            continue
                    else:
                        self.savestat('in_rps', 1)
                        tid = random.randint(0, 2 ** 32 - 1)
                        use = 0

                    with self.translock:
                        if tid in self._transactions:
                            self.log.warning('Transaction %r already in transactions (come from %r)!', tid, peer)
                            tid = random.randint(0, 2 ** 32 - 1)

                        self._transactions[tid] = Transaction(
                            peer[0], peer[1],
                            shardscnt if use_all else 1,
                            time.time() + 60
                        )

                    req = (tid, peer[0], peer[1], data)

                    if use_all:
                        for jobs in self._shard_jobs:
                            jobs.append(req)

                    elif use_any:
                        self._shard_jobs[rndshard].append(req)

                        rndshard += 1
                        if rndshard >= shardscnt:
                            rndshard = 0

                    else:
                        assert use is not None
                        self._shard_jobs[use].append(req)

                    if time.time() - self._shard_jobs_ts > BUFFER_JOBS:
                        self._add_shard_jobs()

            except Exception as ex:
                self.logerr('Unhandled exception in job worker: %s: %s', type(ex).__name__, ex)
                self.logerr(traceback.format_exc())
                time.sleep(1)

    def shardback(self, sock):
        unpacker = msgpack.Unpacker()

        while True:
            try:
                data = sock.recv(1024 * 8)
            except socket.error:
                sock.shutdown(socket.SHUT_RDWR)
                return

            if not data:
                break

            if self._stopev.isSet():
                sock.shutdown(socket.SHUT_RDWR)
                return

            unpacker.feed(data)

            for responses in unpacker:
                for (tid, data) in responses:
                    if tid not in self._transactions:
                        self.logerr('Cant find transaction %r', tid)
                        send = False
                    else:
                        with self.translock:
                            transaction = self._transactions[tid]
                            transaction.wait -= 1

                            if transaction.wait == 0:
                                self._transactions.pop(tid, None)
                                send = True
                            else:
                                send = False

                        if ':' in transaction.ip:
                            s = self._sock6
                        else:
                            s = self._sock4

                    if send and data:
                        try:
                            s.sendto(data, (transaction.ip, transaction.port))
                        except Exception as ex:
                            self.logerr(
                                'Unable to send data back to %r: %s: %s',
                                (transaction.ip, transaction.port),
                                type(ex).__name__, ex
                            )

                        self.savestat('ou_rps', 1)
                    else:
                        if not data:
                            self.savestat('ou_drop', 1)

    def shardworker(self, idx, queue, sockname):
        while not self._stopev.isSet():
            try:
                try:
                    sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
                    sock.settimeout(1)
                    sock.connect('\0' + sockname)
                except socket.error as ex:
                    self.log.warning('Unable to connect shard %s: %s: %s', sockname, type(ex).__name__, ex)
                    time.sleep(0.1)
                    if self._stopev.isSet():
                        return
                    continue

                sock.settimeout(None)

                # Wait until shard is ready to process packets
                assert sock.recv(1) == '\x42'

                self.log.info('Connected to shard worker %s', sockname)

                try:
                    thr = threading.Thread(target=self.shardback, args=(sock, ))
                    thr.daemon = True
                    thr.start()

                    while not self._stopev.isSet():
                        if not thr.isAlive():
                            self.logerr('shardback thread died!')
                            break

                        jobs = queue.get()

                        if self._stopev.isSet():
                            return

                        jobs_msgpack = self.msg_packer.pack(jobs)
                        sock.sendall(jobs_msgpack)
                finally:
                    sock.shutdown(socket.SHUT_RDWR)
                    while thr.isAlive():
                        thr.join(timeout=1)
                        time.sleep(1)

            except Exception as ex:
                self.logerr('Unhandled exception in shard worker %s: %s: %s', sockname, type(ex).__name__, ex)
                time.sleep(1)
                continue

    def start(self):
        udp_recv_buff = 4 * 1024 * 1024
        udp_send_buff = 4 * 1024 * 1024

        try:
            thr = threading.Thread(target=self.statsthread)
            thr.daemon = True
            thr.start()

            for idx, shard in enumerate(self.shards):
                shard_queue = Queue.Queue(maxsize=10)  # max 5 seconds of jobs

                thr = threading.Thread(
                    target=self.shardworker,
                    args=(idx, shard_queue, shard)
                )
                thr.daemon = True

                self._shard_queues.append(shard_queue)
                assert self._shard_queues[idx] == shard_queue
                self._shard_threads.append(thr)
                self._shard_jobs.append([])

                thr.start()

            self.log.info('Bind v4:%s v6:%s', self.bind_v4, self.bind_v6)

            self._sock6 = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
            self._sock6.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)

            if self.bind_v4 == 'any':
                self._sock6.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 0)

            self._sock6.bind(('' if self.bind_v6 == 'any' else self.bind_v6, self.port))
            self._sock6.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, udp_recv_buff)
            self._sock6.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, udp_send_buff)

            self.log.info('Binded (v6) on %s:%d (udp)' % (self.bind_v6, self.port))

            thr = self._udp6_thread = threading.Thread(target=self.udpworker, args=(self._sock6, ))
            thr.daemon = True
            thr.start()

            if self.bind_v4 != 'any':
                self._sock4 = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
                self._sock4.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
                self._sock4.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, udp_recv_buff)
                self._sock4.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, udp_send_buff)
                self._sock4.bind((self.bind_v4, self.port))
                self.log.info('Binded (v4) on %s:%d (udp)' % (self.bind_v4, self.port))
                thr = self._udp4_thread = threading.Thread(target=self.udpworker, args=(self._sock4, ))
                thr.daemon = True
                thr.start()

            thr = self._job_thread = threading.Thread(target=self.jobworker)
            thr.daemon = True
            thr.start()

        except:
            self._stopev.set()
            raise

    def stop(self):
        self._stopev.set()

        for q in self._shard_queues:
            try:
                q.put((None, None), timeout=1)
            except Queue.Full:
                self.log.debug('Unable to put stopflag to shard_queue: timeout')
                pass

        try:
            self._udpqueue.put(None, timeout=1)
        except Queue.Full:
            self.log.debug('unable to put stopflag to udpqueue: timeout')
            pass

    def join(self, ignore_sigint=False, debug=False):
        for idx, thr in enumerate([self._udp4_thread, self._udp6_thread, self._job_thread]):
            if thr is None:
                continue

            while thr.isAlive():
                name = {
                    0: 'udp4_thread',
                    1: 'udp6_thread',
                    2: 'job_thread'
                }[idx]

                if debug:
                    self.log.debug('Waiting %s thread to finish...', name)

                try:
                    thr.join(timeout=1)
                except KeyboardInterrupt:
                    if not ignore_sigint:
                        raise

        self.log.debug('finished all threads')


def main():
    global BASE_PROCTITLE
    setproctitle(BASE_PROCTITLE)

    thr = threading.Thread(target=pwatch)
    thr.daemon = True
    thr.start()

    parser = argparse.ArgumentParser()
    parser.add_argument('--logprefix', default=str(os.getpid()))
    parser.add_argument('--port', type=int, required=True)
    parser.add_argument('--bind-v4', required=True)
    parser.add_argument('--bind-v6', required=True)
    parser.add_argument('--shards', required=True)
    parser.add_argument('--profile', action='store_true')
    parser.add_argument('--log-path', default=None, required=False)

    args = parser.parse_args()
    log = setup_logger(args.logprefix, args.log_path)

    log.info('Start with args: %r', vars(args))

    router = SkyboneCoordRouter(
        shards=args.shards.split(','),
        port=args.port,
        bind_v4=args.bind_v4,
        bind_v6=args.bind_v6,
        logger=log
    )

    if args.profile:
        import yappi
        yappi.set_clock_type('cpu')
        yappi.start(builtins=True)
        log.warning('ENABLED_PROFILING')
        BASE_PROCTITLE = BASE_PROCTITLE + ' (RROF)'
        setproctitle(BASE_PROCTITLE)

    router.start()

    try:
        router.join()
    except KeyboardInterrupt:
        signal.signal(signal.SIGINT, signal.SIG_IGN)
        sys.stderr.write('Caught SIGINT, stopping router\n')
        log.info('Caught SIGINT, stopping router')
        router.stop()
        router.join(ignore_sigint=True, debug=True)
    except SystemExit as ex:
        if ex.args:
            sys.stderr.write(str(ex) + '\n')
            log.warning(ex)
        else:
            sys.stderr.write('Got SystemExit exception in main loop\n')
            log.warning('Got SystemExit exception in main loop')
        raise
    except BaseException:
        sys.stderr.write('Unhandled exception: %s, exit immidiately!\n' % (traceback.format_exc(), ))
        log.critical('Unhandled exception: %s, exit immidiately!', traceback.format_exc())
        os._exit(1)

    if args.profile:
        stats = yappi.get_func_stats()
        stats.save('/tmp/router_%d.prof' % (os.getpid(), ), type='callgrind')
