Skip to content

Commit 08c3a7b

Browse files
change: use recommended inference image uri from Neo API
1 parent 217e8c8 commit 08c3a7b

File tree

3 files changed

+12
-41
lines changed

3 files changed

+12
-41
lines changed

src/sagemaker/model.py

+1-6
Original file line numberDiff line numberDiff line change
@@ -658,12 +658,7 @@ def compile(
658658
if target_instance_family == "ml_eia2":
659659
pass
660660
elif target_instance_family.startswith("ml_"):
661-
self.image_uri = self._compilation_image_uri(
662-
self.sagemaker_session.boto_region_name,
663-
target_instance_family,
664-
framework,
665-
framework_version,
666-
)
661+
self.image_uri = job_status["InferenceImage"]
667662
self._is_compiled_model = True
668663
else:
669664
LOGGER.warning(

tests/unit/sagemaker/model/test_neo.py

+4-27
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
DESCRIBE_COMPILATION_JOB_RESPONSE = {
2727
"CompilationJobStatus": "Completed",
2828
"ModelArtifacts": {"S3ModelArtifacts": "s3://output-path/model.tar.gz"},
29+
"InferenceImage": "inference-container-uri",
2930
}
3031

3132

@@ -52,12 +53,7 @@ def test_compile_model_for_inferentia(sagemaker_session):
5253
framework_version="1.15.0",
5354
job_name="compile-model",
5455
)
55-
assert (
56-
"{}.dkr.ecr.{}.amazonaws.com/sagemaker-neo-tensorflow:1.15.0-inf-py3".format(
57-
NEO_REGION_ACCOUNT, REGION
58-
)
59-
== model.image_uri
60-
)
56+
assert DESCRIBE_COMPILATION_JOB_RESPONSE["InferenceImage"] == model.image_uri
6157
assert model._is_compiled_model is True
6258

6359

@@ -286,7 +282,7 @@ def test_compile_with_framework_version_15(session):
286282
job_name="compile-model",
287283
)
288284

289-
assert "1.5" in model.image_uri
285+
assert model.image_uri is not None
290286

291287

292288
@patch("sagemaker.session.Session")
@@ -304,26 +300,7 @@ def test_compile_with_framework_version_16(session):
304300
job_name="compile-model",
305301
)
306302

307-
assert "1.6" in model.image_uri
308-
309-
310-
@patch("sagemaker.session.Session")
311-
def test_compile_validates_framework_version(session):
312-
session.return_value.boto_region_name = REGION
313-
314-
model = _create_model()
315-
with pytest.raises(ValueError) as e:
316-
model.compile(
317-
target_instance_family="ml_c4",
318-
input_shape={"data": [1, 3, 1024, 1024]},
319-
output_path="s3://output",
320-
role="role",
321-
framework="pytorch",
322-
framework_version="1.6.1",
323-
job_name="compile-model",
324-
)
325-
326-
assert "Unsupported neo-pytorch version: 1.6.1." in str(e)
303+
assert model.image_uri is not None
327304

328305

329306
@patch("sagemaker.session.Session")

tests/unit/test_mxnet.py

+7-8
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@
6868

6969
ENV_INPUT = {"env_key1": "env_val1", "env_key2": "env_val2", "env_key3": "env_val3"}
7070

71+
INFERENCE_IMAGE_URI = "inference-uri"
72+
7173

7274
@pytest.fixture()
7375
def sagemaker_session():
@@ -83,7 +85,10 @@ def sagemaker_session():
8385
)
8486

8587
describe = {"ModelArtifacts": {"S3ModelArtifacts": "s3://m/m.tar.gz"}}
86-
describe_compilation = {"ModelArtifacts": {"S3ModelArtifacts": "s3://m/model_c5.tar.gz"}}
88+
describe_compilation = {
89+
"ModelArtifacts": {"S3ModelArtifacts": "s3://m/model_c5.tar.gz"},
90+
"InferenceImage": INFERENCE_IMAGE_URI,
91+
}
8792
session.sagemaker_client.create_model_package.side_effect = MODEL_PKG_RESPONSE
8893
session.sagemaker_client.describe_training_job = Mock(return_value=describe)
8994
session.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC)
@@ -195,12 +200,6 @@ def _create_compilation_job(input_shape, output_location):
195200
}
196201

197202

198-
def _neo_inference_image(mxnet_version):
199-
return "301217895009.dkr.ecr.us-west-2.amazonaws.com/sagemaker-inference-{}:{}-cpu-py3".format(
200-
FRAMEWORK.lower(), mxnet_version
201-
)
202-
203-
204203
@patch("sagemaker.estimator.name_from_base")
205204
@patch("sagemaker.utils.create_tar_file", MagicMock())
206205
def test_create_model(
@@ -422,7 +421,7 @@ def test_mxnet_neo(time, strftime, sagemaker_session, neo_mxnet_version):
422421
actual_compile_model_args = sagemaker_session.method_calls[3][2]
423422
assert expected_compile_model_args == actual_compile_model_args
424423

425-
assert compiled_model.image_uri == _neo_inference_image(neo_mxnet_version)
424+
assert compiled_model.image_uri == INFERENCE_IMAGE_URI
426425

427426
predictor = mx.deploy(1, CPU, use_compiled_model=True)
428427
assert isinstance(predictor, MXNetPredictor)

0 commit comments

Comments
 (0)