Skip to content

Commit 64d49cb

Browse files
change: use recommended inference image uri from Neo API
1 parent 217e8c8 commit 64d49cb

File tree

3 files changed

+46
-49
lines changed

3 files changed

+46
-49
lines changed

src/sagemaker/model.py

+1-6
Original file line numberDiff line numberDiff line change
@@ -658,12 +658,7 @@ def compile(
658658
if target_instance_family == "ml_eia2":
659659
pass
660660
elif target_instance_family.startswith("ml_"):
661-
self.image_uri = self._compilation_image_uri(
662-
self.sagemaker_session.boto_region_name,
663-
target_instance_family,
664-
framework,
665-
framework_version,
666-
)
661+
self.image_uri = job_status.get("InferenceImage", None)
667662
self._is_compiled_model = True
668663
else:
669664
LOGGER.warning(

tests/unit/sagemaker/model/test_neo.py

+38-35
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,15 @@
2020
MODEL_DATA = "s3://bucket/model.tar.gz"
2121
MODEL_IMAGE = "mi"
2222

23+
IMAGE_URI = "inference-container-uri"
24+
2325
REGION = "us-west-2"
2426

2527
NEO_REGION_ACCOUNT = "301217895009"
2628
DESCRIBE_COMPILATION_JOB_RESPONSE = {
2729
"CompilationJobStatus": "Completed",
2830
"ModelArtifacts": {"S3ModelArtifacts": "s3://output-path/model.tar.gz"},
31+
"InferenceImage": IMAGE_URI,
2932
}
3033

3134

@@ -52,12 +55,7 @@ def test_compile_model_for_inferentia(sagemaker_session):
5255
framework_version="1.15.0",
5356
job_name="compile-model",
5457
)
55-
assert (
56-
"{}.dkr.ecr.{}.amazonaws.com/sagemaker-neo-tensorflow:1.15.0-inf-py3".format(
57-
NEO_REGION_ACCOUNT, REGION
58-
)
59-
== model.image_uri
60-
)
58+
assert DESCRIBE_COMPILATION_JOB_RESPONSE["InferenceImage"] == model.image_uri
6159
assert model._is_compiled_model is True
6260

6361

@@ -271,11 +269,12 @@ def test_deploy_add_compiled_model_suffix_to_endpoint_name_from_model_name(sagem
271269
assert model.endpoint_name.startswith("{}-ml-c4".format(model_name))
272270

273271

274-
@patch("sagemaker.session.Session")
275-
def test_compile_with_framework_version_15(session):
276-
session.return_value.boto_region_name = REGION
272+
def test_compile_with_framework_version_15(sagemaker_session):
273+
sagemaker_session.wait_for_compilation_job = Mock(
274+
return_value=DESCRIBE_COMPILATION_JOB_RESPONSE
275+
)
277276

278-
model = _create_model()
277+
model = _create_model(sagemaker_session)
279278
model.compile(
280279
target_instance_family="ml_c4",
281280
input_shape={"data": [1, 3, 1024, 1024]},
@@ -286,14 +285,15 @@ def test_compile_with_framework_version_15(session):
286285
job_name="compile-model",
287286
)
288287

289-
assert "1.5" in model.image_uri
288+
assert IMAGE_URI == model.image_uri
290289

291290

292-
@patch("sagemaker.session.Session")
293-
def test_compile_with_framework_version_16(session):
294-
session.return_value.boto_region_name = REGION
291+
def test_compile_with_framework_version_16(sagemaker_session):
292+
sagemaker_session.wait_for_compilation_job = Mock(
293+
return_value=DESCRIBE_COMPILATION_JOB_RESPONSE
294+
)
295295

296-
model = _create_model()
296+
model = _create_model(sagemaker_session)
297297
model.compile(
298298
target_instance_family="ml_c4",
299299
input_shape={"data": [1, 3, 1024, 1024]},
@@ -304,26 +304,7 @@ def test_compile_with_framework_version_16(session):
304304
job_name="compile-model",
305305
)
306306

307-
assert "1.6" in model.image_uri
308-
309-
310-
@patch("sagemaker.session.Session")
311-
def test_compile_validates_framework_version(session):
312-
session.return_value.boto_region_name = REGION
313-
314-
model = _create_model()
315-
with pytest.raises(ValueError) as e:
316-
model.compile(
317-
target_instance_family="ml_c4",
318-
input_shape={"data": [1, 3, 1024, 1024]},
319-
output_path="s3://output",
320-
role="role",
321-
framework="pytorch",
322-
framework_version="1.6.1",
323-
job_name="compile-model",
324-
)
325-
326-
assert "Unsupported neo-pytorch version: 1.6.1." in str(e)
307+
assert IMAGE_URI == model.image_uri
327308

328309

329310
@patch("sagemaker.session.Session")
@@ -347,3 +328,25 @@ def test_compile_with_pytorch_neo_in_ml_inf(session):
347328
)
348329
!= model.image_uri
349330
)
331+
332+
333+
def test_compile_validates_framework_version(sagemaker_session):
334+
sagemaker_session.wait_for_compilation_job = Mock(
335+
return_value={
336+
"CompilationJobStatus": "Completed",
337+
"ModelArtifacts": {"S3ModelArtifacts": "s3://output-path/model.tar.gz"},
338+
"InferenceImage": None,
339+
}
340+
)
341+
model = _create_model(sagemaker_session)
342+
model.compile(
343+
target_instance_family="ml_c4",
344+
input_shape={"data": [1, 3, 1024, 1024]},
345+
output_path="s3://output",
346+
role="role",
347+
framework="pytorch",
348+
framework_version="1.6.1",
349+
job_name="compile-model",
350+
)
351+
352+
assert model.image_uri is None

