#!/usr/bin/python

"""
A script to update MongoDB replica set configuration.
"""

from __future__ import print_function
from __future__ import absolute_import

import os
import sys
import json
import time
import pprint
import logging
import argparse
import datetime as dt
import operator as op
import textwrap as tw
import itertools as it
import subprocess as sp
import distutils.util as dtutil
import distutils.spawn as dtspawn
from multiprocessing import pool as mp

try:
    from sandbox.scripts.mongo_tool import mongo_setup
except ImportError:
    import mongo_setup

try:
    from urllib import urlretrieve
    from urllib2 import urlopen, HTTPError
except ImportError:
    from urllib.request import urlopen, urlretrieve
    from urllib.error import HTTPError
    raw_input = input


SCRIPT_TMPL = tw.dedent("""
    #!/usr/bin/python

    import json
    import collections
    import subprocess as sp

    Line = collections.namedtuple("Line", ("name", "port", "config"))
    DATA = (
        {data_lines}
    )

    for l in DATA:
        print("Reconfiguring shard '{}'".format(l.name)
        sp.call([
            "mongo", "--port", str(l.port), "--quiet", "admin", "--eval",
            "printjson(rs.reconfig(" + json.dumps(l.config) + "))"
        ])
""")

S3_BACKUP_BUCKET_URL = "https://s3.mds.yandex.net/sandbox-backup"
MONGO_BACKUP_NAME_TEMPLATE = "mongo_backup-{date}-{iteration}"

RECONFIGURE_COMMAND = "reconfigure"
CHECK_COLLECTIONS_COMMAND = "check_collections"
COLLECTION_INFO_COMMAND = "collection_info"
ENABLE_SHARDING_COMMAND = "enable_sharding"
RESTORE_SHARDS_COMMAND = "restore_shards"
STEP_DOWN_COMMAND = "step_down"
CREATE_LOCAL_CLUSTER_COMMAND = "create_local_cluster"

SHARDING_STRATEGY = {
    "ranged": "1",
    "hashed": "'hashed'"
}


class MongoToolException(Exception):
    pass


class CollectionStat(object):
    def __init__(self, setup, collection_name):
        self.setup = setup
        self.collection = collection_name
        self.stats = run_mongo_command("db.getCollection('{}').stats()".format(collection_name), setup.sandbox_db)

    @property
    def is_sharded(self):
        return self.stats["sharded"] == 1

    @property
    def size(self):
        return self.stats["size"]

    @property
    def count(self):
        return self.stats["count"]

    @property
    def indexes(self):
        return run_mongo_command(
            "db.getCollection('{}').getIndexes()".format(self.collection),
            self.setup.sandbox_db,
            check_errors=False
        )

    @property
    def shards_distribution(self):
        return run_mongo_command(
            "db.getCollection('{}').getShardDistribution()".format(self.collection),
            self.setup.sandbox_db,
            parse_json=False
        )


def run_mongo_command(command, db="admin", port="22222", parse_json=True, check_errors=True, attempts=1):
    # type: (str, str, str, bool, bool, int) -> (dict, str or None)
    logging.debug("Executing mongo command '{}' on port {}, db={}".format(command, port, db))
    for _ in range(attempts):
        try:
            output = sp.check_output([
                "mongo", "--port", str(port), db, "--quiet",
                "--eval", "tostrictjson({})".format(command) if parse_json else command
            ])
            break
        except sp.CalledProcessError:
            time.sleep(10)
            continue
    else:
        raise MongoToolException("Failed to execute command on mongo(s,d), port {}".format(port))
    if not parse_json:
        return output
    parsed_result = json.loads(output)
    if not check_errors or parsed_result["ok"] == 1:
        return parsed_result
    raise MongoToolException(parsed_result["errmsg"])


def human_sized(size):  # type: (int or float) -> str
    for unit in ["", "K", "M", "G", "T", "P", "E", "Z"]:
        if abs(size) < 1024.0:
            return "{:3.1f}{}b".format(size, unit)
        size /= 1024.0
    return "{:.1f}{}b".format(size, "Y")


def data_shards(setup):
    data_shard_line = setup.find_data_shard_map_by_host()
    if not data_shard_line:
        return
    for shard_index in data_shard_line.shards_index_range:
        shard_name = "sandbox{}".format(shard_index)
        yield shard_name, data_shard_line.port_by_shard_index(shard_index), shard_index


