#!/usr/bin/env pypy3

import argparse
from collections import defaultdict, namedtuple
import datetime
import itertools
import msgpack
import socket
import sys
import time

from scapy.all import *

gethostbyaddr_cache = {}
def gethostbyaddr(ip):
    return gethostbyaddr_cache.setdefault(ip, socket.gethostbyaddr(ip)[0])

interesting_hash = None
# interesting_hash = b'1fd9b35a7413e6ebed9d3ac71f97261c58a7add3'

class RequestCnt:
    def __init__(self):
        self.total = 0
        self.announce = 0
        self.leech = 0
        self.seed = 0
        self.stop = 0
        self.info = 0
        self.invalid = 0

    def __add__(self, other):
        ret = RequestCnt()
        ret.total    = self.total    + other.total
        ret.announce = self.announce + other.announce
        ret.leech    = self.leech    + other.leech
        ret.seed     = self.seed     + other.seed
        ret.stop     = self.stop     + other.stop
        ret.info     = self.info     + other.info
        ret.invalid  = self.invalid  + other.invalid
        return ret

    def __str__(self):
        return "RequestCnt(total={}, announce={}, leech={}, seed={}, stop={}, info={}, invalid={})".format(
                self.total, self.announce, self.leech, self.seed, self.stop, self.info, self.invalid)

initial_cid = 1736820976
initial_cid2 = 1964596860
initial_cid3 = 4079332072
act_connect = 0
act_announce = 10

hash_cnt = defaultdict(RequestCnt)
src_ip_cnt = defaultdict(RequestCnt)
shard_use_cnt = defaultdict(int)
shard_interesting_use_cnt = defaultdict(int)
shard_use_cnt_latest = defaultdict(int)
shard_interesting_use_cnt_latest = defaultdict(int)
unique_leeches = set()
n_packets = 0
n_unknown_packets = 0

start_time = None
interactive_mode = False
duplicate_stdout_to_stderr = not sys.stdout.isatty()

def perc(a, b):
    return (100 * a) // b

def output(*args, **kwargs):
    if (
        kwargs.get('file') is not None
        or not duplicate_stdout_to_stderr
    ):
        print(*args, **kwargs)
    else:
        print(*args, **kwargs)
        kwargs['file'] = sys.stderr
        print(*args, **kwargs)

def output_shard_dict(d, interesting_d=None, file=None):
    for k, v in sorted(d.items()):
        output('shard {}: {}{}'.format(
            k, v,
            ' ({}, {}% by {})'.format(interesting_d[k], perc(interesting_d[k], v), interesting_hash) if interesting_d is not None else '',
            file=file
        ))

def parse_packet(data):
    global n_packets
    n_packets += 1

    shardnum = None
    res_hash = None
    cur_req_cnt = RequestCnt()

    cid, action = data[0], data[1]
    if cid == initial_cid3 and action == act_announce:
        res_hash = data[4]
        state = data[5]
        if len(res_hash) == 40:
            hashint = int(res_hash, 16)
            cur_req_cnt.total = 1
            if state == 1:
                cur_req_cnt.leech = 1
            elif state == 2:
                cur_req_cnt.seed = 1
            elif state == 3:
                cur_req_cnt.stop = 1
        else:
            hashint = None

        if res_hash == 'CLEAN':
            shardnum = 'all'
        else:
            shardnum = hashint % 8
    elif cid == 'IN' and action == 'FO':
        res_hash = data[2].encode('hex')
        hashint = int(res_hash, 16)
        cur_req_cnt.total += 1
        cur_req_cnt.info += 1
        shardnum = hashint % 8
    elif action == act_announce:
        state = data[3]
        res_hash = data[4]
        hashint = int(res_hash, 16)
        cur_req_cnt.total = 1
        if state == 1:
            cur_req_cnt.leech = 1
        elif state == 2:
            cur_req_cnt.seed = 1
        elif state == 3:
            cur_req_cnt.stop = 1

        shardnum = hashint % 8
    else:
        global n_unknown_packets
        n_unknown_packets += 1

    return shardnum, res_hash, cur_req_cnt, action

