diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index 4fc0552d64..e04b83a14f 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -789,7 +789,7 @@ def multi_version_compilation_supported( } if target_instance_type in NEO_IOC_TARGET_DEVICES: return framework in multi_version_frameworks_support_mapping["neo_ioc_targets"] - if target_instance_type == "ml_inf": + if target_instance_type == "ml_inf1": return framework in multi_version_frameworks_support_mapping["inferentia"] return False diff --git a/tests/unit/sagemaker/model/test_neo.py b/tests/unit/sagemaker/model/test_neo.py index a865a8c111..82a7b40afd 100644 --- a/tests/unit/sagemaker/model/test_neo.py +++ b/tests/unit/sagemaker/model/test_neo.py @@ -47,7 +47,7 @@ def test_compile_model_for_inferentia(sagemaker_session): ) model = _create_model(sagemaker_session) model.compile( - target_instance_family="ml_inf", + target_instance_family="ml_inf1", input_shape={"data": [1, 3, 1024, 1024]}, output_path="s3://output", role="role", @@ -313,7 +313,7 @@ def test_compile_with_pytorch_neo_in_ml_inf(session): model = _create_model() model.compile( - target_instance_family="ml_inf", + target_instance_family="ml_inf1", input_shape={"data": [1, 3, 1024, 1024]}, output_path="s3://output", role="role", @@ -336,7 +336,7 @@ def test_compile_with_tensorflow_neo_in_ml_inf(session): model = _create_model() model.compile( - target_instance_family="ml_inf", + target_instance_family="ml_inf1", input_shape={"data": [1, 3, 1024, 1024]}, output_path="s3://output", role="role",