import io

from conductor.types.model import Model


class MockModel(Model):
    def __init__(self):
        super().__init__()

    def train(self):
        return

    def predict(self, data: io.StringIO) -> io.StringIO:
        return data


def test_sagemaker_model(c):
    test_model_op = c.operators.SageMakerModelOperator(
        task_id="model",
        model_name="test-model",
        model_cls=MockModel,
        model_s3_path="s3://some/path/model.tar.gz",
    )

    test_model_name = "test-model-test-env-{{(macros.time.time() * 1000) | int}}"
    assert test_model_op.task.config == {
        "ModelName": test_model_name,
        "PrimaryContainer": {
            "Image": "123456789012.dkr.ecr.us-west-2.amazonaws.com/test-project.test-env:test-branch-test-commit-hash",
            "ModelDataUrl": "s3://some/path/model.tar.gz",
            "Environment": {
                "ENV": "test-env",
                "GIT_BRANCH": "test-branch",
                "AWS_DEFAULT_REGION": "us-west-2",
                "MODEL_CLS_MODULE": MockModel.__module__,
                "MODEL_CLS_NAME": MockModel.__name__,
            },
        },
        "ExecutionRoleArn": "arn:aws:iam::123456789012:role/0EKirPILJEXjpsva-sm-execution-role-us-west-2",
        "EnableNetworkIsolation": False,
    }

    assert (
        test_model_op.outputs.model_name
        == "{{ti.xcom_pull(task_ids='model')['Model']['ModelName']}}"
    )
