import unittest
from unittest import mock

import pytest
from airflow import AirflowException

from twitch_airflow_components.operators.sagemaker_model_monitoring_baseline import SageMakerMonitoringBaselineOperator

baseline_params = {
    "ModelMonitorConfig": {
        "role": "arn:aws:iam::0122345678910:role/SageMakerPowerUser",
        "instance_count": 2,
        "instance_type": "ml.p2.xlarge",
        "volume_size_in_gb": 20,
    },
    "BaselineConfig": {
        "baseline_dataset": "{{ S3Uri }}",
        "dataset_format": {"csv": {"header": True, "output_columns_position": "START"}},
        "output_s3_uri": "{{ S3Uri }}",
    },
}


class TestSageMakerMonitoringBaselineOperator(unittest.TestCase):
    @mock.patch.object(
        SageMakerMonitoringBaselineOperator,
        "create_baseline_job",
        return_value={
            "ProcessingJobName": "baseline-suggestion-job",
            "ProcessingJobArn": "testarn",
            "ResponseMetadata": {"HTTPStatusCode": 200},
        },
    )
    def test_execute(self, mock_baselining):
        sagemaker = SageMakerMonitoringBaselineOperator(
            task_id="test_sagemaker_operator", config=baseline_params
        )
        sagemaker.execute(None)

        mock_baselining.assert_called_once_with(
            baseline_params, wait_for_completion=True
        )

    @mock.patch.object(
        SageMakerMonitoringBaselineOperator,
        "create_baseline_job",
        return_value={
            "ProcessingJobName": "baseline-suggestion-job",
            "ProcessingJobArn": "testarn",
            "ResponseMetadata": {"HTTPStatusCode": 404},
        },
    )
    def test_execute_with_failure(self, mock_baselining):
        sagemaker = SageMakerMonitoringBaselineOperator(
            task_id="test_sagemaker_operator", config=baseline_params
        )
        with pytest.raises(AirflowException):
            sagemaker.execute(None)
