from paysys.sre.tools.monitorings.configs.trust.base import utils
from paysys.sre.tools.monitorings.lib.checks import base
from paysys.sre.tools.monitorings.lib.util import yc
from paysys.sre.tools.monitorings.lib.util.aggregators import (
    empty_kwargs,
    more_than_limit_is_problem,
)
from paysys.sre.tools.monitorings.lib.util.helpers import (
    check,
    gen_children_from_tuples,
    gen_unreach,
    merge,
    solomon_check,
)
from paysys.sre.tools.monitorings.lib.util.solomon import (
    selectors_to_string,
    solomon_expression_custom,
)

SUBCLUSTER_MONGOD = "MONGOD"
SUBCLUSTER_MONGOS = "MONGOS"
SUBCLUSTER_MONGOCFG = "MONGOCFG"


def get_mongodb_hosts(cluster_id):
    resp = yc.get_default_client().list_mongodb_cluster_hosts(cluster_id)
    return resp["hosts"]


def mdb_metric(mdb_cluster, sensor, **kwargs):
    result = {
        "project": "internal-mdb",
        "cluster": "mdb_{}".format(mdb_cluster),
        "service": "mdb",
        "sensor": sensor,
    }
    result.update(kwargs)
    return selectors_to_string(result)


def solomon_multialert_check(
    name,
    hosts,
    solomon_project,
    solomon_program,
    window_secs,
    skip_by_unreachable=True,
):
    children = gen_children_from_tuples([host, name, "HOST"] for host in hosts)
    solomon_expr = solomon_expression_custom(
        project_id=solomon_project,
        program_str=solomon_program,
        annotations={
            "description": "{{expression.description}}",
            "juggler_host": "{{labels.host}}",
            "juggler_service": name,
        },
        group_by_labels=["host"],
        channels=[utils.juggler_event_channel],
        juggler_children=merge({"replace": True}, children),
        window_secs=window_secs,
    )

    subchecks = {}
    for host in hosts:
        checks_args = [{"children": []}]
        if skip_by_unreachable:
            checks_args.append(gen_unreach(["{}:UNREACHABLE".format(host)]))
        subchecks[host] = check(name, *checks_args)

    return merge(
        solomon_check(name, solomon_expr),
        check(name, {"subchecks": subchecks, "tags": ["subchecks"]}, empty_kwargs),
    )


def unreachable(subcluster_type, hosts, warn=1, crit=1):
    name = "{}-UNREACHABLE".format(subcluster_type.lower())
    hosts = [h["name"] for h in hosts if h["type"] == subcluster_type.upper()]

    subchecks = {}
    for host in hosts:
        children = gen_children_from_tuples([[host, "UNREACHABLE", "HOST"]])
        subchecks[host] = merge(
            base.unreachable,
            check("UNREACHABLE", children),
        )

    return check(
        name,
        {"subchecks": subchecks, "tags": ["subchecks"]},
        gen_children_from_tuples([host, "UNREACHABLE", "HOST"] for host in hosts),
        more_than_limit_is_problem(warn, crit),
    )


def cpu_usage(
    mdb_hosts, mdb_cluster, solomon_project, warn=60, crit=80, window_secs=60
):
    hosts = sorted([h["name"] for h in mdb_hosts])
    metric_usage = mdb_metric(mdb_cluster, "/porto/cpu_usage", dc="by_host")
    metric_limit = mdb_metric(mdb_cluster, "/porto/cpu_limit", dc="by_host")
    solomon_program = (
        "let usage = series_max({usage});\n"
        "let limit = series_max({limit});\n"
        "no_data_if(count(usage) == 0);\n"
        "no_data_if(count(limit) == 0);\n"
        "let avg_usage = avg(usage);\n"
        "let avg_limit = avg(limit);\n"
        "let usage_percent = avg_usage / avg_limit * 100;\n"
        'let description = "CPU usage - " + to_fixed(usage_percent,2) + "%";\n'
        "alarm_if(usage_percent > {threshold_crit});\n"
        "warn_if(usage_percent > {threshold_warn});\n"
    ).format(
        usage=metric_usage,
        limit=metric_limit,
        threshold_crit=crit,
        threshold_warn=warn,
    )
    return solomon_multialert_check(
        "cpu_usage", hosts, solomon_project, solomon_program, window_secs
    )


