34
34
XGBOOST_FRAMEWORK = "xgboost"
35
35
SKLEARN_FRAMEWORK = "sklearn"
36
36
TRAINIUM_ALLOWED_FRAMEWORKS = "pytorch"
37
+ INFERENCE_GRAVITON = "inference_graviton"
37
38
38
39
39
40
@override_pipeline_parameter_var
@@ -75,8 +76,8 @@ def retrieve(
75
76
accelerator_type (str): Elastic Inference accelerator type. For more, see
76
77
https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html.
77
78
image_scope (str): The image type, i.e. what it is used for.
78
- Valid values: "training", "inference", "eia". If ``accelerator_type`` is set,
79
- ``image_scope`` is ignored.
79
+ Valid values: "training", "inference", "inference_graviton", " eia".
80
+ If ``accelerator_type`` is set, ``image_scope`` is ignored.
80
81
container_version (str): the version of docker image.
81
82
Ideally the value of parameter should be created inside the framework.
82
83
For custom use, see the list of supported container versions:
@@ -146,8 +147,9 @@ def retrieve(
146
147
)
147
148
148
149
if training_compiler_config and (framework == HUGGING_FACE_FRAMEWORK ):
150
+ final_image_scope = image_scope
149
151
config = _config_for_framework_and_scope (
150
- framework + "-training-compiler" , image_scope , accelerator_type
152
+ framework + "-training-compiler" , final_image_scope , accelerator_type
151
153
)
152
154
else :
153
155
_framework = framework
@@ -234,6 +236,7 @@ def retrieve(
234
236
tag = _get_image_tag (
235
237
container_version ,
236
238
distribution ,
239
+ final_image_scope ,
237
240
framework ,
238
241
inference_tool ,
239
242
instance_type ,
@@ -266,6 +269,7 @@ def _get_instance_type_family(instance_type):
266
269
def _get_image_tag (
267
270
container_version ,
268
271
distribution ,
272
+ final_image_scope ,
269
273
framework ,
270
274
inference_tool ,
271
275
instance_type ,
@@ -276,20 +280,29 @@ def _get_image_tag(
276
280
):
277
281
"""Return image tag based on framework, container, and compute configuration(s)."""
278
282
instance_type_family = _get_instance_type_family (instance_type )
279
- if (
280
- framework in (XGBOOST_FRAMEWORK , SKLEARN_FRAMEWORK )
281
- and instance_type_family in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
282
- ):
283
- version_to_arm64_tag_mapping = {
284
- "xgboost" : {
285
- "1.5-1" : "1.5-1-arm64" ,
286
- "1.3-1" : "1.3-1-arm64" ,
287
- },
288
- "sklearn" : {
289
- "1.0-1" : "1.0-1-arm64-cpu-py3" ,
290
- },
291
- }
292
- tag = version_to_arm64_tag_mapping [framework ][version ]
283
+ if framework in (XGBOOST_FRAMEWORK , SKLEARN_FRAMEWORK ):
284
+ if instance_type_family and final_image_scope == INFERENCE_GRAVITON :
285
+ _validate_arg (
286
+ instance_type_family ,
287
+ GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY ,
288
+ "instance type" ,
289
+ )
290
+ if (
291
+ instance_type_family in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
292
+ or final_image_scope == INFERENCE_GRAVITON
293
+ ):
294
+ version_to_arm64_tag_mapping = {
295
+ "xgboost" : {
296
+ "1.5-1" : "1.5-1-arm64" ,
297
+ "1.3-1" : "1.3-1-arm64" ,
298
+ },
299
+ "sklearn" : {
300
+ "1.0-1" : "1.0-1-arm64-cpu-py3" ,
301
+ },
302
+ }
303
+ tag = version_to_arm64_tag_mapping [framework ][version ]
304
+ else :
305
+ tag = _format_tag (tag_prefix , processor , py_version , container_version , inference_tool )
293
306
else :
294
307
tag = _format_tag (tag_prefix , processor , py_version , container_version , inference_tool )
295
308
@@ -375,7 +388,7 @@ def _get_final_image_scope(framework, instance_type, image_scope):
375
388
framework in GRAVITON_ALLOWED_FRAMEWORKS
376
389
and _get_instance_type_family (instance_type ) in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
377
390
):
378
- return "inference_graviton"
391
+ return INFERENCE_GRAVITON
379
392
if image_scope is None and framework in (XGBOOST_FRAMEWORK , SKLEARN_FRAMEWORK ):
380
393
# Preserves backwards compatibility with XGB/SKLearn configs which no
381
394
# longer define top-level "scope" keys after introducing support for
0 commit comments