import copy
import unittest
from unittest import mock

import pytest
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
from botocore.exceptions import ClientError
from parameterized import parameterized

from twitch_airflow_components.operators.sagemaker_model_monitoring_schedule import SageMakerMonitoringScheduleOperator

schedule_name = "test-schedule-name"
monitoring_job_name = "tes-job-name"
create_scheduling_params = {
    "MonitoringScheduleName": schedule_name,
    "MonitoringScheduleConfig": {
        "ScheduleConfig": {"ScheduleExpression": "cron(0 * ? * * *)"},  # runs hourly
        "MonitoringJobDefinition": {
            "BaselineConfig": {
                "BaseliningJobName": "BaseliningJobName",
                "ConstraintsResource": {"S3Uri": "{{ S3Uri }}"},
                "StatisticsResource": {"S3Uri": "{{ S3Uri }}"},
            },
            "MonitoringInputs": [
                {
                    "EndpointInput": {
                        "EndpointName": "EndpointName",
                        "LocalPath": "{{ Local Path }}",
                        "S3InputMode": "File",
                        "S3DataDistributionType": "FullyReplicated",
                    }
                },
            ],
            "MonitoringOutputConfig": {
                "MonitoringOutputs": [
                    {
                        "S3Output": {
                            "LocalPath": "{{ Local Path }}",
                            "S3UploadMode": "EndOfJob",
                            "S3Uri": "{{ S3Uri }}",
                        }
                    },
                ],
                "KmsKeyId": "KmsKeyID",
            },
            "MonitoringResources": {
                "ClusterConfig": {
                    "InstanceCount": 2,
                    "InstanceType": "ml.p2.xlarge",
                    "VolumeSizeInGB": 30,
                    "VolumeKmsKeyId": "{{ kms_key }}",
                }
            },
            "MonitoringAppSpecification": {
                "ContainerArguments": ["container_arg"],
                "ContainerEntrypoint": ["container_entrypoint"],
                "ImageUri": "{{ image_uri }}",
            },
            "RoleArn": "arn:aws:iam::0122345678910:role/SageMakerPowerUser",
        },
        "MonitoringJobDefinitionName": monitoring_job_name,
        "MonitoringType": "DataQuality",
    },
    "Tags": [{"{{ key }}": "{{ value }}"}],
}

create_processing_params_with_stopping_condition = copy.deepcopy(
    create_scheduling_params
)
create_processing_params_with_stopping_condition["MonitoringScheduleConfig"][
    "MonitoringJobDefinition"
].update(StoppingCondition={"MaxRuntimeInSeconds": 3600})

mock_sm_client = mock.Mock()


