Skip to content

Commit 342f218

Browse files
mabundayMark Bunday
and
Mark Bunday
authored
fix: Add regex for short-form sagemaker-xgboost tags (#3358)
Co-authored-by: Mark Bunday <[email protected]>
1 parent e8ee340 commit 342f218

File tree

2 files changed

+43
-0
lines changed

2 files changed

+43
-0
lines changed

src/sagemaker/fw_utils.py

+16
Original file line numberDiff line numberDiff line change
@@ -410,6 +410,8 @@ def framework_name_from_image(image_uri):
410410
'<account>.dkr.ecr.<region>.amazonaws.com/sagemaker-rl-<fw>:<rl_toolkit><rl_version>-<device>-<py_ver>'
411411
current:
412412
'<account>.dkr.ecr.<region>.amazonaws.com/<fw>-<image_scope>:<fw_version>-<device>-<py_ver>'
413+
current:
414+
'<account>.dkr.ecr.<region>.amazonaws.com/sagemaker-xgboost:<fw_version>-<container_version>'
413415
414416
Returns:
415417
tuple: A tuple containing:
@@ -450,6 +452,16 @@ def framework_name_from_image(image_uri):
450452
legacy_match = legacy_name_pattern.match(sagemaker_match.group(9))
451453
if legacy_match is not None:
452454
return (legacy_match.group(1), legacy_match.group(2), legacy_match.group(4), None)
455+
456+
# sagemaker-xgboost images are tagged with two aliases, e.g.:
457+
# 1. Long tag: "315553699071.dkr.ecr.us-west-2.amazonaws.com/sagemaker-xgboost:1.5-1-cpu-py3"
458+
# 2. Short tag: "315553699071.dkr.ecr.us-west-2.amazonaws.com/sagemaker-xgboost:1.5-1"
459+
# Note 1: Both tags point to the same image
460+
# Note 2: Both tags have full GPU capabilities, despite "cpu" delineation in the long tag
461+
short_xgboost_tag_pattern = re.compile(r"^sagemaker-(xgboost):(.*)$")
462+
short_xgboost_tag_match = short_xgboost_tag_pattern.match(sagemaker_match.group(9))
463+
if short_xgboost_tag_match is not None:
464+
return (short_xgboost_tag_match.group(1), "py3", short_xgboost_tag_match.group(2), None)
453465
return None, None, None, None
454466

455467

@@ -459,12 +471,16 @@ def framework_version_from_tag(image_tag):
459471
Args:
460472
image_tag (str): Image tag, which should take the form
461473
'<framework_version>-<device>-<py_version>'
474+
'<xgboost_version>-<container_version>'
462475
463476
Returns:
464477
str: The framework version.
465478
"""
466479
tag_pattern = re.compile(r"^(.*)-(cpu|gpu)-(py2|py3\d*)$")
467480
tag_match = tag_pattern.match(image_tag)
481+
if tag_match is None:
482+
short_xgboost_tag_pattern = re.compile(r"^(\d\.\d+\-\d)$")
483+
tag_match = short_xgboost_tag_pattern.match(image_tag)
468484
return None if tag_match is None else tag_match.group(1)
469485

470486

tests/unit/test_fw_utils.py

+27
Original file line numberDiff line numberDiff line change
@@ -511,6 +511,33 @@ def test_framework_version_from_tag_other():
511511
assert version is None
512512

513513

514+
def test_xgboost_version_from_tag():
515+
tags = (
516+
"1.5-1-cpu-py3",
517+
"1.5-1",
518+
)
519+
520+
for tag in tags:
521+
version = fw_utils.framework_version_from_tag(tag)
522+
assert "1.5-1" == version
523+
524+
525+
def test_framework_name_from_xgboost_image_short_tag():
526+
ecr_uri = "246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-xgboost"
527+
image_tag = "1.5-1"
528+
image_uri = f"{ecr_uri}:{image_tag}"
529+
expected_result = ("xgboost", "py3", "1.5-1", None)
530+
assert expected_result == fw_utils.framework_name_from_image(image_uri)
531+
532+
533+
def test_framework_name_from_xgboost_image_long_tag():
534+
ecr_uri = "246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-xgboost"
535+
image_tag = "1.5-1-cpu-py3"
536+
image_uri = f"{ecr_uri}:{image_tag}"
537+
expected_result = ("xgboost", "py3", "1.5-1-cpu-py3", None)
538+
assert expected_result == fw_utils.framework_name_from_image(image_uri)
539+
540+
514541
def test_model_code_key_prefix_with_all_values_present():
515542
key_prefix = fw_utils.model_code_key_prefix("prefix", "model_name", "image_uri")
516543
assert key_prefix == "prefix/model_name"

0 commit comments

Comments
 (0)