def process_packet(packet):
    if UDP not in packet or packet[UDP].dport != 2399:
        return

    global shard_use_cnt
    global shard_use_cnt_latest
    global shard_interesting_use_cnt
    global shard_interesting_use_cnt_latest
    global unique_leeches

    if n_packets % 50000 == 0 and n_packets:
        t = time.time()
        output("Processed {} packets, current pkt timestamp {}, took {:.2f} seconds, {:.2f} pkts/sec".format(
               n_packets,
               datetime.fromtimestamp(packet.time).strftime('%Y-%m-%d %H:%M:%S.%f'),
               t - start_time, n_packets / (t - start_time)),
               file=sys.stderr)
        output('shard usage (latest):', file=sys.stderr)
        output_shard_dict(shard_use_cnt_latest, shard_interesting_use_cnt_latest, file=sys.stderr)
        shard_use_cnt_latest.clear()
        shard_interesting_use_cnt_latest.clear()

    data = msgpack.loads(bytes(packet.getlayer(UDP).payload), raw=True)
    shardnum, res_hash, cur_req_cnt, action = parse_packet(data)

    if res_hash is not None and res_hash != 'CLEAN':
        hash_cnt[res_hash] += cur_req_cnt

    if interactive_mode:
        if action == act_announce:
            assert cur_req_cnt.leech + cur_req_cnt.seed + cur_req_cnt.stop <= 1
            mode = None
            if cur_req_cnt.leech:
                mode = 'leech'
            elif cur_req_cnt.seed:
                mode = 'seed'
            elif cur_req_cnt.stop:
                mode = 'stop'
            dst = packet[IPv6].dst
            output('{}: announce {}, hash {}, dst {} ({})'.format(
                datetime.now(),
                mode,
                res_hash[:8].decode() + '...',
                gethostbyaddr(dst),
                dst
            ))
        elif action == act_connect:
            output('connect')

    if shardnum is not None:
        shard_use_cnt[shardnum] += 1
        shard_use_cnt_latest[shardnum] += 1
        if cur_req_cnt.leech != 0:
            unique_leeches.add((res_hash, packet[IPv6].src))
        if res_hash == interesting_hash:
            shard_interesting_use_cnt[shardnum] += 1
            shard_interesting_use_cnt_latest[shardnum] += 1
            src_ip_cnt[packet[IPv6].src] += cur_req_cnt

def main():

    parser = argparse.ArgumentParser(description='''
Analyze skybone-coord traffic dumps. Dumps can be collected from tracker via
`tcpdump -nK -w dump.pcap -i eth1 udp and dst port 2399`.
If no argument is given, listen on eth0 and display announces in real-time.
''')

    parser.add_argument('file', nargs='?', help='traffic dump file')
    args = parser.parse_args()

    global interactive_mode
    interactive_mode = args.file is None
    global start_time
    start_time = time.time()

    if interactive_mode:
        print("Listening on eth0", file=sys.stderr)
        sniff(iface='eth0', prn=process_packet, store=False)
        # sniff filter doesn't work in some PyPy versions
    else:
        output('parsing file', args.file)
        sniff(offline=args.file, prn=process_packet, store=False)

        n_ann_total = sum(x[1].total for x in hash_cnt.items())
        n_ann_seed  = sum(x[1].seed  for x in hash_cnt.items())
        n_ann_leech = sum(x[1].leech for x in hash_cnt.items())
        n_ann_stop  = sum(x[1].stop  for x in hash_cnt.items())

        output_shard_dict(shard_use_cnt, shard_interesting_use_cnt)
        for k, v in itertools.islice(sorted(hash_cnt.items(), key=lambda x: -x[1].total), 50):
            output('hash {}: {}, {}% of total announces'.format(k, v, perc(v.total, n_ann_total)))

        for k, v in itertools.islice(sorted(src_ip_cnt.items(), key=lambda x: -x[1].total), 50):
            output('src {}: {}'.format(k, v))

        if src_ip_cnt:
            output('Avg requests from single ip:', sum(x.total for x in src_ip_cnt.values()) / len(src_ip_cnt))
        output('Total ips:', len(src_ip_cnt))
        output('Total packets: {}, total announces {} ({}%), seed announces {} ({}%), leech announces {} ({}%), stop announces {} ({}%), {} ({}%) unique leeches, unknown packets: {} ({}%)'.format(
            n_packets,
            n_ann_total, perc(n_ann_total, n_packets),
            n_ann_seed,  perc(n_ann_seed,  n_packets),
            n_ann_leech, perc(n_ann_leech, n_packets),
            n_ann_stop,  perc(n_ann_stop,  n_packets),
            len(unique_leeches), perc(len(unique_leeches), n_packets),
            n_unknown_packets,   perc(n_unknown_packets,   n_packets)
        ))

if __name__ == '__main__':
    main()
