import juggler_sdk

from paysys.sre.tools.monitorings.lib.util.aggregators import empty_kwargs
from paysys.sre.tools.monitorings.lib.util.helpers import (
    check,
    nodata_skip,
    merge,
    solomon_check,
)
from paysys.sre.tools.monitorings.lib.util.solomon import (
    create_solomon_alert_id,
    solomon_expression_custom,
)


def solomon_multialert_by_topic(
    name,
    host,
    solomon_project,
    solomon_sensor,
    solomon_cluster,
    description_prefix,
    default_threshold,
    topic_thresholds=None,
    window_secs=60,
):
    metric = (
        'series_avg({{project="trust",cluster="{cluster}",service="scheduler",host!="cluster",sensor="{sensor}"}})'
    ).format(
        cluster=solomon_cluster,
        sensor=solomon_sensor
    )
    thresholds = (
        "let warn_threshold = {default_warn_threshold};\n"
        "let crit_threshold = {default_crit_threshold};\n"
    ).format(
        default_warn_threshold=default_threshold.warn,
        default_crit_threshold=default_threshold.crit,
    )
    if topic_thresholds:
        thresholds += (
            "//////////////////////////////\n" "// Begin threshold overrides\n"
        )
        for topic in sorted(topic_thresholds.keys()):
            topic_threshold = topic_thresholds[topic]
            thresholds += (
                'let warn_threshold = topic == "{topic}" ? {warn_threshold} : warn_threshold;\n'
                'let crit_threshold = topic == "{topic}" ? {crit_threshold} : crit_threshold;\n'
            ).format(
                topic=topic,
                warn_threshold=topic_threshold.warn,
                crit_threshold=topic_threshold.crit,
            )
        thresholds += "// End threshold overrides\n" "///////////////////////////\n"
    solomon_program = (
        "let count = {metric};\n"
        'let description = "No data";\n'
        "no_data_if(count(count) == 0 || size(count) == 0);\n"
        "let avg_count = avg(count);\n"
        'let topic = get_label(count, "topic");\n'
        "{thresholds}"
        'let description = "{description_prefix} " + topic + " - " + avg_count'
        ' + " (WARN - " + warn_threshold + ", CRIT - " + crit_threshold + ")";\n'
        "alarm_if(avg_count > crit_threshold);\n"
        "warn_if(avg_count > warn_threshold);\n"
    ).format(
        metric=metric,
        thresholds=thresholds,
        description_prefix=description_prefix,
    )
    juggler_child = juggler_sdk.Child(
        create_solomon_alert_id(host, name),
        "all",
        "all",
        "MONITORING_MULTIALERT",
    )
    solomon_expr = solomon_expression_custom(
        project_id=solomon_project,
        program_str=solomon_program,
        annotations={"description": "{{expression.description}}"},
        group_by_labels=["topic"],
        juggler_children={"replace": True, "children": [juggler_child]},
        window_secs=window_secs,
    )
    return merge(
        solomon_check(name, solomon_expr),
        check(name, empty_kwargs, nodata_skip),
    )


def queue_size(
    host,
    solomon_project,
    solomon_cluster,
    default_threshold,
    topic_thresholds=None,
    window_secs=60,
):
    return solomon_multialert_by_topic(
        "queue_size",
        host,
        solomon_project,
        "queue_size_topic",
        solomon_cluster,
        "Queue size for topic",
        default_threshold,
        topic_thresholds,
        window_secs,
    )


def expired_count(
    host,
    solomon_project,
    solomon_cluster,
    default_threshold,
    topic_thresholds=None,
    window_secs=60,
):
    return solomon_multialert_by_topic(
        "expired_count",
        host,
        solomon_project,
        "expired_tasks_count_by_topic",
        solomon_cluster,
        "Expired count for topic",
        default_threshold,
        topic_thresholds,
        window_secs,
    )


def get_checks(
    host, solomon_project, solomon_cluster, queue_thresholds, expired_thresholds
):
    return merge(
        queue_size(
            host,
            solomon_project,
            solomon_cluster,
            queue_thresholds.pop("default"),
            queue_thresholds,
        ),
        expired_count(
            host,
            solomon_project,
            solomon_cluster,
            expired_thresholds.pop("default"),
            expired_thresholds,
        ),
    )