class TestSageMakerMonitoringScheduleOperator(unittest.TestCase):
    def setUp(self):
        self.scheduling_config_kwargs = dict(
            task_id="test_sagemaker_operator",
            wait_for_completion=False,
            check_interval=5,
        )

    @parameterized.expand(
        [
            (
                create_scheduling_params,
                [
                    [
                        "MonitoringScheduleConfig",
                        "MonitoringJobDefinition",
                        "MonitoringResources",
                        "ClusterConfig",
                        "InstanceCount",
                    ],
                    [
                        "MonitoringScheduleConfig",
                        "MonitoringJobDefinition",
                        "MonitoringResources",
                        "ClusterConfig",
                        "VolumeSizeInGB",
                    ],
                ],
            ),
            (
                create_processing_params_with_stopping_condition,
                [
                    [
                        "MonitoringScheduleConfig",
                        "MonitoringJobDefinition",
                        "MonitoringResources",
                        "ClusterConfig",
                        "InstanceCount",
                    ],
                    [
                        "MonitoringScheduleConfig",
                        "MonitoringJobDefinition",
                        "MonitoringResources",
                        "ClusterConfig",
                        "VolumeSizeInGB",
                    ],
                    [
                        "MonitoringScheduleConfig",
                        "MonitoringJobDefinition",
                        "StoppingCondition",
                        "MaxRuntimeInSeconds",
                    ],
                ],
            ),
        ]
    )
    def test_integer_fields_are_set(self, config, expected_fields):
        sagemaker = SageMakerMonitoringScheduleOperator(
            **self.scheduling_config_kwargs, config=config
        )
        assert sagemaker.integer_fields == expected_fields

    @mock.patch.object(SageMakerHook, "get_conn")
    @mock.patch.object(
        SageMakerMonitoringScheduleOperator,
        "create_monitoring_job",
        return_value={
            "MonitoringScheduleArn": "testarn",
            "ResponseMetadata": {"HTTPStatusCode": 200},
        },
    )
    @mock.patch.object(
        SageMakerMonitoringScheduleOperator,
        "describe_monitoring_job",
        return_value={
            "MonitoringScheduleArn": "testarn",
            "ResponseMetadata": {"HTTPStatusCode": 200},
        },
    )
    def test_execute(self, mock_describe_job, mock_monitoring, mock_conn):
        sagemaker = SageMakerMonitoringScheduleOperator(
            **self.scheduling_config_kwargs, config=create_scheduling_params
        )
        retval = sagemaker.execute(None)

        mock_monitoring.assert_called_once_with(
            create_scheduling_params,
            wait_for_completion=False,
            check_interval=5,
            max_ingestion_time=None,
        )

        assert retval == {
            "MonitorScheduling": {
                "MonitoringScheduleArn": "testarn",
                "ResponseMetadata": {"HTTPStatusCode": 200},
            }
        }

    @mock.patch.object(SageMakerHook, "get_conn")
    @mock.patch.object(
        SageMakerMonitoringScheduleOperator,
        "create_monitoring_job",
        return_value={
            "MonitoringScheduleArn": "testarn",
            "ResponseMetadata": {"HTTPStatusCode": 404},
        },
    )
    def test_execute_with_failure(self, mock_monitoring, mock_conn):
        sagemaker = SageMakerMonitoringScheduleOperator(
            **self.scheduling_config_kwargs, config=create_scheduling_params
        )
        with pytest.raises(AirflowException):
            sagemaker.execute(None)

    @mock.patch.object(SageMakerHook, "get_conn", return_value=mock_sm_client)
    @mock.patch.object(
        SageMakerMonitoringScheduleOperator,
        "describe_monitoring_job",
        return_value={"MonitoringScheduleStatus": "Scheduled"},
    )
    def test_create_monitoring_job(self, mock_describe_job, mock_cli_create_schedule):
        mock_sm_client.create_monitoring_schedule.return_value = {
            "MonitoringScheduleArn": "testarn",
            "ResponseMetadata": {"HTTPStatusCode": 200},
        }

        sagemaker = SageMakerMonitoringScheduleOperator(
            **self.scheduling_config_kwargs, config=create_scheduling_params
        )
        sagemaker.create_monitoring_job(
            config=sagemaker.config,
            wait_for_completion=True,
            check_interval=1,
            max_ingestion_time=1,
        )

        mock_sm_client.create_monitoring_schedule.assert_called_once_with(
            **create_scheduling_params
        )

    @mock.patch.object(SageMakerHook, "get_conn", return_value=mock_sm_client)
    @mock.patch.object(SageMakerMonitoringScheduleOperator, "describe_monitoring_job")
    def test_create_monitoring_job_with_failure(
        self, mock_describe_job, mock_cli_create_schedule
    ):
        sagemaker = SageMakerMonitoringScheduleOperator(
            **self.scheduling_config_kwargs, config=create_scheduling_params
        )
        # Test for schedule failure
        mock_describe_job.return_value = {
            "MonitoringScheduleStatus": "Failed",
            "FailureReason": "Test failure",
        }
        mock_sm_client.create_monitoring_schedule.return_value = {
            "MonitoringScheduleArn": "testarn",
            "ResponseMetadata": {"HTTPStatusCode": 200},
        }
        with pytest.raises(AirflowException):
            sagemaker.create_monitoring_job(
                config=sagemaker.config,
                wait_for_completion=True,
                check_interval=1,
                max_ingestion_time=1,
            )

        # test for time out
        mock_describe_job.return_value = {"MonitoringScheduleStatus": "Pending"}
        mock_sm_client.create_monitoring_schedule.return_value = {
            "MonitoringScheduleArn": "testarn",
            "ResponseMetadata": {"HTTPStatusCode": 200},
        }
        with pytest.raises(AirflowException):
            sagemaker.create_monitoring_job(
                config=sagemaker.config,
                wait_for_completion=True,
                check_interval=1,
                max_ingestion_time=2,
            )

    @mock.patch.object(SageMakerHook, "get_conn")
    @mock.patch.object(SageMakerMonitoringScheduleOperator, "create_monitoring_job")
    @mock.patch.object(SageMakerMonitoringScheduleOperator, "update_monitoring_job")
    def test_execute_with_duplicate_monitor_creation(
        self, mock_monitor_update, mock_monitor, mock_client
    ):
        response = {
            "Error": {
                "Code": "ValidationException",
                "Message": "Cannot create already existing monitor.",
            }
        }
        mock_monitor.side_effect = ClientError(
            error_response=response, operation_name="CreateMonitoringJob"
        )
        mock_monitor_update.return_value = {
            "EndpointArn": "testarn",
            "ResponseMetadata": {"HTTPStatusCode": 200},
        }
        sagemaker = SageMakerMonitoringScheduleOperator(
            **self.scheduling_config_kwargs, config=create_scheduling_params
        )
        sagemaker.execute(None)
        mock_monitor_update.assert_called_once()