def unispace(mdb_hosts, mdb_cluster, solomon_project, threshold=80, window_secs=60):
    hosts = sorted([h["name"] for h in mdb_hosts])
    metric_usage = mdb_metric(
        mdb_cluster, "disk-used_bytes_/var/lib/mongodb", dc="by_host"
    )
    metric_free = mdb_metric(
        mdb_cluster, "disk-free_bytes_/var/lib/mongodb", dc="by_host"
    )
    solomon_program = (
        "let usage = series_max({usage});\n"
        "let free = series_max({free});\n"
        "no_data_if(count(usage) == 0);\n"
        "no_data_if(count(free) == 0);\n"
        "let last_usage = last(usage);\n"
        "let last_free = last(free);\n"
        "let usage_percent = last_usage / (last_usage + last_free) * 100;\n"
        'let description = "Disk usage - " + to_fixed(usage_percent,2) + "%";\n'
        "alarm_if(usage_percent > {threshold});\n"
    ).format(
        usage=metric_usage,
        free=metric_free,
        threshold=threshold,
    )
    return solomon_multialert_check(
        "unispace", hosts, solomon_project, solomon_program, window_secs
    )


def mem_free(mdb_hosts, mdb_cluster, solomon_project, threshold=80, window_secs=60):
    hosts = sorted([h["name"] for h in mdb_hosts])
    metric_usage = mdb_metric(mdb_cluster, "/porto/anon_usage", dc="by_host")
    metric_limit = mdb_metric(mdb_cluster, "/porto/anon_limit", dc="by_host")
    solomon_program = (
        "let usage = series_max({usage});\n"
        "let limit = series_max({limit});\n"
        "no_data_if(count(usage) == 0);\n"
        "no_data_if(count(limit) == 0);\n"
        "let avg_usage = avg(usage);\n"
        "let avg_limit = avg(limit);\n"
        "let usage_percent = avg_usage / avg_limit * 100;\n"
        'let description = "RAM usage - " + to_fixed(usage_percent,2) + "%";\n'
        "alarm_if(usage_percent > {threshold});\n"
    ).format(
        usage=metric_usage,
        limit=metric_limit,
        threshold=threshold,
    )
    return solomon_multialert_check(
        "mem-free", hosts, solomon_project, solomon_program, window_secs
    )


def op_rate(operation, mdb_cluster, solomon_project, thresholds, window_secs=60):
    metric_rate = mdb_metric(
        mdb_cluster,
        "server_status_admin_opLatencies.{}.ops_rate".format(operation),
        subcluster_name="mongod_subcluster",
        dc="by_host",
        host="*",
    )
    program = (
        "let rate = series_sum({rate});\n"
        "no_data_if(count(rate) == 0);\n"
        "let avg_rate = avg(rate);\n"
        'let description = "Rate ({operation}) - " + to_fixed(avg_rate,2) + " rps";\n'
        "alarm_if(avg_rate > {threshold_crit});\n"
        "warn_if(avg_rate > {threshold_warn});\n"
    ).format(
        rate=metric_rate,
        operation=operation,
        threshold_crit=thresholds.crit,
        threshold_warn=thresholds.warn,
    )
    expr = solomon_expression_custom(
        program_str=program,
        project_id=solomon_project,
        annotations={"description": "{{expression.description}}"},
        window_secs=window_secs,
        repeat_delay_secs=0,
    )
    name = "{}-ops_rate".format(operation)
    return solomon_check(name, expr, empty_kwargs)


def connections(
    subcluster_type,
    hosts,
    mdb_cluster,
    solomon_project,
    thresholds,
    window_secs=60,
):
    prefix = subcluster_type.lower()
    hosts = [h["name"] for h in hosts if h["type"] == subcluster_type.upper()]

    metric_connections = mdb_metric(
        mdb_cluster,
        "server_status_admin_connections.current",
        subcluster_name="{}_subcluster".format(prefix),
        dc="by_host",
    )
    solomon_program = (
        "let connections = series_max({connections});\n"
        'let description = "No connections";\n'
        "ok_if(count(connections) == 0);\n"
        "let avg_connections = round(avg(connections));\n"
        'let description = "Connection count ({subcluster}) - " + avg_connections;\n'
        "alarm_if(avg_connections > {threshold_crit});\n"
        "warn_if(avg_connections > {threshold_warn});\n"
    ).format(
        connections=metric_connections,
        subcluster=prefix,
        threshold_crit=thresholds.crit,
        threshold_warn=thresholds.warn,
    )
    name = "{}-connections".format(prefix)
    return solomon_multialert_check(
        name, hosts, solomon_project, solomon_program, window_secs
    )


