from itertools import chain

from paysys.sre.tools.monitorings.lib.util.aggregators import empty_kwargs
from paysys.sre.tools.monitorings.lib.util.helpers import (
    merge,
    unreach_skip,
    gen_children,
    ttl,
    flaps,
    check,
    gen_children_from_tuples,
    unreach_checks,
)
from paysys.sre.tools.monitorings.lib.checks.graphite import graphite_check
from paysys.sre.tools.monitorings.lib.checks.doc import doc_link
from paysys.sre.tools.monitorings.lib.util import conductor

from collections import namedtuple


ShardGroup = namedtuple(
    "ShardGroup", ["children", "shards", "start_shard", "config_server"]
)


def get_simple_shard_group(children, shards):
    if not isinstance(children, list):
        children = [children]
    return [
        ShardGroup(child, shards, start_shard=1, config_server="True")
        for child in children
    ]


def mongo_with_single_shardgroup(
    children, cluster, shards, host=None, split_by_dc=None
):
    return mongo_with_shardgroups(
        get_simple_shard_group(children, shards),
        cluster,
        host=host,
        split_by_dc=split_by_dc,
    )


def __prepare_mongo_check(
    dc, children, shard_group, cluster, balancer, group_type="CGROUP"
):
    return __mongo(
        children,
        cluster,
        shard_group.shards,
        balancer=balancer,
        config_server=shard_group.config_server,
        start_shard=shard_group.start_shard,
        dc=dc,
        group_type=group_type,
    )


def __merge_result(result, data):
    for k, v in data.items():
        if k in result.keys():
            result[k]["children"] += v["children"]
        else:
            result[k] = v
    return result


def mongo_with_shardgroups(
    shard_groups, cluster, balancer=True, host=None, split_by_dc=None
):
    result = {}
    for shard_group in shard_groups:
        if split_by_dc:
            for dc, hosts in conductor.split_hosts_by_dc(shard_group.children).items():
                r = __prepare_mongo_check(
                    dc, list(hosts), shard_group, cluster, balancer, group_type="HOST"
                )
                for c, args in r.items():
                    args.update(unreach_checks(["{0}{1}:UNREACHABLE".format(host, dc)]))
                __merge_result(result, r)
        else:
            __merge_result(
                result,
                __prepare_mongo_check(
                    "", shard_group.children, shard_group, cluster, balancer
                ),
            )

    return result


def __gen_children(children, cluster, config, group_type):
    services = [
        "mongo_{0}-shard{1}_{2}".format(cluster, x, config[0]) for x in config[1]
    ]
    return gen_children(children, services, group_type=group_type)


def __mongo(
    children,
    cluster,
    shards,
    balancer=True,
    config_server=True,
    start_shard=1,
    dc=None,
    group_type=None,
):
    _shards = list(xrange(start_shard, start_shard + shards))

    cfg = ["-config"] if config_server else []
    cfg_mongos = ["-config", "-mongos"] if config_server else ["-mongos"]

    data = {
        "mongo_connect": ("connect", cfg_mongos + _shards),
        "mongo_rs_state": ("replset_state", cfg + _shards),
        "mongo_rs_lag": ("replication_lag", cfg + _shards),
        "mongo_statistics": ("statistics", cfg_mongos + _shards),
    }
    if balancer:
        data.update({"mongo_balancer_state": ("balancer_state", ["-mongos"])})

    checks = {
        k: __gen_children(children, cluster, v, group_type) for k, v in data.items()
    }

    # Add flap to mongo_statistics, because it is really not an important check
    checks["mongo_statistics"] = merge(
        checks["mongo_statistics"], flaps(120, 600), ttl(7200, 60),
    )

    return {
        "{0}{1}".format(k, dc.replace(".", "_")): merge(
            v, unreach_skip, doc_link("commonmongodb")
        )
        for k, v in checks.items()
    }


