Skip to content

Commit 762b509

Browse files
author
Edward J Kim
authored
fix: blacklist unknown xgboost image versions (#1519)
1 parent e86da62 commit 762b509

File tree

3 files changed

+101
-24
lines changed

3 files changed

+101
-24
lines changed

src/sagemaker/amazon/amazon_estimator.py

+47-20
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,10 @@
2828
from sagemaker.session import s3_input
2929
from sagemaker.utils import sagemaker_timestamp, get_ecr_image_uri_prefix
3030
from sagemaker.xgboost.defaults import (
31+
XGBOOST_1P_VERSIONS,
3132
XGBOOST_LATEST_VERSION,
33+
XGBOOST_NAME,
3234
XGBOOST_SUPPORTED_VERSIONS,
33-
XGBOOST_VERSION_0_90_1,
34-
XGBOOST_VERSION_0_90,
3535
XGBOOST_VERSION_EQUIVALENTS,
3636
)
3737
from sagemaker.xgboost.estimator import get_xgboost_image_uri
@@ -621,41 +621,68 @@ def get_image_uri(region_name, repo_name, repo_version=1):
621621
"in SageMaker Python SDK v2."
622622
)
623623

624-
if repo_name == "xgboost":
625-
if not _is_latest_xgboost_version(repo_version):
626-
logging.warning(
627-
"There is a more up to date SageMaker XGBoost image. "
628-
"To use the newer image, please set 'repo_version'="
629-
"'%s'. For example:\n"
630-
"\tget_image_uri(region, 'xgboost', '%s').",
631-
XGBOOST_LATEST_VERSION,
632-
XGBOOST_LATEST_VERSION,
633-
)
624+
repo_version = str(repo_version)
625+
626+
if repo_name == XGBOOST_NAME:
627+
628+
if repo_version in XGBOOST_1P_VERSIONS:
629+
_warn_newer_xgboost_image()
630+
return "{}/{}:{}".format(registry(region_name, repo_name), repo_name, repo_version)
634631

635-
if repo_version in [XGBOOST_VERSION_0_90] + _generate_version_equivalents(
636-
XGBOOST_VERSION_0_90_1
637-
):
638-
return get_xgboost_image_uri(region_name, XGBOOST_VERSION_0_90_1)
632+
if "-" not in repo_version:
633+
xgboost_version_matches = [
634+
version
635+
for version in XGBOOST_SUPPORTED_VERSIONS
636+
if repo_version == version.split("-")[0]
637+
]
638+
if xgboost_version_matches:
639+
# Assumes that XGBOOST_SUPPORTED_VERSION is sorted from oldest version to latest.
640+
# When SageMaker version is not specified, we use the oldest one that matches
641+
# XGBoost version for backward compatibility.
642+
repo_version = xgboost_version_matches[0]
639643

