from conductor.types.model import Model


class ModelTestClass(Model):
    pass


def test_sagemaker_training(c):
    test_training_op = c.operators.SageMakerTrainingOperator(
        task_id="training", model_cls=ModelTestClass
    )

    job_name = "test-project-training-test-dag-{{(macros.time.time() * 1000) | int}}"
    assert test_training_op.task.config == {
        "TrainingJobName": job_name,
        "ResourceConfig": {
            "InstanceCount": 1,
            "InstanceType": "ml.t3.medium",
            "VolumeSizeInGB": 1,
        },
        "AlgorithmSpecification": {
            "TrainingImage": "123456789012.dkr.ecr.us-west-2.amazonaws.com/test-project.test-env:test-branch-test-commit-hash",
            "TrainingInputMode": "File",
        },
        "RoleArn": "arn:aws:iam::123456789012:role/0EKirPILJEXjpsva-sm-execution-role-us-west-2",
        "OutputDataConfig": {
            "S3OutputPath": "s3://test-project.test-env.123456789012/test-branch/test-dag/{{run_id}}/training"
        },
        "Environment": {
            "MODEL_CLS_MODULE": "tests.operators.test_sagemaker_training",
            "MODEL_CLS_NAME": "ModelTestClass",
            "CONDUCTOR_ENV": "test-env",
            "CONDUCTOR_GIT_BRANCH": "test-branch",
            "CONDUCTOR_COMMIT_HASH": "test-commit-hash",
            "AWS_DEFAULT_REGION": "us-west-2",
        },
        "StoppingCondition": {"MaxRuntimeInSeconds": 86400},  # One day.
        "EnableNetworkIsolation": False,
    }

    assert (
        test_training_op.outputs.s3_url
        == "{{ti.xcom_pull(task_ids='training')['Training']['ModelArtifacts']['S3ModelArtifacts']}}"
    )
