from airflow.models import DAG
from airflow.operators.bash import BashOperator
from airflow.utils.dates import days_ago

import twitch_airflow_components
from twitch_airflow_components.dag_utils import pagerduty_on_failure_callback


class MockPagerdutyHook:
    def mock_init(self, pagerduty_conn_id):
        self.pagerduty_conn_id = pagerduty_conn_id
        return self

    def create_event(self, summary, severity, source=None):
        self.summary = summary
        self.severity = severity
        self.source = source


def test_pagerduty_default(mocker):
    mock_hook = MockPagerdutyHook()
    mocker.patch.object(
        twitch_airflow_components.dag_utils, "PagerdutyHook", mock_hook.mock_init
    )

    default_args = {
        "owner": "airflow",
        "depends_on_past": False,
        "start_date": days_ago(2),
        "email": ["airflow@example.com"],
        "email_on_failure": False,
        "email_on_retry": False,
        "retries": 1,
    }
    dag = DAG(
        "test-dag",
        default_args=default_args,
        schedule_interval="@once",
    )
    task = BashOperator(
        task_id="print_date",
        bash_command="date",
        dag=dag,
    )
    context = {
        "dag": dag,
        "task": task,
        "run_id": "test_run_id",
        "params": {"pagerduty_conn_id": "pagerduty"},
    }
    pagerduty_on_failure_callback(context)

    assert mock_hook.pagerduty_conn_id == "pagerduty"
    assert (
        mock_hook.summary
        == "Execution failure in DAG test-dag run test_run_id at task print_date."
    )
    assert mock_hook.severity == "error"
    assert mock_hook.source == "test-dag"