def config_shards(setup):
    if not setup.find_config_shard_map_by_host():
        return
    yield "cs_repl_set", 27019, None


def get_data_shards_configs(args):
    for shard_name, port, shard_index in data_shards(args.setup):
        config = {
            "_id": shard_name,
            "version": 42,  # this field is not used actually and will be replaced by mongodb
            "protocolVersion": 1,
            "members": [
                {
                    "_id": data_map_line.no,
                    "host": "{}.search.yandex.net:{}".format(
                        data_map_line.id,
                        data_map_line.port_by_shard_index(shard_index)
                    ),
                    "arbiterOnly": False,
                    "buildIndexes": True,
                    "hidden": not data_map_line.priority,
                    "priority": data_map_line.priority or 0,
                    "tags": {},
                    "slaveDelay": 0,
                    "votes": int(data_map_line.priority is not None),
                }
                for data_map_line in args.setup.data_shards_map
                if data_map_line.dc != args.exclude and shard_index in data_map_line.shards_index_range
            ]
        }
        yield shard_name, port, config


def get_config_shards_configs(args):
    for shard_name, port, _ in config_shards(args.setup):
        config = {
            "_id": shard_name,
            "version": 42,  # this field is not used actually and will be replaced by mongodb
            "protocolVersion": 1,
            "configsvr": True,
            "members": [
                {
                    "_id": line.no,
                    "host": "{}.search.yandex.net:{}".format(line.alias or line.id, port),
                    "arbiterOnly": False,
                    "buildIndexes": True,
                    "hidden": not line.priority,
                    "priority": line.priority or 0,
                    "tags": {},
                    "slaveDelay": 0,
                    "votes": int(line.priority is not None),
                }
                for line in args.setup.config_shards_map
                if line.dc != args.exclude
            ]
        }
        yield shard_name, port, config


def _reconfigure(args):
    lines = []
    for name, port, config in it.chain(get_data_shards_configs(args), get_config_shards_configs(args)):
        if not args.script:
            logging.info("Reconfiguring shard '{}'".format(name))
            (sp.call if args.execute else logging.info)([
                "mongo", "--port", str(port), "--quiet", "admin",
                "--eval",
                "printjson(rs.reconfig({}, {{force: {}}}))".format(json.dumps(config), str(args.force).lower())
            ])
        else:
            lines.append("Line({}, {}, {}),".format(name, port, pprint.pformat(config)))
    if args.script:
        script_filename = "{}_{}_{}.py".format(
            sys.argv[0].rsplit(".")[0], args.installation, args.setup.get_host_id()
        )
        with open(script_filename, "w") as fh:
            fh.write(SCRIPT_TMPL.format(data_lines="\n        ".join(lines)).lstrip())
        os.chmod(script_filename, 0o755)
        logging.info("Created script {!r}".format(script_filename))


def _check_sharded_sandbox_collections(args):
    return check_sharded_sandbox_collections(args.setup)


def check_sharded_sandbox_collections(setup):
    sandbox_db = setup.sandbox_db
    logging.info("Checking '%s' collections sharding setup...", sandbox_db)
    collection_names = run_mongo_command("db.getCollectionNames()", sandbox_db, check_errors=False)
    chunk_size = run_mongo_command("db.settings.findOne({'_id': 'chunksize'})", "config", check_errors=False) or {}
    chunk_size = chunk_size.get("value", 60) << 20
    not_sharded_collections = []
    for collection_name in collection_names.values():
        collection_stat = CollectionStat(setup, collection_name)
        if not collection_stat.is_sharded:
            not_sharded_collections.append(collection_stat)
    not_sharded_collections.sort(key=op.attrgetter("size"), reverse=True)
    for collection_stat in not_sharded_collections:
        logging.log(
            logging.ERROR if collection_stat.size > chunk_size else logging.WARNING,
            "Collection '%s' not sharded (size: %s) ", collection_stat.collection, human_sized(collection_stat.size)
        )
    if not not_sharded_collections:
        logging.info("All large collections in %s are sharded", sandbox_db)
    return not_sharded_collections


