Skip to content

Commit 3138e99

Browse files
fix: IOC image version select issue (#3021)
Co-authored-by: Navin Soni <[email protected]>
1 parent 80c3c13 commit 3138e99

File tree

2 files changed

+25
-2
lines changed

2 files changed

+25
-2
lines changed

src/sagemaker/model.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -628,9 +628,9 @@ def _compilation_job_config(
628628
"Framework": framework.upper(),
629629
}
630630

631+
multiple_version_supported_framework_list = ["pytorch", "tensorflow"]
631632
if (
632-
framework.lower() == "pytorch"
633-
or framework.lower() == "tensorflow"
633+
framework.lower() in multiple_version_supported_framework_list
634634
and target_instance_type is not None
635635
and re.match("(?=^ml_)(?!ml_inf)", target_instance_type) is not None
636636
and framework_version is not None

tests/unit/sagemaker/model/test_neo.py

+23
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,29 @@ def test_compile_with_pytorch_neo_in_ml_inf(session):
330330
)
331331

332332

333+
@patch("sagemaker.session.Session")
334+
def test_compile_with_tensorflow_neo_in_ml_inf(session):
335+
session.return_value.boto_region_name = REGION
336+
337+
model = _create_model()
338+
model.compile(
339+
target_instance_family="ml_inf",
340+
input_shape={"data": [1, 3, 1024, 1024]},
341+
output_path="s3://output",
342+
role="role",
343+
framework="tensorflow",
344+
framework_version="1.15",
345+
job_name="compile-model",
346+
)
347+
348+
assert (
349+
"{}.dkr.ecr.{}.amazonaws.com/sagemaker-inference-tensorflow:1.15-cpu-py3".format(
350+
NEO_REGION_ACCOUNT, REGION
351+
)
352+
!= model.image_uri
353+
)
354+
355+
333356
def test_compile_validates_framework_version(sagemaker_session):
334357
sagemaker_session.wait_for_compilation_job = Mock(
335358
return_value={

0 commit comments

Comments
 (0)