from dataclasses import dataclass
from typing import Any, Dict, Optional

from airflow.providers.amazon.aws.operators.sagemaker_endpoint import SageMakerEndpointOperator
from airflow.providers.amazon.aws.operators.sagemaker_endpoint_config import SageMakerEndpointConfigOperator
from airflow.utils.task_group import TaskGroup

from conductor.internal.dag_utils import deep_update
from conductor.operators.operator_iface import IConfiguredOperator, TaskWrapper
from conductor.utils.naming import timestamp_name


@dataclass
class SageMakerEndpointOutput:
    endpoint_config_name: str


class ConfiguredSageMakerEndpointOperator(IConfiguredOperator):
    def generate_tasks(
        self,
        *,
        task_id: str,
        model_name: str,
        endpoint_name: str,
        initial_instance_count: int = 1,
        instance_type: str = "ml.t2.medium",
        config: Optional[Dict[str, Any]] = None,
    ) -> TaskWrapper[TaskGroup, SageMakerEndpointOutput]:
        endpoint_config_name = timestamp_name(endpoint_name, 63)
        base_config: Dict[str, Any] = {
            "EndpointConfigName": endpoint_config_name,
            "ProductionVariants": [
                {
                    "VariantName": "main",
                    "ModelName": model_name,
                    "InitialInstanceCount": initial_instance_count,
                    "InstanceType": instance_type,
                }
            ],
        }
        if config is not None:
            deep_update(base_config, config)
        with TaskGroup(
            task_id, tooltip="Endpoint Config + Endpoint Deploy", dag=self.dag
        ) as task_group:
            endpoint_config = SageMakerEndpointConfigOperator(
                config=base_config, aws_conn_id=None, task_id="config", dag=self.dag
            )
            endpoint_config_output_name = (
                "{{ti.xcom_pull(task_ids='"
                + endpoint_config.task_id
                + "')['EndpointConfig']['EndpointConfigName']}}"
            )
            endpoint = (
                SageMakerEndpointOperator(
                    config={
                        "EndpointName": endpoint_name,
                        "EndpointConfigName": endpoint_config_output_name,
                    },
                    aws_conn_id=None,
                    task_id="endpoint",
                    dag=self.dag,
                ),
            )
            endpoint_config >> endpoint
        return TaskWrapper(
            task_group,
            SageMakerEndpointOutput(endpoint_config_name=endpoint_config_output_name),
        )
