Skip to content

Commit fbebd80

Browse files
fix: fix xgboost image incorrect latest version warning (#1434)
1 parent ddd5800 commit fbebd80

File tree

3 files changed

+50
-10
lines changed

3 files changed

+50
-10
lines changed

src/sagemaker/amazon/amazon_estimator.py

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,12 @@
2727
from sagemaker.model import NEO_IMAGE_ACCOUNT
2828
from sagemaker.session import s3_input
2929
from sagemaker.utils import sagemaker_timestamp, get_ecr_image_uri_prefix
30-
from sagemaker.xgboost.defaults import XGBOOST_VERSION_1, XGBOOST_SUPPORTED_VERSIONS
30+
from sagemaker.xgboost.defaults import (
31+
XGBOOST_LATEST_VERSION,
32+
XGBOOST_SUPPORTED_VERSIONS,
33+
XGBOOST_VERSION_1,
34+
XGBOOST_VERSION_EQUIVALENTS,
35+
)
3136
from sagemaker.xgboost.estimator import get_xgboost_image_uri
3237

3338
logger = logging.getLogger(__name__)
@@ -611,24 +616,46 @@ def get_image_uri(region_name, repo_name, repo_version=1):
611616
repo_version:
612617
"""
613618
if repo_name == "xgboost":
619+
if not _is_latest_xgboost_version(repo_version):
620+
logging.warning(
621+
"There is a more up to date SageMaker XGBoost image. "
622+
"To use the newer image, please set 'repo_version'="
623+
"'%s'. For example:\n"
624+
"\tget_image_uri(region, 'xgboost', '%s').",
625+
XGBOOST_LATEST_VERSION,
626+
XGBOOST_LATEST_VERSION,
627+
)
628+
614629
if repo_version in ["0.90", "0.90-1", "0.90-1-cpu-py3"]:
615630
return get_xgboost_image_uri(region_name, XGBOOST_VERSION_1)
616631

617632
supported_version = [
618633
version
619634
for version in XGBOOST_SUPPORTED_VERSIONS
620-
if repo_version in (version, version + "-cpu-py3")
635+
if repo_version in _generate_version_equivalents(version)
621636
]
622637
if supported_version:
623638
return get_xgboost_image_uri(region_name, supported_version[0])
624639

625-
logging.warning(
626-
"There is a more up to date SageMaker XGBoost image. "
627-
"To use the newer image, please set 'repo_version'="
628-
"'%s'. For example:\n"
629-
"\tget_image_uri(region, 'xgboost', '%s').",
630-
XGBOOST_VERSION_1,
631-
XGBOOST_VERSION_1,
632-
)
633640
repo = "{}:{}".format(repo_name, repo_version)
634641
return "{}/{}".format(registry(region_name, repo_name), repo)
642+
643+
644+
def _is_latest_xgboost_version(repo_version):
645+
"""Compare xgboost image version with latest version
646+
647+
Args:
648+
repo_version:
649+
"""
650+
if repo_version in (1, "latest"):
651+
return False
652+
return repo_version in _generate_version_equivalents(XGBOOST_LATEST_VERSION)
653+
654+
655+
def _generate_version_equivalents(version):
656+
"""Returns a list of version equivalents for XGBoost
657+
658+
Args:
659+
version:
660+
"""
661+
return [version + suffix for suffix in XGBOOST_VERSION_EQUIVALENTS] + [version]

src/sagemaker/xgboost/defaults.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,4 @@
1717
XGBOOST_VERSION_1 = "0.90-1"
1818
XGBOOST_LATEST_VERSION = "0.90-2"
1919
XGBOOST_SUPPORTED_VERSIONS = [XGBOOST_VERSION_1, XGBOOST_LATEST_VERSION]
20+
XGBOOST_VERSION_EQUIVALENTS = ["-cpu-py3"]

tests/unit/test_amazon_estimator.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@
2424
registry,
2525
get_image_uri,
2626
FileSystemRecordSet,
27+
_is_latest_xgboost_version,
2728
)
29+
from sagemaker.xgboost.defaults import XGBOOST_LATEST_VERSION, XGBOOST_SUPPORTED_VERSIONS
2830

2931
COMMON_ARGS = {"role": "myrole", "train_instance_count": 1, "train_instance_type": "ml.c4.xlarge"}
3032

@@ -474,3 +476,13 @@ def test_regitry_throws_error_if_mapping_does_not_exist_for_default_algorithm():
474476
with pytest.raises(ValueError) as error:
475477
registry("broken_region_name")
476478
assert "Algorithm (None) is unsupported for region (broken_region_name)." in str(error)
479+
480+
481+
def test_is_latest_xgboost_version():
482+
for version in XGBOOST_SUPPORTED_VERSIONS:
483+
if version != XGBOOST_LATEST_VERSION:
484+
assert _is_latest_xgboost_version(version) is False
485+
486+
assert _is_latest_xgboost_version("0.90-1-cpu-py3") is False
487+
488+
assert _is_latest_xgboost_version(XGBOOST_LATEST_VERSION) is True

0 commit comments

Comments
 (0)