#!/usr/bin/python
# -*- coding: utf-8 -*-
"""

Скрипт для выполнения команд на всех шардах

    Может принимать команду как через pipe:
    echo 'db.user_data.count()' | mpfs-sharded-db-executer.py

    так и через параметр -c:
    mpfs-sharded-db-executer.py -c 'db.user_data.count()'
    108
    1382

    Умеет собирать результаты в одно значение: 
    mpfs-sharded-db-executer.py -c 'db.user_data.count()' -a
    1490

"""

import sys
import gevent

from optparse import OptionParser, Option

import mpfs.engine.process
mpfs.engine.process.setup_anyone_script()

from mpfs.metastorage.mongo.cursor import MPFSCommonCursor

usage = "usage: %prog -h"

option_list = (
    Option(
        '-c', '--command',
        action='store',
        dest='command',
        type='string',
        help='command to execute'
    ),
    Option(
        '-d', '--database',
        action='store',
        dest='db_name',
        type='string',
        help='database name',
        default=None,
    ),
    Option(
        '-a', '--aggregate',
        action='store_true',
        dest='aggregate',
        help='aggregate results from all shards to one value',
        default=False,
    ),

)

parser = OptionParser(usage, option_list=option_list)
(options, args) = parser.parse_args()


def aggregate_results(results):
    if isinstance(results[0], int):
        result = 0
        for r in results:
            result += r
        return result
    elif isinstance(results[0], list):
        result = []
        for r in results:
            result.extend(r)
        return result
    else:
        return results


def run(command, db_name, aggregate):
    if db_name is None:
        db_name = command.split('.')[1]

    dbctl = mpfs.engine.process.dbctl()
    shards = [x.get('_id') for x in dbctl.mapper.rspool.get_all_shards_info() if not 'sys' in x.get('_id')]

    def get_data_from_shard(shard_id, command, db_name):
        try:
            conn = dbctl.mapper.rspool.get_connection_for_rs_name(shard_id)
            db = conn[db_name]
            result = eval(command)
            if isinstance(result, MPFSCommonCursor):
                result = list(result)
            return result
        except Exception as e:
            return "Error executing command: %s" % e

    threads = []
    for shard_id in shards:
        thr = gevent.spawn(get_data_from_shard, shard_id, command, db_name)
        threads.append(thr)
    gevent.joinall(threads)
    results = [thread.value for thread in threads]
    if aggregate:
        print aggregate_results(results)
    else:
        for r in results:
            print r

if __name__ == "__main__":
    command = options.command if sys.stdin.isatty() else sys.stdin.read().strip()
    if command:
        run(command, options.db_name, options.aggregate)
    else:
        parser.print_help()
        sys.exit(0)