def _collection_info(args):
    collection_names = run_mongo_command("db.getCollectionNames()", args.setup.sandbox_db, check_errors=False)
    if args.collection not in collection_names:
        logging.error("Collection '%s.%s' not exists", args.setup.sandbox_db, args.collection)
        return
    collection_stats = CollectionStat(args.setup, args.collection)
    logging.info("### Collection '%s' info ###", args.collection)
    logging.info("Documents size: %s", human_sized(collection_stats.size))
    logging.info("Documents count: %d", collection_stats.count)
    logging.info("Indexes:\n%s", json.dumps(collection_stats.indexes, indent=4))
    if collection_stats.is_sharded:
        logging.info("Shards distribution:\n%s", collection_stats.shards_distribution)
    else:
        logging.info("Collection is not sharded")


def _enable_sharding(args):
    if CollectionStat(args.setup, args.collection).is_sharded:
        print("Collection '{}' is already sharded".format(args.collection))
        return 0
    print("You are going to enable sharding on collection '{}.{}' with shard key {{'{}': {}}})".format(
        args.setup.sandbox_db, args.collection, args.key, args.strategy
    ))
    continue_approved = raw_input("This operation can't be undone. Continue [(y)es/(n)o]: ")
    while True:
        try:
            continue_approved = dtutil.strtobool(continue_approved)
            break
        except ValueError:
            continue_approved = raw_input("Wrong input, try again [(y)es/(n)o]: ")
    if not continue_approved:
        return 0
    print("Ensuring index ({{'{}': {}}})... ".format(args.key, args.strategy), end="")
    try:
        run_mongo_command(
            "db.getCollection('{}').ensureIndex({{'{}': {}}})".format(args.collection, args.key, args.strategy),
            args.setup.sandbox_db
        )
    except MongoToolException as exception:
        print("FAILED\n{}".format(str(exception)))
        return 1
    else:
        print("OK")
    print("Sharding collection {}... ".format(args.collection), end="")
    try:
        run_mongo_command(
            "sh.shardCollection('{}.{}', {{'{}': {}}})".format(
                args.setup.sandbox_db, args.collection, args.key, args.strategy
            )
        )
    except MongoToolException as exception:
        print("FAILED\n{}".format(str(exception)))
        return 1
    print("OK")


def get_target_shards(setup, shards_whitelist=None):
    target_shards = {}
    shards_whitelist = shards_whitelist or set()
    for shard_name, port, _ in it.chain(data_shards(setup), config_shards(setup)):
        if shards_whitelist and shard_name not in shards_whitelist:
            continue
        target_shards[shard_name] = port
    remaining_shards = shards_whitelist.difference(target_shards)
    if remaining_shards:
        raise MongoToolException(
            "Specified shards either not exists or not appropriate to this host: [{}]".format(
                ", ".join(remaining_shards)
            )
        )
    return target_shards


def find_latest_backup_iteration(date):
    iteration = 0
    try:
        while True:
            mongo_backup_name = MONGO_BACKUP_NAME_TEMPLATE.format(date=date, iteration=iteration)
            url = "{0}/{1}/{1}".format(S3_BACKUP_BUCKET_URL, mongo_backup_name)
            urlopen(url, timeout=10).close()
            iteration += 1
    except HTTPError:
        iteration -= 1
    if iteration < 0:
        raise MongoToolException("Failed to find backup for date {}".format(date))
    return iteration


def download_shard_from_metadata(shard_metadata):
    logging.info("Downloading file %s (size: %s) from S3", shard_metadata["path"], human_sized(shard_metadata["size"]))
    urlretrieve("{}/{}".format(S3_BACKUP_BUCKET_URL, shard_metadata["key"]), shard_metadata["path"])
    logging.info("File %s is ready", shard_metadata["path"])


def download_latest_backups(date, shards, target_dir=None):
    thread_pool = mp.ThreadPool(4)
    iteration = find_latest_backup_iteration(date)
    mongo_backup_name = MONGO_BACKUP_NAME_TEMPLATE.format(date=date, iteration=iteration)
    metadata_url = "{0}/{1}/{1}".format(S3_BACKUP_BUCKET_URL, mongo_backup_name)
    connection = urlopen(metadata_url, timeout=10)
    backup_s3_metadata = json.loads(connection.read())
    if not os.path.exists(mongo_backup_name):
        os.makedirs(mongo_backup_name)
    for shard_metadata in backup_s3_metadata:
        if os.path.splitext(os.path.basename(shard_metadata["path"]))[0] in shards:
            if os.path.exists(shard_metadata["path"]):
                logging.info("File {} already exists, skip downloading".format(shard_metadata["path"]))
            else:
                thread_pool.apply_async(download_shard_from_metadata, (shard_metadata,))
    thread_pool.close()
    thread_pool.join()
    logging.info("All required shard backups downloaded")
    if target_dir and os.path.abspath(target_dir) != os.path.abspath(mongo_backup_name):
        os.rename(os.path.abspath(mongo_backup_name), os.path.abspath(target_dir))
    else:
        target_dir = mongo_backup_name
    return target_dir