def rs_lag(
    mdb_hosts,
    mdb_cluster,
    solomon_project,
    thresholds,
    is_sharded=False,
    window_secs=60,
):
    hosts = [h["name"] for h in mdb_hosts if h["type"] == SUBCLUSTER_MONGOD]
    sensor_is_primary = "server_status_admin_repl.ismaster"
    if is_sharded:
        sensor_is_primary = "server_status_admin_repl.isWritablePrimary"
    metric_is_master = mdb_metric(
        mdb_cluster,
        sensor_is_primary,
        subcluster_name="mongod_subcluster",
        dc="by_host",
    )
    metric_rs_lag = mdb_metric(
        mdb_cluster,
        "replset_status-replicationLag",
        subcluster_name="mongod_subcluster",
        dc="by_host",
    )
    solomon_program = (
        "let is_master = last({is_master});\n"
        'let description = "OK - PRIMARY";\n'
        "ok_if(is_master > 0);\n"
        "let rs_lag = series_max({rs_lag});\n"
        "let avg_rg_lag = avg(rs_lag);\n"
        'let description = "RS lag - " + to_fixed(avg_rg_lag,2) + " seconds";\n'
        "alarm_if(avg_rg_lag > {threshold_crit});\n"
        "warn_if(avg_rg_lag > {threshold_warn});\n"
    ).format(
        is_master=metric_is_master,
        rs_lag=metric_rs_lag,
        threshold_crit=thresholds.crit,
        threshold_warn=thresholds.warn,
    )
    return solomon_multialert_check(
        "mongo_rs_lag", hosts, solomon_project, solomon_program, window_secs
    )


def rs_state(hosts, mdb_cluster, solomon_project, is_sharded=False, window_secs=60):
    hosts = [h["name"] for h in hosts if h["type"] == SUBCLUSTER_MONGOD]
    sensor_is_primary = "server_status_admin_repl.ismaster"
    if is_sharded:
        sensor_is_primary = "server_status_admin_repl.isWritablePrimary"
    metric_is_primary = mdb_metric(
        mdb_cluster,
        sensor_is_primary,
        subcluster_name="mongod_subcluster",
        dc="by_host",
    )
    metric_is_secondary = mdb_metric(
        mdb_cluster,
        "server_status_admin_repl.secondary",
        subcluster_name="mongod_subcluster",
        dc="by_host",
    )
    solomon_program = (
        "let is_primary = {is_primary};\n"
        "let is_secondary = {is_secondary};\n"
        'let description = "No data";\n'
        "alarm_if(count(is_primary) == 0 || size(is_primary) == 0);\n"
        "alarm_if(count(is_secondary) == 0 || size(is_secondary) == 0);\n"
        "let last_is_primary = last(is_primary);\n"
        "let last_is_secondary = last(is_secondary);\n"
        'let description = "Unknown state";\n'
        "alarm_if(last_is_primary + last_is_secondary != 1);\n"
        'let description = "OK - PRIMARY";\n'
        "ok_if(last_is_primary == 1);\n"
        'let description = "OK - SECONDARY";\n'
        "ok_if(last_is_secondary == 1);\n"
    ).format(
        is_primary=metric_is_primary,
        is_secondary=metric_is_secondary,
    )
    return solomon_multialert_check(
        "mongo_rs_state", hosts, solomon_project, solomon_program, window_secs
    )


def group_hosts_by_shards(mdb_hosts):
    hosts_by_shards = {}
    for h in mdb_hosts:
        if h["type"] != SUBCLUSTER_MONGOD:
            continue
        shard_name = h["shardName"]
        shard_hosts = hosts_by_shards.get(shard_name, [])
        shard_hosts.append(h["name"])
        hosts_by_shards[shard_name] = shard_hosts
    return hosts_by_shards


