from dataclasses import dataclass
from typing import Any, Dict, Optional

from twitch_airflow_components.operators.sagemaker_model_monitoring_baseline import SageMakerMonitoringBaselineOperator

from conductor.internal.dag_utils import deep_update
from conductor.operators.operator_iface import IConfiguredOperator, TaskWrapper


@dataclass
class SageMakerMonitoringBaselineOutputs:
    statistics_s3_url: str
    constraints_s3_url: str


class ConfiguredSageMakerMonitoringBaselineOperator(IConfiguredOperator):
    def generate_tasks(
        self,
        *,
        task_id: str,
        baseline_dataset_uri: str,
        baseline_dataset_format: Dict[str, Any],
        instance_count: int = 1,
        instance_type: str = "ml.m5.large",
        volume_size_in_gb: int = 30,
        config: Optional[Dict[str, Any]] = None,
    ) -> TaskWrapper[
        SageMakerMonitoringBaselineOperator, SageMakerMonitoringBaselineOutputs
    ]:

        base_config: Dict[str, Any] = {
            "ModelMonitorConfig": {
                "role": self.project_resources.sagemaker_execution_role(),
                "instance_count": instance_count,
                "instance_type": instance_type,
                "volume_size_in_gb": volume_size_in_gb,
            },
            "BaselineConfig": {
                "baseline_dataset": baseline_dataset_uri,
                "dataset_format": baseline_dataset_format,
                "output_s3_uri": self.dag_resources.s3_url_for_path([task_id]),
            },
        }

        if config is not None:
            deep_update(base_config, config)

        output_s3_url = (
            "{{ti.xcom_pull(task_ids='"
            + task_id
            + "')['MonitorBaselining']['ProcessingOutputConfig']['Outputs'][0]['S3Output']['S3Uri']}}"
        )
        return TaskWrapper[
            SageMakerMonitoringBaselineOperator, SageMakerMonitoringBaselineOutputs
        ](
            SageMakerMonitoringBaselineOperator(
                config=base_config,
                task_id=task_id,
                aws_conn_id=None,
                dag=self.dag,
            ),
            SageMakerMonitoringBaselineOutputs(
                statistics_s3_url=f"{output_s3_url}/statistics.json",
                constraints_s3_url=f"{output_s3_url}/constraints.json",
            ),
        )
