Skip to content

Commit a839f9f

Browse files
author
Mark Bunday
committed
fix: Return ARM XGB/SKLearn tags if image_scope is inference_graviton
1 parent 885423c commit a839f9f

File tree

2 files changed

+69
-18
lines changed

2 files changed

+69
-18
lines changed

src/sagemaker/image_uris.py

+27-16
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
XGBOOST_FRAMEWORK = "xgboost"
3535
SKLEARN_FRAMEWORK = "sklearn"
3636
TRAINIUM_ALLOWED_FRAMEWORKS = "pytorch"
37+
INFERENCE_GRAVITON = "inference_graviton"
3738

3839

3940
@override_pipeline_parameter_var
@@ -146,8 +147,9 @@ def retrieve(
146147
)
147148

148149
if training_compiler_config and (framework == HUGGING_FACE_FRAMEWORK):
150+
final_image_scope = image_scope
149151
config = _config_for_framework_and_scope(
150-
framework + "-training-compiler", image_scope, accelerator_type
152+
framework + "-training-compiler", final_image_scope, accelerator_type
151153
)
152154
else:
153155
_framework = framework
@@ -234,6 +236,7 @@ def retrieve(
234236
tag = _get_image_tag(
235237
container_version,
236238
distribution,
239+
final_image_scope,
237240
framework,
238241
inference_tool,
239242
instance_type,
@@ -266,6 +269,7 @@ def _get_instance_type_family(instance_type):
266269
def _get_image_tag(
267270
container_version,
268271
distribution,
272+
final_image_scope,
269273
framework,
270274
inference_tool,
271275
instance_type,
@@ -276,20 +280,27 @@ def _get_image_tag(
276280
):
277281
"""Return image tag based on framework, container, and compute configuration(s)."""
278282
instance_type_family = _get_instance_type_family(instance_type)
279-
if (
280-
framework in (XGBOOST_FRAMEWORK, SKLEARN_FRAMEWORK)
281-
and instance_type_family in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
282-
):
283-
version_to_arm64_tag_mapping = {
284-
"xgboost": {
285-
"1.5-1": "1.5-1-arm64",
286-
"1.3-1": "1.3-1-arm64",
287-
},
288-
"sklearn": {
289-
"1.0-1": "1.0-1-arm64-cpu-py3",
290-
},
291-
}
292-
tag = version_to_arm64_tag_mapping[framework][version]
283+
if framework in (XGBOOST_FRAMEWORK, SKLEARN_FRAMEWORK):
284+
if instance_type_family and final_image_scope == INFERENCE_GRAVITON:
285+
_validate_arg(
286+
instance_type_family,
287+
GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY,
288+
"instance type for Graviton images",
289+
)
290+
if (
291+
instance_type_family in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
292+
or final_image_scope == INFERENCE_GRAVITON
293+
):
294+
version_to_arm64_tag_mapping = {
295+
"xgboost": {
296+
"1.5-1": "1.5-1-arm64",
297+
"1.3-1": "1.3-1-arm64",
298+
},
299+
"sklearn": {
300+
"1.0-1": "1.0-1-arm64-cpu-py3",
301+
},
302+
}
303+
tag = version_to_arm64_tag_mapping[framework][version]
293304
else:
294305
tag = _format_tag(tag_prefix, processor, py_version, container_version, inference_tool)
295306

@@ -375,7 +386,7 @@ def _get_final_image_scope(framework, instance_type, image_scope):
375386
framework in GRAVITON_ALLOWED_FRAMEWORKS
376387
and _get_instance_type_family(instance_type) in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
377388
):
378-
return "inference_graviton"
389+
return INFERENCE_GRAVITON
379390
if image_scope is None and framework in (XGBOOST_FRAMEWORK, SKLEARN_FRAMEWORK):
380391
# Preserves backwards compatibility with XGB/SKLearn configs which no
381392
# longer define top-level "scope" keys after introducing support for

tests/unit/sagemaker/image_uris/test_graviton.py

+42-2
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def test_graviton_pytorch(graviton_pytorch_version):
8989
_test_graviton_framework_uris("pytorch", graviton_pytorch_version)
9090

9191

92-
def test_graviton_xgboost(graviton_xgboost_versions):
92+
def test_graviton_xgboost_instance_type_specified(graviton_xgboost_versions):
9393
for xgboost_version in graviton_xgboost_versions:
9494
for instance_type in GRAVITON_INSTANCE_TYPES:
9595
uri = image_uris.retrieve(
@@ -102,6 +102,33 @@ def test_graviton_xgboost(graviton_xgboost_versions):
102102
assert expected == uri
103103

104104

105+
def test_graviton_xgboost_image_scope_specified(graviton_xgboost_versions):
106+
for xgboost_version in graviton_xgboost_versions:
107+
for instance_type in GRAVITON_INSTANCE_TYPES:
108+
uri = image_uris.retrieve(
109+
"xgboost", "us-west-2", version=xgboost_version, image_scope="inference_graviton"
110+
)
111+
expected = (
112+
"246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-xgboost:"
113+
f"{xgboost_version}-arm64"
114+
)
115+
assert expected == uri
116+
117+
118+
def test_graviton_xgboost_image_scope_specified_x86_instance(graviton_xgboost_versions):
119+
for xgboost_version in graviton_xgboost_versions:
120+
for instance_type in GRAVITON_INSTANCE_TYPES:
121+
with pytest.raises(ValueError) as error:
122+
image_uris.retrieve(
123+
"xgboost",
124+
"us-west-2",
125+
version=xgboost_version,
126+
image_scope="inference_graviton",
127+
instance_type="ml.m5.xlarge",
128+
)
129+
assert f"Unsupported instance type for Graviton images: m5." in str(error)
130+
131+
105132
def test_graviton_xgboost_unsupported_version(graviton_xgboost_unsupported_versions):
106133
for xgboost_version in graviton_xgboost_unsupported_versions:
107134
for instance_type in GRAVITON_INSTANCE_TYPES:
@@ -112,7 +139,7 @@ def test_graviton_xgboost_unsupported_version(graviton_xgboost_unsupported_versi
112139
assert f"Unsupported xgboost version: {xgboost_version}." in str(error)
113140

114141

115-
def test_graviton_sklearn(graviton_sklearn_versions):
142+
def test_graviton_sklearn_instance_type_specified(graviton_sklearn_versions):
116143
for sklearn_version in graviton_sklearn_versions:
117144
for instance_type in GRAVITON_INSTANCE_TYPES:
118145
uri = image_uris.retrieve(
@@ -125,6 +152,19 @@ def test_graviton_sklearn(graviton_sklearn_versions):
125152
assert expected == uri
126153

127154

155+
def test_graviton_sklearn_image_scope_specified(graviton_sklearn_versions):
156+
for sklearn_version in graviton_sklearn_versions:
157+
for instance_type in GRAVITON_INSTANCE_TYPES:
158+
uri = image_uris.retrieve(
159+
"sklearn", "us-west-2", version=sklearn_version, image_scope="inference_graviton"
160+
)
161+
expected = (
162+
"246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-scikit-learn:"
163+
f"{sklearn_version}-arm64-cpu-py3"
164+
)
165+
assert expected == uri
166+
167+
128168
def test_graviton_sklearn_unsupported_version(graviton_sklearn_unsupported_versions):
129169
for sklearn_version in graviton_sklearn_unsupported_versions:
130170
for instance_type in GRAVITON_INSTANCE_TYPES:

0 commit comments

Comments
 (0)