640-
supported_version = [
644+
supported_framework_versions = [
641645
version
642646
for version in XGBOOST_SUPPORTED_VERSIONS
643647
if repo_version in _generate_version_equivalents(version)
644648
]
645-
if supported_version:
646-
return get_xgboost_image_uri(region_name, supported_version[0])
649+
650+
if not supported_framework_versions:
651+
raise ValueError(
652+
"SageMaker XGBoost version {} is not supported. Supported versions: {}".format(
653+
repo_version, ", ".join(XGBOOST_SUPPORTED_VERSIONS)
654+
)
655+
)
656+
657+
if not _is_latest_xgboost_version(repo_version):
658+
_warn_newer_xgboost_image()
659+
660+
return get_xgboost_image_uri(region_name, supported_framework_versions[-1])
647661

648662
repo = "{}:{}".format(repo_name, repo_version)
649663
return "{}/{}".format(registry(region_name, repo_name), repo)
650664

651665

666+
def _warn_newer_xgboost_image():
667+
"""Print a warning when there is a newer XGBoost image"""
668+
logging.warning(
669+
"There is a more up to date SageMaker XGBoost image. "
670+
"To use the newer image, please set 'repo_version'="
671+
"'%s'. For example:\n"
672+
"\tget_image_uri(region, '%s', '%s').",
673+
XGBOOST_LATEST_VERSION,
674+
XGBOOST_NAME,
675+
XGBOOST_LATEST_VERSION,
676+
)
677+
678+
652679
def _is_latest_xgboost_version(repo_version):
653680
"""Compare xgboost image version with latest version
654681
655682
Args:
656683
repo_version:
657684
"""
658-
if repo_version in (1, "latest"):
685+
if repo_version in XGBOOST_1P_VERSIONS:
659686
return False
660687
return repo_version in _generate_version_equivalents(XGBOOST_LATEST_VERSION)
661688

src/sagemaker/xgboost/defaults.py

+2
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,12 @@
1414
from __future__ import absolute_import
1515

1616
XGBOOST_NAME = "xgboost"
17+
XGBOOST_1P_VERSIONS = ["1", "latest"]
1718
XGBOOST_VERSION_0_90 = "0.90"
1819
XGBOOST_VERSION_0_90_1 = "0.90-1"
1920
XGBOOST_VERSION_0_90_2 = "0.90-2"
2021
XGBOOST_LATEST_VERSION = "1.0-1"
22+
# XGBOOST_SUPPORTED_VERSIONS has XGBoost Framework versions sorted from oldest to latest
2123
XGBOOST_SUPPORTED_VERSIONS = [
2224
XGBOOST_VERSION_0_90_1,
2325
XGBOOST_VERSION_0_90_2,

tests/unit/test_amazon_estimator.py

+52-4
Original file line numberDiff line numberDiff line change
@@ -452,6 +452,10 @@ def test_file_system_record_set_data_channel():
452452
def test_get_xgboost_image_uri():
453453
legacy_xgb_image_uri = get_image_uri(REGION, "xgboost")
454454
assert legacy_xgb_image_uri == "433757028032.dkr.ecr.us-west-2.amazonaws.com/xgboost:1"
455+
legacy_xgb_image_uri = get_image_uri(REGION, "xgboost", 1)
456+
assert legacy_xgb_image_uri == "433757028032.dkr.ecr.us-west-2.amazonaws.com/xgboost:1"
457+
legacy_xgb_image_uri = get_image_uri(REGION, "xgboost", "latest")
458+
assert legacy_xgb_image_uri == "433757028032.dkr.ecr.us-west-2.amazonaws.com/xgboost:latest"
455459

456460
updated_xgb_image_uri = get_image_uri(REGION, "xgboost", "0.90-1")
457461
assert (
@@ -465,6 +469,52 @@ def test_get_xgboost_image_uri():
465469
== "246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-xgboost:0.90-2-cpu-py3"
466470
)
467471

472+
assert (
473+
get_image_uri(REGION, "xgboost", "0.90-2-cpu-py3")
474+
== "246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-xgboost:0.90-2-cpu-py3"
475+
)
476+
assert (
477+
get_image_uri(REGION, "xgboost", "0.90")
478+
== "246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-xgboost:0.90-1-cpu-py3"
479+
)
480+
assert (
481+
get_image_uri(REGION, "xgboost", "1.0-1")
482+
== "246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-xgboost:1.0-1-cpu-py3"
483+
)
484+
assert (
485+
get_image_uri(REGION, "xgboost", "1.0-1-cpu-py3")
486+
== "246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-xgboost:1.0-1-cpu-py3"
487+
)
488+
assert (
489+
get_image_uri(REGION, "xgboost", "1.0")
490+
== "246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-xgboost:1.0-1-cpu-py3"
491+
)
492+
493+
494+
def test_get_xgboost_image_uri_warning_with_legacy(caplog):
495+
get_image_uri(REGION, "xgboost", 1)
496+
assert "There is a more up to date SageMaker XGBoost image." in caplog.text
497+
498+
499+
def test_get_xgboost_image_uri_warning_with_no_sagemaker_version(caplog):
500+
get_image_uri(REGION, "xgboost", "0.90")
501+
assert "There is a more up to date SageMaker XGBoost image." in caplog.text
502+
503+
504+
def test_get_xgboost_image_uri_no_warning_with_latest(caplog):
505+
get_image_uri(REGION, "xgboost", XGBOOST_LATEST_VERSION.split("-")[0])
506+
assert "There is a more up to date SageMaker XGBoost image." not in caplog.text
507+
508+
509+
def test_get_xgboost_image_uri_throws_error_for_unsupported_version():
510+
with pytest.raises(ValueError) as error:
511+
get_image_uri(REGION, "xgboost", "99.99-9")
512+
assert "SageMaker XGBoost version 99.99-9 is not supported" in str(error)
513+
514+
with pytest.raises(ValueError) as error:
515+
get_image_uri(REGION, "xgboost", "0.90-1-gpu-py3")
516+
assert "SageMaker XGBoost version 0.90-1-gpu-py3 is not supported" in str(error)
517+
468518

469519
def test_regitry_throws_error_if_mapping_does_not_exist_for_lda():
470520
with pytest.raises(ValueError) as error:
@@ -482,10 +532,8 @@ def test_is_latest_xgboost_version():
482532
for version in XGBOOST_SUPPORTED_VERSIONS:
483533
if version != XGBOOST_LATEST_VERSION:
484534
assert _is_latest_xgboost_version(version) is False
485-
486-
assert _is_latest_xgboost_version("0.90-1-cpu-py3") is False
487-
488-
assert _is_latest_xgboost_version(XGBOOST_LATEST_VERSION) is True
535+
else:
536+
assert _is_latest_xgboost_version(version) is True
489537

490538

491539
def test_get_image_uri_warn(caplog):

0 commit comments

Comments
 (0)