#!/usr/bin/env python

"""
The script is designed to detect and remove any extra data overall database shards.
It lists all identifiers registered at all database shards and searches for duplicates.
In case of some duplicates has been found, it can remove them from the shard, which handles them.
"""

import argparse
import threading

import pymongo


DATABASE = 'sandbox'
SHARDS = [
    ('sandbox{}'.format(i), 'sandbox-storage8:{}'.format(37000 + i)) for i in xrange(1, 11)
]
SHARDS.extend(
    ('sandbox{}'.format(10 + i), 'sandbox-storage9:{}'.format(37000 + i)) for i in xrange(1, 11)
)


def fetcher(s, coll, ids):
    print("Fetching data for shard '{}'...".format(s))
    ids[s] = set(d['_id'] for d in coll.find({}, {'_id': True}))


def main(args):
    print("Initializing connections to database '{}'...".format(DATABASE))
    conn = {}
    for s, a in SHARDS:
        h, p = a.split(':')
        conn[s] = pymongo.MongoClient(h, int(p))[DATABASE]

    print("Fetching data for collection '{}'...".format(args.collection))
    ids = {}
    threads = [threading.Thread(target=fetcher, args=(s, conn[s][args.collection], ids)) for s, _ in SHARDS]
    map(threading.Thread.start, threads)
    map(threading.Thread.join, threads)

    for s, _ in SHARDS:
        print("Shard '{}' IDs: {}".format(s, len(ids[s])))

    print("Collecting totals...")
    all_ids = set()
    for s, _ in SHARDS:
        all_ids.update(ids[s])

    all_cnt = len(all_ids)
    all_sum = sum(map(len, ids.itervalues()))
    print("All shards unique count: {}, total sum: {}, diff: {}".format(all_cnt, all_sum, all_sum - all_cnt))

    print("Checking intersections...")
    for s, _ in SHARDS:
        sids = ids[s]
        unique = sids.copy()
        for ss, ssids in ids.iteritems():
            if s != ss:
                unique -= ssids
        extra = sids - unique
        ids[s] = unique
        print("Intersect with shard {} - extra: {}, unique: {}".format(s, len(extra), len(unique)))

        if extra and args.purge:
            print("Dropping extra records")
            extra = list(extra)
            while extra:
                chunk = extra[:100000]
                conn[s][args.collection].remove({'_id': {'$in': chunk}})
                del extra[:len(chunk)]
                print("Dropped a chunk of {} elements, {} left.".format(len(chunk), len(extra)))

if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        formatter_class=lambda *args, **kwargs: argparse.ArgumentDefaultsHelpFormatter(*args, width=120, **kwargs),
        description="Database shards extra data detection and removal tool."
    )
    parser.add_argument(
        "--purge",
        default=False,
        action="store_true",
        help="purge detected extra records"
    )
    parser.add_argument(
        "collection",
        metavar="COLLECTION",
        help="collection name to be examined"
    )
    main(parser.parse_args())