def backup(children, group_type='CGROUP'):
    return {
        "mongo_backup": merge(
            ttl(7200, 600),
            gen_children(children, "backup_mongo", group_type),
            unreach_skip,
            doc_link("commonmongodb"),
        )
    }


def _single_mongod_oplog_window(children, mongod_instance_port, warn, crit):
    name = "oplog_window_{port}"
    template = "one_min.{host}.mongo.{port}.oplog_window"
    subchecks = {}
    hosts = [
        host.fqdn
        for host in chain(*[conductor.get_hosts_in_group(child) for child in children])
    ]
    hosts.sort()

    for _host in hosts:
        subchecks[_host] = merge(
            graphite_check(
                name.format(port=mongod_instance_port),
                template.format(
                    host=_host.replace(".", "_"), port=mongod_instance_port
                ),
                crit,
                warn,
                "-20min",
                less=True,
            ),
            check(
                name.format(port=mongod_instance_port),
                flaps(120, 300),
                {"children": []},
            ),
        )
    return check(
        name.format(port=mongod_instance_port),
        flaps(120, 300),
        {"tags": ["subchecks"]},
        {"subchecks": subchecks},
        gen_children_from_tuples(
            [host, name.format(port=mongod_instance_port), "HOST"] for host in hosts
        ),
        empty_kwargs,
    )


def mongod_oplog_window(children, mongod_instance_ports, warn, crit):
    # monitor oplog window, measures in hours
    checks = {}
    for port in mongod_instance_ports:
        checks.update(_single_mongod_oplog_window(children, port, warn, crit))
    return checks


def _single_mongod_scanned_objects(
    children, mongod_instance_port, warn=None, crit=None
):
    if warn is None:
        warn = 40000
    if crit is None:
        crit = 60000
    name = "scanned_objects_{port}"
    template = (
        "perSecond(one_min.{host}.mongo.{port}.metrics.queryExecutor.scannedObjects)"
    )
    return _single_mongod_scanned_objects_base(
        children, mongod_instance_port, warn, crit, name, template
    )


def _single_mongod_scanned_objects_ratio(
    children, mongod_instance_port, warn=None, crit=None
):
    if warn is None:
        warn = 3.0
    if crit is None:
        crit = 4.0
    name = "scanned_objects_ratio_{port}"
    template = (
        "divideSeries("
        + "perSecond(one_min.{host}.mongo.{port}.metrics.queryExecutor.scannedObjects),"
        + "perSecond(one_min.{host}.mongo.{port}.metrics.document.returned)"
        + ")"
    )
    return _single_mongod_scanned_objects_base(
        children, mongod_instance_port, warn, crit, name, template
    )


def _single_mongod_scanned_objects_base(
    children, mongod_instance_port, warn, crit, name, template
):
    subchecks = {}
    hosts = [
        host.fqdn
        for host in chain(*[conductor.get_hosts_in_group(child) for child in children])
    ]
    hosts.sort()

    for _host in hosts:
        subchecks[_host] = merge(
            graphite_check(
                name.format(port=mongod_instance_port),
                template.format(
                    host=_host.replace(".", "_"), port=mongod_instance_port
                ),
                crit,
                warn,
                "-20min",
            ),
            check(
                name.format(port=mongod_instance_port),
                flaps(120, 300),
                {"children": []},
            ),
        )
    return check(
        name.format(port=mongod_instance_port),
        flaps(120, 300),
        {"tags": ["subchecks"]},
        {"subchecks": subchecks},
        gen_children_from_tuples(
            [host, name.format(port=mongod_instance_port), "HOST"] for host in hosts
        ),
        empty_kwargs,
    )


def mongod_scanned_objects(
    children, mongod_instance_ports, warn=None, crit=None, ratio=None
):
    checks = {}
    for port in mongod_instance_ports:
        if ratio:
            checks.update(
                _single_mongod_scanned_objects_ratio(children, port, warn, crit)
            )
        else:
            checks.update(_single_mongod_scanned_objects(children, port, warn, crit))
    return checks
