import sagemaker
from airflow import AirflowException
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from airflow.providers.amazon.aws.operators.sagemaker_base import SageMakerBaseOperator
from sagemaker.model_monitor import DefaultModelMonitor


class SageMakerMonitoringBaselineOperator(SageMakerBaseOperator):
    def __init__(
        self,
        *,
        config: dict,
        wait_for_completion: bool = True,
        action_if_job_exists: str = "increment",
        **kwargs,
    ):
        super().__init__(config=config, **kwargs)
        if action_if_job_exists not in ("increment", "fail"):
            raise AirflowException(
                "Argument action_if_job_exists accepts only 'increment' and 'fail'. "
                f"Provided value: '{action_if_job_exists}'."
            )
        self.action_if_job_exists = action_if_job_exists
        self.wait_for_completion = wait_for_completion

    def expand_role(self) -> None:
        if "RoleArn" in self.config:
            hook = AwsBaseHook(self.aws_conn_id, client_type="iam")
            self.config["RoleArn"] = hook.expand_role(self.config["RoleArn"])

    def execute(self, context) -> dict:
        self.preprocess_config()
        self.log.info("Suggesting SageMaker Model Baseline.")
        response = self.create_baseline_job(
            self.config, wait_for_completion=self.wait_for_completion
        )
        if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
            raise AirflowException(
                f"SageMaker Model Monitor baselining failed: {response}"
            )
        return {"MonitorBaselining": response}

    def create_baseline_job(self, config: dict, wait_for_completion: bool = True):

        sagemaker_session = sagemaker.session.Session(self.hook.get_session())

        self.log.info("creating default monitor")
        default_monitor = DefaultModelMonitor(
            **config["ModelMonitorConfig"], sagemaker_session=sagemaker_session
        )

        self.log.info("suggesting baseline")
        response = default_monitor.suggest_baseline(
            **config["BaselineConfig"], wait=wait_for_completion
        )
        return response.describe()
