25
25
from sagemaker .jumpstart import artifacts
26
26
from sagemaker .workflow import is_pipeline_variable
27
27
from sagemaker .workflow .utilities import override_pipeline_parameter_var
28
- from sagemaker .fw_utils import GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY , GRAVITON_ALLOWED_FRAMEWORKS
28
+ from sagemaker .fw_utils import (
29
+ GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY ,
30
+ GRAVITON_ALLOWED_FRAMEWORKS ,
31
+ GRAVITON_ALLOWED_REGIONS ,
32
+ SKLEARN_GRAVITON_INFERENCE_ENABLED_VERSIONS ,
33
+ XGBOOST_GRAVITON_INFERENCE_ENABLED_VERSIONS ,
34
+ )
29
35
30
36
logger = logging .getLogger (__name__ )
31
37
32
38
ECR_URI_TEMPLATE = "{registry}.dkr.{hostname}/{repository}"
33
39
HUGGING_FACE_FRAMEWORK = "huggingface"
40
+ XGBOOST_FRAMEWORK = "xgboost"
41
+ SKLEARN_FRAMEWORK = "sklearn"
42
+ INSTANCE_TYPE_REGEX = r"^ml[\._]([a-z\d]+)\.?\w*$"
34
43
35
44
36
45
@override_pipeline_parameter_var
@@ -244,6 +253,18 @@ def retrieve(
244
253
if key in container_versions :
245
254
tag = "-" .join ([tag , container_versions [key ]])
246
255
256
+ if framework in (XGBOOST_FRAMEWORK , SKLEARN_FRAMEWORK ):
257
+ match = re .match (INSTANCE_TYPE_REGEX , instance_type )
258
+ if match and match [1 ] in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY :
259
+ _validate_arg (region , GRAVITON_ALLOWED_REGIONS , "Graviton region" )
260
+ arg_name = f"{ framework } version for Graviton instances"
261
+ if framework == XGBOOST_FRAMEWORK :
262
+ _validate_arg (version , XGBOOST_GRAVITON_INFERENCE_ENABLED_VERSIONS , arg_name )
263
+ tag = f"{ version } -arm64"
264
+ else :
265
+ _validate_arg (version , SKLEARN_GRAVITON_INFERENCE_ENABLED_VERSIONS , arg_name )
266
+ tag = f"{ version } -arm64-cpu-py3"
267
+
247
268
if tag :
248
269
repo += ":{}" .format (tag )
249
270
@@ -295,7 +316,7 @@ def config_for_framework(framework):
295
316
def _get_image_scope_for_instance_type (framework , instance_type , image_scope ):
296
317
"""Extract the image scope from instance type."""
297
318
if framework in GRAVITON_ALLOWED_FRAMEWORKS and isinstance (instance_type , str ):
298
- match = re .match (r"^ml[\._]([a-z\d]+)\.?\w*$" , instance_type )
319
+ match = re .match (INSTANCE_TYPE_REGEX , instance_type )
299
320
if match and match [1 ] in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY :
300
321
return "inference_graviton"
301
322
return image_scope
@@ -304,7 +325,7 @@ def _get_image_scope_for_instance_type(framework, instance_type, image_scope):
304
325
def _get_inference_tool (inference_tool , instance_type ):
305
326
"""Extract the inference tool name from instance type."""
306
327
if not inference_tool and instance_type :
307
- match = re .match (r"^ml[\._]([a-z\d]+)\.?\w*$" , instance_type )
328
+ match = re .match (INSTANCE_TYPE_REGEX , instance_type )
308
329
if match and match [1 ].startswith ("inf" ):
309
330
return "neuron"
310
331
return inference_tool
@@ -385,7 +406,7 @@ def _processor(instance_type, available_processors, serverless_inference_config=
385
406
processor = "neuron"
386
407
else :
387
408
# looks for either "ml.<family>.<size>" or "ml_<family>"
388
- match = re .match (r"^ml[\._]([a-z\d]+)\.?\w*$" , instance_type )
409
+ match = re .match (INSTANCE_TYPE_REGEX , instance_type )
389
410
if match :
390
411
family = match [1 ]
391
412
@@ -415,7 +436,7 @@ def _should_auto_select_container_version(instance_type, distribution):
415
436
p4d = False
416
437
if instance_type :
417
438
# looks for either "ml.<family>.<size>" or "ml_<family>"
418
- match = re .match (r"^ml[\._]([a-z\d]+)\.?\w*$" , instance_type )
439
+ match = re .match (INSTANCE_TYPE_REGEX , instance_type )
419
440
if match :
420
441
family = match [1 ]
421
442
p4d = family == "p4d"
0 commit comments