@@ -410,6 +410,8 @@ def framework_name_from_image(image_uri):
410
410
'<account>.dkr.ecr.<region>.amazonaws.com/sagemaker-rl-<fw>:<rl_toolkit><rl_version>-<device>-<py_ver>'
411
411
current:
412
412
'<account>.dkr.ecr.<region>.amazonaws.com/<fw>-<image_scope>:<fw_version>-<device>-<py_ver>'
413
+ current:
414
+ '<account>.dkr.ecr.<region>.amazonaws.com/sagemaker-xgboost:<fw_version>-<container_version>'
413
415
414
416
Returns:
415
417
tuple: A tuple containing:
@@ -450,6 +452,15 @@ def framework_name_from_image(image_uri):
450
452
legacy_match = legacy_name_pattern .match (sagemaker_match .group (9 ))
451
453
if legacy_match is not None :
452
454
return (legacy_match .group (1 ), legacy_match .group (2 ), legacy_match .group (4 ), None )
455
+
456
+ # sagemaker-xgboost images are tagged with two aliases, e.g.:
457
+ # 1. Long-form: "315553699071.dkr.ecr.us-west-2.amazonaws.com/sagemaker-xgboost:1.5-1-cpu-py3"
458
+ # 2. Short-form: "315553699071.dkr.ecr.us-west-2.amazonaws.com/sagemaker-xgboost:1.5-1"
459
+ # Both tags point to the same image and the image postfixed with cpu can also supports gpu
460
+ short_xgboost_name_pattern = re .compile (r"^sagemaker-(xgboost):(.*)$" )
461
+ short_xgboost_match = short_xgboost_name_pattern .match (sagemaker_match .group (9 ))
462
+ if short_xgboost_match is not None :
463
+ return (short_xgboost_match .group (1 ), "py3" , short_xgboost_match .group (2 ), None )
453
464
return None , None , None , None
454
465
455
466
0 commit comments