Skip to content

Commit ab2b4f4

Browse files
author
Mark Bunday
committed
fix: Return ARM XGB/SKLearn tags if image_scope is inference_graviton
1 parent 885423c commit ab2b4f4

File tree

2 files changed

+35
-6
lines changed

2 files changed

+35
-6
lines changed

src/sagemaker/image_uris.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
XGBOOST_FRAMEWORK = "xgboost"
3535
SKLEARN_FRAMEWORK = "sklearn"
3636
TRAINIUM_ALLOWED_FRAMEWORKS = "pytorch"
37+
INFERENCE_GRAVITON = "inference_graviton"
3738

3839

3940
@override_pipeline_parameter_var
@@ -234,6 +235,7 @@ def retrieve(
234235
tag = _get_image_tag(
235236
container_version,
236237
distribution,
238+
final_image_scope,
237239
framework,
238240
inference_tool,
239241
instance_type,
@@ -266,6 +268,7 @@ def _get_instance_type_family(instance_type):
266268
def _get_image_tag(
267269
container_version,
268270
distribution,
271+
final_image_scope,
269272
framework,
270273
inference_tool,
271274
instance_type,
@@ -276,9 +279,9 @@ def _get_image_tag(
276279
):
277280
"""Return image tag based on framework, container, and compute configuration(s)."""
278281
instance_type_family = _get_instance_type_family(instance_type)
279-
if (
280-
framework in (XGBOOST_FRAMEWORK, SKLEARN_FRAMEWORK)
281-
and instance_type_family in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
282+
if framework in (XGBOOST_FRAMEWORK, SKLEARN_FRAMEWORK) and (
283+
instance_type_family in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
284+
or final_image_scope == INFERENCE_GRAVITON
282285
):
283286
version_to_arm64_tag_mapping = {
284287
"xgboost": {
@@ -375,7 +378,7 @@ def _get_final_image_scope(framework, instance_type, image_scope):
375378
framework in GRAVITON_ALLOWED_FRAMEWORKS
376379
and _get_instance_type_family(instance_type) in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
377380
):
378-
return "inference_graviton"
381+
return INFERENCE_GRAVITON
379382
if image_scope is None and framework in (XGBOOST_FRAMEWORK, SKLEARN_FRAMEWORK):
380383
# Preserves backwards compatibility with XGB/SKLearn configs which no
381384
# longer define top-level "scope" keys after introducing support for

tests/unit/sagemaker/image_uris/test_graviton.py

+28-2
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def test_graviton_pytorch(graviton_pytorch_version):
8989
_test_graviton_framework_uris("pytorch", graviton_pytorch_version)
9090

9191

92-
def test_graviton_xgboost(graviton_xgboost_versions):
92+
def test_graviton_xgboost_instance_type_specified(graviton_xgboost_versions):
9393
for xgboost_version in graviton_xgboost_versions:
9494
for instance_type in GRAVITON_INSTANCE_TYPES:
9595
uri = image_uris.retrieve(
@@ -102,6 +102,19 @@ def test_graviton_xgboost(graviton_xgboost_versions):
102102
assert expected == uri
103103

104104

105+
def test_graviton_xgboost_image_scope_specified(graviton_xgboost_versions):
106+
for xgboost_version in graviton_xgboost_versions:
107+
for instance_type in GRAVITON_INSTANCE_TYPES:
108+
uri = image_uris.retrieve(
109+
"xgboost", "us-west-2", version=xgboost_version, image_scope="inference_graviton"
110+
)
111+
expected = (
112+
"246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-xgboost:"
113+
f"{xgboost_version}-arm64"
114+
)
115+
assert expected == uri
116+
117+
105118
def test_graviton_xgboost_unsupported_version(graviton_xgboost_unsupported_versions):
106119
for xgboost_version in graviton_xgboost_unsupported_versions:
107120
for instance_type in GRAVITON_INSTANCE_TYPES:
@@ -112,7 +125,7 @@ def test_graviton_xgboost_unsupported_version(graviton_xgboost_unsupported_versi
112125
assert f"Unsupported xgboost version: {xgboost_version}." in str(error)
113126

114127

115-
def test_graviton_sklearn(graviton_sklearn_versions):
128+
def test_graviton_sklearn_instance_type_specified(graviton_sklearn_versions):
116129
for sklearn_version in graviton_sklearn_versions:
117130
for instance_type in GRAVITON_INSTANCE_TYPES:
118131
uri = image_uris.retrieve(
@@ -125,6 +138,19 @@ def test_graviton_sklearn(graviton_sklearn_versions):
125138
assert expected == uri
126139

127140

141+
def test_graviton_sklearn_image_scope_specified(graviton_sklearn_versions):
142+
for sklearn_version in graviton_sklearn_versions:
143+
for instance_type in GRAVITON_INSTANCE_TYPES:
144+
uri = image_uris.retrieve(
145+
"sklearn", "us-west-2", version=sklearn_version, image_scope="inference_graviton"
146+
)
147+
expected = (
148+
"246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-scikit-learn:"
149+
f"{sklearn_version}-arm64-cpu-py3"
150+
)
151+
assert expected == uri
152+
153+
128154
def test_graviton_sklearn_unsupported_version(graviton_sklearn_unsupported_versions):
129155
for sklearn_version in graviton_sklearn_unsupported_versions:
130156
for instance_type in GRAVITON_INSTANCE_TYPES:

0 commit comments

Comments
 (0)