#!/usr/bin/env python2

import argparse
import gevent
import gevent.event
import gevent.socket
import hashlib
import msgpack
import multiprocessing
import random
import socket
import sys
import threading
import time
import Queue

CUR_TID = 0

def generate_tid():
    global CUR_TID
    CUR_TID += 1
    #return CUR_TID
    return int(random.random() * 2 ** 32 - 1)


TID_QUEUES = {}  # tid, resqueue


def peer_runner(log, uid, sock, stats, host, port, resid):
    connect_tid = None
    connect_deadline = 0

    result_queue = gevent.queue.Queue(maxsize=1024)
    residhex = resid.decode('hex')

    gevent.sleep(random.random() * 1)

    announces_count = 100
    announces_made = 0

    while True:
        if time.time() > connect_deadline:
            if connect_tid is None:
                tid = generate_tid()

            TID_QUEUES[tid] = result_queue

            pkt = msgpack.dumps((
                4079332072,  # initial cid
                0,           # action connect
                tid,
                uid,
                [
                    '127.0.0.1', '::1',
                    '127.0.0.2', '::2',
                ],
                12345,
                ['peer_types'],  # extensions
            ))

            ret = sock.sendto(pkt, (host, port))
            try:
                response = result_queue.get(timeout=1)
            except gevent.queue.Empty:
                stats['connects_lost'] += 1
                #TID_QUEUES.pop(tid, None)
                continue
            else:
                stats['connects'] += 1
                connect_deadline = time.time() + response[1]
                TID_QUEUES.pop(tid, None)

        tid = generate_tid()
        result_queue = gevent.queue.Queue(maxsize=1024)

        while True:
            TID_QUEUES[tid] = result_queue
            pkt = msgpack.dumps((
                4079332072,  # initial cid
                10,          # action announce
                tid,
                uid,
                resid,
                1,           # announce downloading  (1 - dl, 2 - seed, 3 - stop)
                0,           # network - auto
            ))

            ret = sock.sendto(pkt, (host, port))
            try:
                response = result_queue.get(timeout=5)
            except gevent.queue.Empty:
                stats['announces_lost'] += 1
                continue
            else:
                break

        stats['announces'] += 1
        TID_QUEUES.pop(tid, None)

        #log('%s %s %s', host, port, resid)
        #gevent.sleep(0.01)

        announces_made += 1
        if announces_made == announces_count:
            break


def receiver(sock, stats):
    while True:
        data = sock.recvfrom(2048)[0]
        stats['responses'] += 1
        data = msgpack.loads(data)
        tid, response = data[0], data[1:]

        if tid in TID_QUEUES:
            TID_QUEUES[tid].put(response)
        else:
            stats['responses_no_tid'] += 1


def child(codename, host, port, peers, resid, stats_queue):
    global CUR_TID
    CUR_TID = int(codename) * 10**6

    sock = gevent.socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
    sock.setsockopt(gevent.socket.IPPROTO_IPV6, gevent.socket.IPV6_V6ONLY, 0)
    sock.setsockopt(gevent.socket.SOL_SOCKET, gevent.socket.SO_RCVBUF, 1 * 1024 * 1024)
    sock.setsockopt(gevent.socket.SOL_SOCKET, gevent.socket.SO_SNDBUF, 16 * 1024 * 1024)
    sock.bind(('', 0))

    def log(msg, *args):
        print('[' + codename + ']  ' + msg % args)
        sys.stdout.flush()

    def peerlog(peername, msg, *args):
        print('[' + codename + ']  [' + peername + ']  ' + msg % args)
        sys.stdout.flush()

    stats = {
        'connects': 0,
        'connects_lost': 0,
        'announces': 0,
        'announces_lost': 0,
        'responses': 0,
        'responses_no_tid': 0,
    }

    def runner(n):
        return gevent.spawn(
            peer_runner,
            lambda msg, *args: peerlog('%05d' % (n, ), msg, *args),
            'bench_%05d' % (n, ),
            sock,
            stats,
            host, port, resid
        )

    def _stats(finev, resqueue):
        while not finev.isSet():
            pstats = stats.copy()
            gevent.sleep(0.1)

            diff = {}
            for key, value in stats.iteritems():
                key_diff = value - pstats.get(key, 0)
                diff[key] = key_diff

            resqueue.put(diff)

    receiver_grn = gevent.spawn(receiver, sock, stats)

    stats_fin = gevent.event.Event()
    stats_grn = gevent.spawn(_stats, stats_fin, stats_queue)

    grns = [runner(n) for n in range(peers)]

    log('started %d workers', len(grns))

    try:
        ret = gevent.joinall(grns)
    except KeyboardInterrupt:
        log('Got KeyboardInterrupt, exiting...')
    else:
        stats_fin.set()
        stats_grn.join()

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--host', required=True)
    parser.add_argument('--port', type=int, default=2399)
    parser.add_argument('--procs', type=int, default=2)
    parser.add_argument('--peers', type=int, default=5000)
    parser.add_argument('--child')

    args = parser.parse_args()

    ip6 = gevent.socket.getaddrinfo(args.host, 0, socket.AF_INET6, socket.SOCK_DGRAM)[0][4][0]

    print('resolved %s to %s' % (args.host, ip6))

    args.host = ip6

    peers_per_proc = int(args.peers / args.procs)
    resid = hashlib.sha1('benchmark').hexdigest()

    stats = {}

    stats_queue = multiprocessing.Queue()

    def stats_printer(stopev):
        stats = {}

        while True:
            pstats = stats.copy()

            while True:
                try:
                    stats_block = stats_queue.get_nowait()
                except Queue.Empty:
                    break

                for key, value in stats_block.iteritems():
                    stats.setdefault(key, 0)
                    stats[key] += value

            diff = {}
            for key, value in stats.iteritems():
                key_diff = value - pstats.get(key, 0)
                diff[key] = key_diff

            if diff:
                print('resp %5d/s (%5d/s no tid)    connects %5d|%5d    announce %5d|%5d    tconn:%d   tann:%d' % (
                    diff.get('responses', 0),
                    diff.get('responses_no_tid', 0),
                    diff.get('connects', 0),
                    diff.get('connects_lost', 0),
                    diff.get('announces', 0),
                    diff.get('announces_lost', 0),
                    stats.get('connects', 0),
                    stats.get('announces', 0)
                ))

            if stopev.isSet():
                break

            time.sleep(1)

    stats_fin = threading.Event()
    stats_thr = threading.Thread(target=stats_printer, args=(stats_fin, ))
    stats_thr.daemon = True
    stats_thr.start()

    if args.child:
        child(args.child, args.host, args.port, peers_per_proc, resid)
    else:
        procs = []

        for i in range(args.procs):
            #proc = multiprocessing.Process([sys.executable, '-sbttu'] + sys.argv + ['--child', str(i)])
            proc = multiprocessing.Process(
                target=child,
                args=(str(i), args.host, args.port, peers_per_proc, resid, stats_queue)
            )
            procs.append(proc)

        [proc.start() for proc in procs]

        try:
            [proc.join() for proc in procs]
        except KeyboardInterrupt:
            [proc.terminate() for proc in procs]
            [proc.join() for proc in procs]

        stats_fin.set()
        stats_thr.join()


if __name__ == '__main__':
    main()
