25
25
from sagemaker .jumpstart import artifacts
26
26
from sagemaker .workflow import is_pipeline_variable
27
27
from sagemaker .workflow .utilities import override_pipeline_parameter_var
28
- from sagemaker .fw_utils import GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY , GRAVITON_ALLOWED_FRAMEWORKS
28
+ from sagemaker .fw_utils import (
29
+ GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY ,
30
+ GRAVITON_ALLOWED_FRAMEWORKS ,
31
+ GRAVITON_ALLOWED_REGIONS ,
32
+ SKLEARN_GRAVITON_INFERENCE_ENABLED_VERSIONS ,
33
+ XGBOOST_GRAVITON_INFERENCE_ENABLED_VERSIONS ,
34
+ )
29
35
30
36
logger = logging .getLogger (__name__ )
31
37
32
38
ECR_URI_TEMPLATE = "{registry}.dkr.{hostname}/{repository}"
33
39
HUGGING_FACE_FRAMEWORK = "huggingface"
40
+ XGBOOST_FRAMEWORK = "xgboost"
41
+ SKLEARN_FRAMEWORK = "sklearn"
42
+ INSTANCE_TYPE_REGEX = r"^ml[\._]([a-z\d]+)\.?\w*$"
34
43
35
44
36
45
@override_pipeline_parameter_var
@@ -244,6 +253,25 @@ def retrieve(
244
253
if key in container_versions :
245
254
tag = "-" .join ([tag , container_versions [key ]])
246
255
256
+ if framework in (XGBOOST_FRAMEWORK , SKLEARN_FRAMEWORK ):
257
+ match = re .match (INSTANCE_TYPE_REGEX , instance_type )
258
+ if match and match [1 ] in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY :
259
+ _validate_arg (region , GRAVITON_ALLOWED_REGIONS , "Graviton region" )
260
+ if framework == XGBOOST_FRAMEWORK :
261
+ _validate_arg (
262
+ version ,
263
+ XGBOOST_GRAVITON_INFERENCE_ENABLED_VERSIONS ,
264
+ "xgboost version for Graviton instances"
265
+ )
266
+ tag = f"{ version } -arm64"
267
+ else :
268
+ _validate_arg (
269
+ version ,
270
+ SKLEARN_GRAVITON_INFERENCE_ENABLED_VERSIONS ,
271
+ "sklearn version for Graviton instances"
272
+ )
273
+ tag = f"{ version } -arm64-cpu-py3"
274
+
247
275
if tag :
248
276
repo += ":{}" .format (tag )
249
277
@@ -295,7 +323,7 @@ def config_for_framework(framework):
295
323
def _get_image_scope_for_instance_type (framework , instance_type , image_scope ):
296
324
"""Extract the image scope from instance type."""
297
325
if framework in GRAVITON_ALLOWED_FRAMEWORKS and isinstance (instance_type , str ):
298
- match = re .match (r"^ml[\._]([a-z\d]+)\.?\w*$" , instance_type )
326
+ match = re .match (INSTANCE_TYPE_REGEX , instance_type )
299
327
if match and match [1 ] in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY :
300
328
return "inference_graviton"
301
329
return image_scope
@@ -304,7 +332,7 @@ def _get_image_scope_for_instance_type(framework, instance_type, image_scope):
304
332
def _get_inference_tool (inference_tool , instance_type ):
305
333
"""Extract the inference tool name from instance type."""
306
334
if not inference_tool and instance_type :
307
- match = re .match (r"^ml[\._]([a-z\d]+)\.?\w*$" , instance_type )
335
+ match = re .match (INSTANCE_TYPE_REGEX , instance_type )
308
336
if match and match [1 ].startswith ("inf" ):
309
337
return "neuron"
310
338
return inference_tool
@@ -385,7 +413,7 @@ def _processor(instance_type, available_processors, serverless_inference_config=
385
413
processor = "neuron"
386
414
else :
387
415
# looks for either "ml.<family>.<size>" or "ml_<family>"
388
- match = re .match (r"^ml[\._]([a-z\d]+)\.?\w*$" , instance_type )
416
+ match = re .match (INSTANCE_TYPE_REGEX , instance_type )
389
417
if match :
390
418
family = match [1 ]
391
419
@@ -415,7 +443,7 @@ def _should_auto_select_container_version(instance_type, distribution):
415
443
p4d = False
416
444
if instance_type :
417
445
# looks for either "ml.<family>.<size>" or "ml_<family>"
418
- match = re .match (r"^ml[\._]([a-z\d]+)\.?\w*$" , instance_type )
446
+ match = re .match (INSTANCE_TYPE_REGEX , instance_type )
419
447
if match :
420
448
family = match [1 ]
421
449
p4d = family == "p4d"
0 commit comments