def get_uncompress_command():
    if dtspawn.find_executable("pixz"):
        logging.info("Using 'pixz' for backup decompression")
        return ["pixz", "-d", "-p14", "-i"]
    if dtspawn.find_executable("xz"):
        logging.info("Using 'xz' for backup decompression")
        return ["xz", "--decompress", "-c"]
    raise MongoToolException("No '.xz' decompressors found, 'xz' or 'pixz' required")


def _restore_shards(args):
    return restore_shards(args.setup, args.date, args.directory, args.namespaces, args.shards)


def restore_shards(setup, date, directory=None, namespaces=None, shards=None):
    target_shards = get_target_shards(setup, set(shards or []))
    logging.info("Target shards: {}".format(", ".join(sorted(target_shards))))
    if not directory or not os.path.exists(directory):
        directory = download_latest_backups(date, target_shards, directory)
    uncompress_command = get_uncompress_command()

    def restore_single_shard(shard_name):
        log_filename = os.path.join(directory, shard_name + "_restore.log")
        logging.info("Start restore shard {} on port {} (logfile: {})".format(
            shard_name, target_shards[shard_name], log_filename)
        )
        uncompress_process = sp.Popen(
            uncompress_command + ["{}.pixz".format(os.path.join(directory, shard_name))],
            stdout=sp.PIPE
        )
        mongorestore_command = [
            "mongorestore", "--archive", "--port", str(target_shards[shard_name])
        ]
        if namespaces:
            for namespace in namespaces:
                mongorestore_command += ["--nsInclude", namespace]
        else:
            mongorestore_command += ["--oplogReplay"]
        with open(log_filename, "w") as logfile:
            restore_process = sp.Popen(
                mongorestore_command, stdin=uncompress_process.stdout, stdout=logfile, stderr=sp.STDOUT
            )
            restore_process.wait()
        if restore_process.returncode != 0:
            logging.error("Restore of {} has failed! See logs in {}".format(shard_name, log_filename))
        return True

    pool = mp.ThreadPool(4)
    restore_results = pool.imap_unordered(restore_single_shard, sorted(target_shards))
    pool.close()
    if not all(restore_results):
        raise MongoToolException("Failed to restore mongo shards. See logs for more info")
    logging.info("All target shards restored")


def _step_down(args):
    if os.geteuid() != 0:
        raise MongoToolException("You need root access to execute step_down. Please, try again with 'sudo'")
    target_shards = get_target_shards(args.setup, set(args.shards or []))
    logging.info("Target shards: {}".format(", ".join(sorted(target_shards))))
    for shard_name in sorted(target_shards):
        logging.info("Step down shard {} on port {}".format(shard_name, target_shards[shard_name]))
        try:
            run_mongo_command(
                "rs.stepDown({})".format(args.duration), port=target_shards[shard_name],
                parse_json=False, check_errors=False
            )
        except MongoToolException:
            pass
        sp.call([
            "service",
            "mongod_{}".format(target_shards[shard_name] if shard_name != "cs_repl_set" else "cfg"),
            "restart"
        ])


def _create_local_cluster(args):
    return create_local_cluster(args.datadir, args.logdir)


def create_local_cluster(data_dir, log_dir):
    logging.info("Creating local mongodb cluster")
    instances = []
    try:
        instances.append(_create_cfg_instance(data_dir, log_dir))
        instances.append(_create_mongos_instance(log_dir))
        instances.extend(_create_data_instances(data_dir, log_dir))
    except MongoToolException:
        for instance in instances:
            instance.kill()
        raise
    return instances


