from conductor.internal.dag_utils import deep_update


def test_deep_update():
    source = {"hello1": 1}
    overrides = {"hello2": 2}
    deep_update(source, overrides)
    assert source == {"hello1": 1, "hello2": 2}

    source = {"hello": "to_override"}
    overrides = {"hello": "over"}
    deep_update(source, overrides)
    assert source == {"hello": "over"}

    source = {"hello": {"value": "to_override", "no_change": 1}}
    overrides = {"hello": {"value": "over"}}
    deep_update(source, overrides)
    assert source == {"hello": {"value": "over", "no_change": 1}}

    source = {"hello": {"value": "to_override", "no_change": 1}}
    overrides = {"hello": {"value": {}}}
    deep_update(source, overrides)
    assert source == {"hello": {"value": {}, "no_change": 1}}

    source = {"hello": {"value": {}, "no_change": 1}}
    overrides = {"hello": {"value": 2}}
    deep_update(source, overrides)
    assert source == {"hello": {"value": 2, "no_change": 1}}

    source = {
        "TrainingJobName": "test-job",
        "ResourceConfig": {
            "InstanceCount": 1,
            "InstanceType": "ml.t3.medium",
            "VolumeSizeInGB": 1,
        },
        "AlgorithmSpecification": {
            "TrainingImage": "123456789012.dkr.ecr.us-west-2.amazonaws.com/test-project.test-env:test-branch-test-commit-hash",
            "TrainingInputMode": "File",
        },
        "RoleArn": "test-role",
        "OutputDataConfig": {
            "S3OutputPath": "s3://test-project.test-env/test-branch/test-dag/{{run_id}}/training_custom_config"
        },
        "Environment": {
            "MODEL_CLS_MODULE": "tests.operators.test_sagemaker_training",
            "MODEL_CLS_NAME": "ModelTestClass",
            "CONDUCTOR_ENV": "test-env",
            "CONDUCTOR_GIT_BRANCH": "test-branch",
            "CONDUCTOR_COMMIT_HASH": "test-commit-hash",
            "AWS_DEFAULT_REGION": "us-west-2",
        },
        "StoppingCondition": {"MaxRuntimeInSeconds": 86400},  # One day.
        "EnableNetworkIsolation": False,
    }
    overrides = {
        "ResourceConfig": {"InstanceType": "ml.p3.2xlarge"},
        "AlgorithmSpecification": {
            "MetricDefinitions": [
                {"Name": "loss", "Regex": " loss: ([0-9.]+)"},
                {"Name": "offensive_loss", "Regex": "offensive_loss: ([0-9.]+)"},
                {"Name": "bias_loss", "Regex": "bias_loss: ([0-9.]+)"},
            ]
        },
        "RoleArn": "arn:aws:iam::123456789012:role/other-test-role",
    }

    assert deep_update(source, overrides) == {
        "TrainingJobName": "test-job",
        "ResourceConfig": {
            "InstanceCount": 1,
            "InstanceType": "ml.p3.2xlarge",
            "VolumeSizeInGB": 1,
        },
        "AlgorithmSpecification": {
            "TrainingImage": "123456789012.dkr.ecr.us-west-2.amazonaws.com/test-project.test-env:test-branch-test-commit-hash",
            "TrainingInputMode": "File",
            "MetricDefinitions": [
                {"Name": "loss", "Regex": " loss: ([0-9.]+)"},
                {"Name": "offensive_loss", "Regex": "offensive_loss: ([0-9.]+)"},
                {"Name": "bias_loss", "Regex": "bias_loss: ([0-9.]+)"},
            ],
        },
        "RoleArn": "arn:aws:iam::123456789012:role/other-test-role",
        "OutputDataConfig": {
            "S3OutputPath": "s3://test-project.test-env/test-branch/test-dag/{{run_id}}/training_custom_config"
        },
        "Environment": {
            "MODEL_CLS_MODULE": "tests.operators.test_sagemaker_training",
            "MODEL_CLS_NAME": "ModelTestClass",
            "CONDUCTOR_ENV": "test-env",
            "CONDUCTOR_GIT_BRANCH": "test-branch",
            "CONDUCTOR_COMMIT_HASH": "test-commit-hash",
            "AWS_DEFAULT_REGION": "us-west-2",
        },
        "StoppingCondition": {"MaxRuntimeInSeconds": 86400},  # One day.
        "EnableNetworkIsolation": False,
    }


if __name__ == "__main__":
    test_deep_update()
