Skip to content

Commit 4bdabac

Browse files
mabundayMark Bunday
authored and
Namrata Madan
committed
fix: Return ARM XGB/SKLearn tags if image_scope is inference_graviton (aws#3449)
Co-authored-by: Mark Bunday <[email protected]>
1 parent 663f635 commit 4bdabac

File tree

2 files changed

+87
-20
lines changed

2 files changed

+87
-20
lines changed

src/sagemaker/image_uris.py

+31-18
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
XGBOOST_FRAMEWORK = "xgboost"
3535
SKLEARN_FRAMEWORK = "sklearn"
3636
TRAINIUM_ALLOWED_FRAMEWORKS = "pytorch"
37+
INFERENCE_GRAVITON = "inference_graviton"
3738

3839

3940
@override_pipeline_parameter_var
@@ -75,8 +76,8 @@ def retrieve(
7576
accelerator_type (str): Elastic Inference accelerator type. For more, see
7677
https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html.
7778
image_scope (str): The image type, i.e. what it is used for.
78-
Valid values: "training", "inference", "eia". If ``accelerator_type`` is set,
79-
``image_scope`` is ignored.
79+
Valid values: "training", "inference", "inference_graviton", "eia".
80+
If ``accelerator_type`` is set, ``image_scope`` is ignored.
8081
container_version (str): the version of docker image.
8182
Ideally the value of parameter should be created inside the framework.
8283
For custom use, see the list of supported container versions:
@@ -146,8 +147,9 @@ def retrieve(
146147
)
147148

148149
if training_compiler_config and (framework == HUGGING_FACE_FRAMEWORK):
150+
final_image_scope = image_scope
149151
config = _config_for_framework_and_scope(
150-
framework + "-training-compiler", image_scope, accelerator_type
152+
framework + "-training-compiler", final_image_scope, accelerator_type
151153
)
152154
else:
153155
_framework = framework
@@ -234,6 +236,7 @@ def retrieve(
234236
tag = _get_image_tag(
235237
container_version,
236238
distribution,
239+
final_image_scope,
237240
framework,
238241
inference_tool,
239242
instance_type,
@@ -266,6 +269,7 @@ def _get_instance_type_family(instance_type):
266269
def _get_image_tag(
267270
container_version,
268271
distribution,
272+
final_image_scope,
269273
framework,
270274
inference_tool,
271275
instance_type,
@@ -276,20 +280,29 @@ def _get_image_tag(
276280
):
277281
"""Return image tag based on framework, container, and compute configuration(s)."""
278282
instance_type_family = _get_instance_type_family(instance_type)
279-
if (
280-
framework in (XGBOOST_FRAMEWORK, SKLEARN_FRAMEWORK)
281-
and instance_type_family in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
282-
):
283-
version_to_arm64_tag_mapping = {
284-
"xgboost": {
285-
"1.5-1": "1.5-1-arm64",
286-
"1.3-1": "1.3-1-arm64",
287-
},
288-
"sklearn": {
289-
"1.0-1": "1.0-1-arm64-cpu-py3",
290-
},
291-
}
292-
tag = version_to_arm64_tag_mapping[framework][version]
283+
if framework in (XGBOOST_FRAMEWORK, SKLEARN_FRAMEWORK):
284+
if instance_type_family and final_image_scope == INFERENCE_GRAVITON:
285+
_validate_arg(
286+
instance_type_family,
287+
GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY,
288+
"instance type",
289+
)
290+
if (
291+
instance_type_family in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
292+
or final_image_scope == INFERENCE_GRAVITON
293+
):
294+
version_to_arm64_tag_mapping = {
295+
"xgboost": {
296+
"1.5-1": "1.5-1-arm64",
297+
"1.3-1": "1.3-1-arm64",
298+
},
299+
"sklearn": {
300+
"1.0-1": "1.0-1-arm64-cpu-py3",
301+
},
302+
}
303+
tag = version_to_arm64_tag_mapping[framework][version]
304+
else:
305+
tag = _format_tag(tag_prefix, processor, py_version, container_version, inference_tool)
293306
else:
294307
tag = _format_tag(tag_prefix, processor, py_version, container_version, inference_tool)
295308

@@ -375,7 +388,7 @@ def _get_final_image_scope(framework, instance_type, image_scope):
375388
framework in GRAVITON_ALLOWED_FRAMEWORKS
376389
and _get_instance_type_family(instance_type) in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
377390
):
378-
return "inference_graviton"
391+
return INFERENCE_GRAVITON
379392
if image_scope is None and framework in (XGBOOST_FRAMEWORK, SKLEARN_FRAMEWORK):
380393
# Preserves backwards compatibility with XGB/SKLearn configs which no
381394
# longer define top-level "scope" keys after introducing support for

tests/unit/sagemaker/image_uris/test_graviton.py

+56-2
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def test_graviton_pytorch(graviton_pytorch_version):
8989
_test_graviton_framework_uris("pytorch", graviton_pytorch_version)
9090

9191

92-
def test_graviton_xgboost(graviton_xgboost_versions):
92+
def test_graviton_xgboost_instance_type_specified(graviton_xgboost_versions):
9393
for xgboost_version in graviton_xgboost_versions:
9494
for instance_type in GRAVITON_INSTANCE_TYPES:
9595
uri = image_uris.retrieve(
@@ -102,6 +102,33 @@ def test_graviton_xgboost(graviton_xgboost_versions):
102102
assert expected == uri
103103

104104

105+
def test_graviton_xgboost_image_scope_specified(graviton_xgboost_versions):
106+
for xgboost_version in graviton_xgboost_versions:
107+
for instance_type in GRAVITON_INSTANCE_TYPES:
108+
uri = image_uris.retrieve(
109+
"xgboost", "us-west-2", version=xgboost_version, image_scope="inference_graviton"
110+
)
111+
expected = (
112+
"246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-xgboost:"
113+
f"{xgboost_version}-arm64"
114+
)
115+
assert expected == uri
116+
117+
118+
def test_graviton_xgboost_image_scope_specified_x86_instance(graviton_xgboost_versions):
119+
for xgboost_version in graviton_xgboost_versions:
120+
for instance_type in GRAVITON_INSTANCE_TYPES:
121+
with pytest.raises(ValueError) as error:
122+
image_uris.retrieve(
123+
"xgboost",
124+
"us-west-2",
125+
version=xgboost_version,
126+
image_scope="inference_graviton",
127+
instance_type="ml.m5.xlarge",
128+
)
129+
assert "Unsupported instance type: m5." in str(error)
130+
131+
105132
def test_graviton_xgboost_unsupported_version(graviton_xgboost_unsupported_versions):
106133
for xgboost_version in graviton_xgboost_unsupported_versions:
107134
for instance_type in GRAVITON_INSTANCE_TYPES:
@@ -112,7 +139,7 @@ def test_graviton_xgboost_unsupported_version(graviton_xgboost_unsupported_versi
112139
assert f"Unsupported xgboost version: {xgboost_version}." in str(error)
113140

114141

115-
def test_graviton_sklearn(graviton_sklearn_versions):
142+
def test_graviton_sklearn_instance_type_specified(graviton_sklearn_versions):
116143
for sklearn_version in graviton_sklearn_versions:
117144
for instance_type in GRAVITON_INSTANCE_TYPES:
118145
uri = image_uris.retrieve(
@@ -125,6 +152,19 @@ def test_graviton_sklearn(graviton_sklearn_versions):
125152
assert expected == uri
126153

127154

155+
def test_graviton_sklearn_image_scope_specified(graviton_sklearn_versions):
156+
for sklearn_version in graviton_sklearn_versions:
157+
for instance_type in GRAVITON_INSTANCE_TYPES:
158+
uri = image_uris.retrieve(
159+
"sklearn", "us-west-2", version=sklearn_version, image_scope="inference_graviton"
160+
)
161+
expected = (
162+
"246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-scikit-learn:"
163+
f"{sklearn_version}-arm64-cpu-py3"
164+
)
165+
assert expected == uri
166+
167+
128168
def test_graviton_sklearn_unsupported_version(graviton_sklearn_unsupported_versions):
129169
for sklearn_version in graviton_sklearn_unsupported_versions:
130170
for instance_type in GRAVITON_INSTANCE_TYPES:
@@ -138,6 +178,20 @@ def test_graviton_sklearn_unsupported_version(graviton_sklearn_unsupported_versi
138178
assert expected == uri
139179

140180

181+
def test_graviton_sklearn_image_scope_specified_x86_instance(graviton_sklearn_unsupported_versions):
182+
for sklearn_version in graviton_sklearn_unsupported_versions:
183+
for instance_type in GRAVITON_INSTANCE_TYPES:
184+
with pytest.raises(ValueError) as error:
185+
image_uris.retrieve(
186+
"sklearn",
187+
"us-west-2",
188+
version=sklearn_version,
189+
image_scope="inference_graviton",
190+
instance_type="ml.m5.xlarge",
191+
)
192+
assert "Unsupported instance type: m5." in str(error)
193+
194+
141195
def _expected_graviton_framework_uri(framework, version, region):
142196
return expected_uris.graviton_framework_uri(
143197
"{}-inference-graviton".format(framework),

0 commit comments

Comments
 (0)