def _create_cfg_instance(data_dir, log_dir):
    repl_name, port = "cs_repl_set", "27019"
    logging.info("Creating config shard %s on port %s", repl_name, port)
    data_path = os.path.join(data_dir, repl_name)
    if not os.path.exists(data_path):
        os.makedirs(data_path)
    log_path = os.path.join(log_dir, "{}.log".format(repl_name))
    config_instance = sp.Popen([
        "mongod", "--port", port, "--configsvr", "--replSet", repl_name,
        "--ipv6", "--bind_ip", "127.0.0.1", "--dbpath", data_path, "--logpath", log_path])
    try:
        run_mongo_command("rs.initiate()", port=port, attempts=10)
    except MongoToolException:
        logging.error("Failed to init config shard %s on port %s", repl_name, port)
        config_instance.kill()
        raise
    return config_instance


def _create_mongos_instance(log_dir):
    logging.info("Creating mongos process on port 22222")
    log_path = os.path.join(log_dir, "mongos.log")
    mongos_instance = sp.Popen([
        "mongos", "--port", "22222", "--configdb", "cs_repl_set/127.0.0.1:27019",
        "--ipv6", "--bind_ip", "127.0.0.1", "--logpath", log_path
    ])
    run_mongo_command("db.version()", parse_json=False, attempts=10)
    return mongos_instance


def _create_data_instances(data_dir, log_dir):
    data_instances = []
    for index in range(1, mongo_setup.SHARDS + 1):
        port = "370{:02}".format(index)
        repl_name = "sandbox{}".format(index)
        logging.info("Creating shard {} on port {}".format(repl_name, port))
        data_path = os.path.join(data_dir, repl_name)
        if not os.path.exists(data_path):
            os.makedirs(data_path)
        log_path = os.path.join(log_dir, "{}.log".format(repl_name))
        data_instance = sp.Popen([
            "mongod", "--port", port, "--shardsvr", "--replSet", repl_name,
            "--ipv6", "--bind_ip", "127.0.0.1", "--wiredTigerCacheSizeGB", "2",
            "--dbpath", data_path, "--logpath", log_path])
        data_instances.append(data_instance)
        try:
            run_mongo_command("rs.initiate()", port=port, attempts=10)
            run_mongo_command("sh.addShard('{}/127.0.0.1:{}')".format(repl_name, port), attempts=10)
        except MongoToolException:
            logging.error("Failed to init data shard %s", repl_name)
            for instance in data_instances:
                instance.kill()
            raise
    return data_instances