tests/unit/test_mxnet.py

+7-8
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@
6868

6969
ENV_INPUT = {"env_key1": "env_val1", "env_key2": "env_val2", "env_key3": "env_val3"}
7070

71+
INFERENCE_IMAGE_URI = "inference-uri"
72+
7173

7274
@pytest.fixture()
7375
def sagemaker_session():
@@ -83,7 +85,10 @@ def sagemaker_session():
8385
)
8486

8587
describe = {"ModelArtifacts": {"S3ModelArtifacts": "s3://m/m.tar.gz"}}
86-
describe_compilation = {"ModelArtifacts": {"S3ModelArtifacts": "s3://m/model_c5.tar.gz"}}
88+
describe_compilation = {
89+
"ModelArtifacts": {"S3ModelArtifacts": "s3://m/model_c5.tar.gz"},
90+
"InferenceImage": INFERENCE_IMAGE_URI,
91+
}
8792
session.sagemaker_client.create_model_package.side_effect = MODEL_PKG_RESPONSE
8893
session.sagemaker_client.describe_training_job = Mock(return_value=describe)
8994
session.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC)
@@ -195,12 +200,6 @@ def _create_compilation_job(input_shape, output_location):
195200
}
196201

197202

198-
def _neo_inference_image(mxnet_version):
199-
return "301217895009.dkr.ecr.us-west-2.amazonaws.com/sagemaker-inference-{}:{}-cpu-py3".format(
200-
FRAMEWORK.lower(), mxnet_version
201-
)
202-
203-
204203
@patch("sagemaker.estimator.name_from_base")
205204
@patch("sagemaker.utils.create_tar_file", MagicMock())
206205
def test_create_model(
@@ -422,7 +421,7 @@ def test_mxnet_neo(time, strftime, sagemaker_session, neo_mxnet_version):
422421
actual_compile_model_args = sagemaker_session.method_calls[3][2]
423422
assert expected_compile_model_args == actual_compile_model_args
424423

425-
assert compiled_model.image_uri == _neo_inference_image(neo_mxnet_version)
424+
assert compiled_model.image_uri == INFERENCE_IMAGE_URI
426425

427426
predictor = mx.deploy(1, CPU, use_compiled_model=True)
428427
assert isinstance(predictor, MXNetPredictor)

0 commit comments

Comments
 (0)