from dataclasses import dataclass
from typing import Any, Dict, Optional, Type

from airflow.providers.amazon.aws.operators.sagemaker_model import SageMakerModelOperator

from conductor.internal.dag_utils import deep_update
from conductor.operators.operator_iface import IConfiguredOperator, TaskWrapper
from conductor.types.model import Model
from conductor.utils.naming import timestamp_name


@dataclass
class SageMakerModelOutput:
    model_name: str


class ConfiguredSageMakerModelOperator(IConfiguredOperator):
    def generate_tasks(
        self,
        *,
        task_id: str,
        model_name: str,
        model_cls: Type[Model],
        model_s3_path: str,
        config: Optional[Dict[str, Any]] = None,
    ) -> TaskWrapper[SageMakerModelOperator, SageMakerModelOutput]:
        env_name = self.project_resources.env_name
        env_config = self.project_resources.env
        model_name = timestamp_name(f"{model_name}-{env_name}", 63)
        base_config: Dict[str, Any] = {
            "ModelName": model_name,
            "PrimaryContainer": {
                "Image": self.project_resources.ecr_url(),
                "ModelDataUrl": model_s3_path,
                "Environment": {
                    "ENV": env_name,
                    "GIT_BRANCH": self.project_resources.branch,
                    "AWS_DEFAULT_REGION": env_config.default_region,
                    "MODEL_CLS_MODULE": model_cls.__module__,
                    "MODEL_CLS_NAME": model_cls.__name__,
                },
            },
            "ExecutionRoleArn": self.project_resources.sagemaker_execution_role(),
            "EnableNetworkIsolation": False,
        }
        if env_config.vpc is not None:
            base_config["VpcConfig"] = {
                "SecurityGroupIds": env_config.vpc.security_groups,
                "Subnets": env_config.vpc.subnets,
            }
        if config is not None:
            deep_update(base_config, config)
        model_name_output = (
            "{{ti.xcom_pull(task_ids='" + task_id + "')['Model']['ModelName']}}"
        )
        return TaskWrapper[SageMakerModelOperator, SageMakerModelOutput](
            SageMakerModelOperator(
                config=base_config, task_id=task_id, aws_conn_id=None, dag=self.dag
            ),
            SageMakerModelOutput(model_name_output),
        )
