#!/usr/bin/env python

"""
# get status
./scripts/dev/mongo_profiler.py mongodb://localhost:22222/sandbox status
# activate profiling for 90s requests
./scripts/dev/mongo_profiler.py mongodb://localhost:22222/sandbox enable -k 90000
# tail last 2 items
./scripts/dev/mongo_profiler.py mongodb://localhost:22222/sandbox tail -k 2
# disable and clear
./scripts/dev/mongo_profiler.py mongodb://localhost:22222/sandbox disable
"""


import argparse

from pymongo import MongoClient


DEFAULT_DB_NAME = 'sandbox'


def _fetch_clients(mongo_path):
    cli = MongoClient(mongo_path)
    db = cli['config']
    coll = db['shards']
    shards = [s for s in coll.find()]
    hosts = map(lambda s: s['host'], shards)
    clients = [MongoClient(h) for h in hosts]
    return clients


def enable_profiling(clients, db_name, tt):
    for c in clients:
        db = c[db_name]
        db.set_profiling_level(1, tt)
        print '%s profiling level is %s' % (db, db.profiling_level())


def disable_profiling(clients, db_name):
    for c in clients:
        db = c[db_name]
        db.set_profiling_level(0)
        print '%s count is %d' % (db.system.profile, db.system.profile.find().count())
        db.system.profile.drop()
        print '%s profiling level is %s' % (db, db.profiling_level())


def profiling_status(clients, db_name):
    status = []
    for c in clients:
        db = c[db_name]
        lvl = db.profiling_level()
        cn = db.system.profile.find().count()
        status.append((c.nodes, 'profiling level %s' % lvl, 'items count %d' % cn))
    return status


def tail_profiling_by_ms(clients, db_name, limit):
    items = []
    for c in clients:
        db = c[db_name]
        items.extend([(c.nodes, i) for i in db.system.profile.find().sort('millis', -1).limit(limit)])
    items.sort(key=lambda x: x[1]['millis'])
    return items


def count_profiling(clients, db_name):
    count = 0
    for c in clients:
        db = c[db_name]
        count += db.system.profile.find().count()
    return count


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Requests response time and frequency calculation script")
    parser.add_argument('mongo_path')
    parser.add_argument('mode', help='enable | disable | status | tail | count')
    parser.add_argument('-k', '--key', default=None)
    parser.add_argument('--db_name', default=DEFAULT_DB_NAME)
    args = parser.parse_args()

    print '"%s" for "%s"' % (args.mongo_path, args.db_name)

    if args.mode.lower() == 'enable':
        clients = _fetch_clients(args.mongo_path)
        print 'amount %d clients' % len(clients)
        if not args.key:
            args.key = 100
        args.key = int(args.key)
        disable_profiling(clients, args.db_name)
        print 'OK'
        enable_profiling(clients, args.db_name, args.key)
        print 'OK'

    elif args.mode.lower() == 'disable':
        clients = _fetch_clients(args.mongo_path)
        print 'amount %d clients' % len(clients)
        disable_profiling(clients, args.db_name)
        print 'OK'

    elif args.mode.lower() == 'status':
        clients = _fetch_clients(args.mongo_path)
        print 'amount %d clients' % len(clients)
        status = profiling_status(clients, args.db_name)
        if status:
            import pprint
            pprint.pprint(status)

    elif args.mode.lower() == 'tail':
        clients = _fetch_clients(args.mongo_path)
        print 'amount %d clients' % len(clients)
        if not args.key:
            args.key = 1
        args.key = int(args.key)
        items = tail_profiling_by_ms(clients, args.db_name, args.key)
        if items:
            import pprint
            pprint.pprint(items)
        else:
            print 'empty'

    elif args.mode.lower() == 'count':
        clients = _fetch_clients(args.mongo_path)
        print 'amount %d clients' % len(clients)
        count = count_profiling(clients, args.db_name)
        print 'total %d items in system.profile collections on shards' % count

    else:
        print 'unknown mode'