def handle_args(argv=None):
    def valid_date(arg):
        try:
            dt.datetime.strptime(arg, "%Y%m%d")
        except ValueError:
            msg = "Not a valid date format: '{}'".format(arg)
            raise argparse.ArgumentTypeError(msg)
        return arg

    def existing_directory(arg):
        if not os.path.exists(arg):
            msg = "Directory '{}' not found".format(arg)
            raise argparse.ArgumentTypeError(msg)
        elif not os.path.isdir(arg):
            msg = "'{}' should be directory".format(arg)
            raise argparse.ArgumentTypeError(msg)
        return arg

    parser = argparse.ArgumentParser(
        formatter_class=argparse.RawTextHelpFormatter,
        description=sys.modules[__name__].__doc__.strip()
    )
    parser.add_argument(
        "installation",
        metavar="INSTALLATION", type=str, help="installation type (preprod/production/testing)",
        choices=sorted(mongo_setup.SETUP_MAP)
    )
    commands = parser.add_subparsers(dest="command", title="Commands")
    # Reconfigure
    reconfigure_parser = commands.add_parser(
        RECONFIGURE_COMMAND, help="generate new replica sets configuration"
    )
    result_options = reconfigure_parser.add_mutually_exclusive_group()
    result_options.add_argument(
        "-e", "--execute", default=False, action="store_true", help="execute resulting configuration immediately"
    )
    result_options.add_argument(
        "-s", "--script", default=False, action="store_true", help="generate script which apply configuration"
    )
    reconfigure_parser.add_argument(
        "--exclude", metavar="EXCLUDE_DC", type=str, help="exclude dc from configuration",
        choices=mongo_setup.DCS
    )
    reconfigure_parser.add_argument(
        "-f", "--force", default=False, action="store_true",
        help="force configuration execution (execute on SECONDARY instance)"
    )
    # Check collections
    commands.add_parser(CHECK_COLLECTIONS_COMMAND, help="check sharding status on sandbox collections")
    # Get collection info
    info_parser = commands.add_parser(
        COLLECTION_INFO_COMMAND, help="get collection info: size, object count, indexes and sharding"
    )
    info_parser.add_argument("collection", metavar="COLLECTION", help="name of sandbox collection")
    # Enable sharding
    sharding_parser = commands.add_parser(ENABLE_SHARDING_COMMAND, help="enable sharding on sandbox collection")
    sharding_parser.add_argument("collection", metavar="COLLECTION", help="name of sandbox collection to shard")
    sharding_parser.add_argument("key", metavar="SHARD_KEY", help="indexed field to be used as shard key (e.g. '_id')")
    sharding_parser.add_argument(
        "strategy", metavar="STRATEGY", help="sharding strategy ('ranged' or 'hashed')", choices=SHARDING_STRATEGY
    )
    # Restore shards
    restore_parser = commands.add_parser(
        RESTORE_SHARDS_COMMAND, help="restore mongo instances data from backup (for PRIMARY only)"
    )
    source_options = restore_parser.add_mutually_exclusive_group()
    source_options.add_argument(
        "--date", metavar="YYYYMMDD", type=valid_date, default=dt.datetime.today().strftime("%Y%m%d"),
        help="download backup by date (default: today)")
    source_options.add_argument(
        "--directory", metavar="DIRECTORY", type=existing_directory, help="use local backup from directory"
    )
    restore_parser.add_argument("--namespaces", nargs="+", metavar="NAMESPACE", help="restore only selected namespace")
    restore_parser.add_argument("--shards", nargs="+", metavar="SHARD", help="restore only selected shards")
    # Step down
    step_down_parser = commands.add_parser(
        STEP_DOWN_COMMAND, help="step down instances (for PRIMARY only, root access required)"
    )
    step_down_parser.add_argument("--shards", nargs="+", metavar="SHARD", help="step down only selected shards")
    step_down_parser.add_argument("--duration", type=int, metavar="SECONDS", default=600, help="step down duration")
    # Create local cluster
    local_cluster_parser = commands.add_parser(
        CREATE_LOCAL_CLUSTER_COMMAND, help="create sharded cluster on localhost (for testing only)"
    )
    local_cluster_parser.add_argument(
        "-d", "--datadir", metavar="DIRECTORY", default="mongo_data", help="directory to store mongo data"
    )
    local_cluster_parser.add_argument(
        "-l", "--logdir", metavar="DIRECTORY", default="mongo_logs", help="directory to store mongo logs"
    )

    args = parser.parse_args(argv)
    args.setup = mongo_setup.SETUP_MAP[args.installation]
    if args.command in (RECONFIGURE_COMMAND, RESTORE_SHARDS_COMMAND, STEP_DOWN_COMMAND):
        # Check host is in shard map
        host_id = args.setup.get_host_id()
        host_data_map = args.setup.find_data_shard_map_by_host(host_id)
        host_config_map = args.setup.find_config_shard_map_by_host(host_id)
        if not host_data_map and not host_config_map:
            print("There are no data or config shards on host '{}'".format(host_id))
            print("Please, run this script on another host or check shard map in mongo_setup.py")
            exit(1)
    elif args.command == ENABLE_SHARDING_COMMAND:
        args.strategy = SHARDING_STRATEGY[args.strategy]
    elif args.command == CREATE_LOCAL_CLUSTER_COMMAND and args.installation != "testing":
        print("Can't create local cluster with non-testing setup '{}'".format(args.installation))
        exit(1)
    return args


def init_logging():
    root_logger = logging.getLogger()
    root_logger.setLevel(logging.INFO)
    handler = logging.StreamHandler(sys.stderr)
    handler.setLevel(logging.DEBUG)
    formatter = logging.Formatter('%(message)s')
    handler.setFormatter(formatter)
    root_logger.addHandler(handler)


if __name__ == "__main__":
    init_logging()
    cli_args = handle_args()
    try:
        {
            RECONFIGURE_COMMAND: _reconfigure,
            CHECK_COLLECTIONS_COMMAND: _check_sharded_sandbox_collections,
            COLLECTION_INFO_COMMAND: _collection_info,
            ENABLE_SHARDING_COMMAND: _enable_sharding,
            RESTORE_SHARDS_COMMAND: _restore_shards,
            STEP_DOWN_COMMAND: _step_down,
            CREATE_LOCAL_CLUSTER_COMMAND: _create_local_cluster,
        }[cli_args.command](cli_args)
    except MongoToolException as error:
        logging.error(str(error))