def master_switch(
    mdb_hosts, mdb_cluster, solomon_project, is_sharded=False, window_secs=60
):
    hosts_by_shards = group_hosts_by_shards(mdb_hosts)

    sensor_is_primary = "server_status_admin_repl.ismaster"
    if is_sharded:
        sensor_is_primary = "server_status_admin_repl.isWritablePrimary"

    def host_variable(host):
        # Host example: vla-mcjtynvfjq7bwcnz.db.yandex.net
        return host.split(".")[0].replace("-", "_")

    solomon_program = ""
    shards = sorted(hosts_by_shards.keys())
    for shard in shards:
        solomon_program += "// Checks for shard {}\n".format(shard)
        solomon_program += 'let to = "";\n\n'
        host_names = sorted(hosts_by_shards[shard])
        for host in host_names:
            metric_is_primary = mdb_metric(
                mdb_cluster,
                sensor_is_primary,
                subcluster_name="mongod_subcluster",
                dc="by_host",
                host=host,
            )
            solomon_program += (
                "// Check whether {host_name} is primary or not\n"
                "let {host_var} = {is_primary};\n"
                'let to = last({host_var}) > 0 ? "{host_name}" : to;\n\n'
            ).format(
                host_var=host_variable(host),
                host_name=host,
                is_primary=metric_is_primary,
            )
        solomon_program += (
            "// Check whether master was switched or not:\n"
            "// The sum of the maximum values in the interval must be equal to 1 if master was not changed\n"
            "let sum_max_is_primary = {sum_max_is_primary};\n"
            'let description = "Master was switched to " + to + " in shard {shard}.";\n'
            "alarm_if(sum_max_is_primary > 1);\n\n"
        ).format(
            shard=shard,
            sum_max_is_primary=" + ".join(
                "max({})".format(host_variable(h)) for h in host_names
            ),
        )
    solomon_program += 'let description = "OK";\n'
    solomon_expr = solomon_expression_custom(
        program_str=solomon_program,
        project_id=solomon_project,
        annotations={"description": "{{expression.description}}"},
        window_secs=window_secs,
        repeat_delay_secs=0,
    )
    return solomon_check("master-switch", solomon_expr, empty_kwargs)


def get_checks(
    mdb_cluster,
    solomon_project,
    op_reads_thresholds=utils.thresholds(50, 100),
    op_writes_thresholds=utils.thresholds(20, 40),
    op_commands_thresholds=utils.thresholds(50, 100),
    mongod_conn_thresholds=utils.thresholds(100, 200),
    mongos_conn_thresholds=utils.thresholds(300, 500),
    rs_lag_thresholds=utils.thresholds(10, 30),
    is_sharded=False,
    mdb_hosts=None,
):
    if mdb_hosts is None:
        mdb_hosts = get_mongodb_hosts(mdb_cluster)

    def connection_count(subcluster_type, thresholds):
        return connections(
            subcluster_type,
            mdb_hosts,
            mdb_cluster,
            solomon_project,
            thresholds,
        )

    return merge(
        unreachable(SUBCLUSTER_MONGOD, mdb_hosts),
        unreachable(SUBCLUSTER_MONGOS, mdb_hosts),
        unreachable(SUBCLUSTER_MONGOCFG, mdb_hosts),
        cpu_usage(mdb_hosts, mdb_cluster, solomon_project),
        unispace(mdb_hosts, mdb_cluster, solomon_project),
        mem_free(mdb_hosts, mdb_cluster, solomon_project),
        op_rate("reads", mdb_cluster, solomon_project, op_reads_thresholds),
        op_rate("writes", mdb_cluster, solomon_project, op_writes_thresholds),
        op_rate("commands", mdb_cluster, solomon_project, op_commands_thresholds),
        rs_lag(mdb_hosts, mdb_cluster, solomon_project, rs_lag_thresholds, is_sharded),
        rs_state(mdb_hosts, mdb_cluster, solomon_project, is_sharded),
        master_switch(mdb_hosts, mdb_cluster, solomon_project, is_sharded),
        connection_count(SUBCLUSTER_MONGOD, mongod_conn_thresholds),
        connection_count(SUBCLUSTER_MONGOS, mongos_conn_thresholds),
    )
