diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index 317e865ddd..ffa6cf1a84 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -628,9 +628,9 @@ def _compilation_job_config( "Framework": framework.upper(), } + multiple_version_supported_framework_list = ["pytorch", "tensorflow"] if ( - framework.lower() == "pytorch" - or framework.lower() == "tensorflow" + framework.lower() in multiple_version_supported_framework_list and target_instance_type is not None and re.match("(?=^ml_)(?!ml_inf)", target_instance_type) is not None and framework_version is not None diff --git a/tests/unit/sagemaker/model/test_neo.py b/tests/unit/sagemaker/model/test_neo.py index 2357c771f9..a865a8c111 100644 --- a/tests/unit/sagemaker/model/test_neo.py +++ b/tests/unit/sagemaker/model/test_neo.py @@ -330,6 +330,29 @@ def test_compile_with_pytorch_neo_in_ml_inf(session): ) +@patch("sagemaker.session.Session") +def test_compile_with_tensorflow_neo_in_ml_inf(session): + session.return_value.boto_region_name = REGION + + model = _create_model() + model.compile( + target_instance_family="ml_inf", + input_shape={"data": [1, 3, 1024, 1024]}, + output_path="s3://output", + role="role", + framework="tensorflow", + framework_version="1.15", + job_name="compile-model", + ) + + assert ( + "{}.dkr.ecr.{}.amazonaws.com/sagemaker-inference-tensorflow:1.15-cpu-py3".format( + NEO_REGION_ACCOUNT, REGION + ) + != model.image_uri + ) + + def test_compile_validates_framework_version(sagemaker_session): sagemaker_session.wait_for_compilation_job = Mock( return_value={