Skip to content

fix: blacklist unknown xgboost image versions #1519

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 22, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 47 additions & 20 deletions src/sagemaker/amazon/amazon_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@
from sagemaker.session import s3_input
from sagemaker.utils import sagemaker_timestamp, get_ecr_image_uri_prefix
from sagemaker.xgboost.defaults import (
XGBOOST_1P_VERSIONS,
XGBOOST_LATEST_VERSION,
XGBOOST_NAME,
XGBOOST_SUPPORTED_VERSIONS,
XGBOOST_VERSION_0_90_1,
XGBOOST_VERSION_0_90,
XGBOOST_VERSION_EQUIVALENTS,
)
from sagemaker.xgboost.estimator import get_xgboost_image_uri
Expand Down Expand Up @@ -621,41 +621,68 @@ def get_image_uri(region_name, repo_name, repo_version=1):
"in SageMaker Python SDK v2."
)

if repo_name == "xgboost":
if not _is_latest_xgboost_version(repo_version):
logging.warning(
"There is a more up to date SageMaker XGBoost image. "
"To use the newer image, please set 'repo_version'="
"'%s'. For example:\n"
"\tget_image_uri(region, 'xgboost', '%s').",
XGBOOST_LATEST_VERSION,
XGBOOST_LATEST_VERSION,
)
repo_version = str(repo_version)

if repo_name == XGBOOST_NAME:

if repo_version in XGBOOST_1P_VERSIONS:
_warn_newer_xgboost_image()
return "{}/{}:{}".format(registry(region_name, repo_name), repo_name, repo_version)

if repo_version in [XGBOOST_VERSION_0_90] + _generate_version_equivalents(
XGBOOST_VERSION_0_90_1
):
return get_xgboost_image_uri(region_name, XGBOOST_VERSION_0_90_1)
if "-" not in repo_version:
xgboost_version_matches = [
version
for version in XGBOOST_SUPPORTED_VERSIONS
if repo_version == version.split("-")[0]
]
if xgboost_version_matches:
# Assumes that XGBOOST_SUPPORTED_VERSION is sorted from oldest version to latest.
# When SageMaker version is not specified, we use the oldest one that matches
# XGBoost version for backward compatibility.
repo_version = xgboost_version_matches[0]

supported_version = [
supported_framework_versions = [
version
for version in XGBOOST_SUPPORTED_VERSIONS
if repo_version in _generate_version_equivalents(version)
]
if supported_version:
return get_xgboost_image_uri(region_name, supported_version[0])

if not supported_framework_versions:
raise ValueError(
"SageMaker XGBoost version {} is not supported. Supported versions: {}".format(
repo_version, ", ".join(XGBOOST_SUPPORTED_VERSIONS)
)
)

if not _is_latest_xgboost_version(repo_version):
_warn_newer_xgboost_image()

return get_xgboost_image_uri(region_name, supported_framework_versions[-1])

repo = "{}:{}".format(repo_name, repo_version)
return "{}/{}".format(registry(region_name, repo_name), repo)


def _warn_newer_xgboost_image():
"""Print a warning when there is a newer XGBoost image"""
logging.warning(
"There is a more up to date SageMaker XGBoost image. "
"To use the newer image, please set 'repo_version'="
"'%s'. For example:\n"
"\tget_image_uri(region, '%s', '%s').",
XGBOOST_LATEST_VERSION,
XGBOOST_NAME,
XGBOOST_LATEST_VERSION,
)


def _is_latest_xgboost_version(repo_version):
"""Compare xgboost image version with latest version

Args:
repo_version:
"""
if repo_version in (1, "latest"):
if repo_version in XGBOOST_1P_VERSIONS:
return False
return repo_version in _generate_version_equivalents(XGBOOST_LATEST_VERSION)

Expand Down
2 changes: 2 additions & 0 deletions src/sagemaker/xgboost/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@
from __future__ import absolute_import

XGBOOST_NAME = "xgboost"
XGBOOST_1P_VERSIONS = ["1", "latest"]
XGBOOST_VERSION_0_90 = "0.90"
XGBOOST_VERSION_0_90_1 = "0.90-1"
XGBOOST_VERSION_0_90_2 = "0.90-2"
XGBOOST_LATEST_VERSION = "1.0-1"
# XGBOOST_SUPPORTED_VERSIONS has XGBoost Framework versions sorted from oldest to latest
XGBOOST_SUPPORTED_VERSIONS = [
XGBOOST_VERSION_0_90_1,
XGBOOST_VERSION_0_90_2,
Expand Down
56 changes: 52 additions & 4 deletions tests/unit/test_amazon_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,10 @@ def test_file_system_record_set_data_channel():
def test_get_xgboost_image_uri():
legacy_xgb_image_uri = get_image_uri(REGION, "xgboost")
assert legacy_xgb_image_uri == "433757028032.dkr.ecr.us-west-2.amazonaws.com/xgboost:1"
legacy_xgb_image_uri = get_image_uri(REGION, "xgboost", 1)
assert legacy_xgb_image_uri == "433757028032.dkr.ecr.us-west-2.amazonaws.com/xgboost:1"
legacy_xgb_image_uri = get_image_uri(REGION, "xgboost", "latest")
assert legacy_xgb_image_uri == "433757028032.dkr.ecr.us-west-2.amazonaws.com/xgboost:latest"

updated_xgb_image_uri = get_image_uri(REGION, "xgboost", "0.90-1")
assert (
Expand All @@ -465,6 +469,52 @@ def test_get_xgboost_image_uri():
== "246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-xgboost:0.90-2-cpu-py3"
)

assert (
get_image_uri(REGION, "xgboost", "0.90-2-cpu-py3")
== "246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-xgboost:0.90-2-cpu-py3"
)
assert (
get_image_uri(REGION, "xgboost", "0.90")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once again, we were previously returning the 0.90-1 image for this. Let's check with Rahul on whether or not we want this new behavior.

== "246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-xgboost:0.90-1-cpu-py3"
)
assert (
get_image_uri(REGION, "xgboost", "1.0-1")
== "246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-xgboost:1.0-1-cpu-py3"
)
assert (
get_image_uri(REGION, "xgboost", "1.0-1-cpu-py3")
== "246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-xgboost:1.0-1-cpu-py3"
)
assert (
get_image_uri(REGION, "xgboost", "1.0")
== "246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-xgboost:1.0-1-cpu-py3"
)


def test_get_xgboost_image_uri_warning_with_legacy(caplog):
get_image_uri(REGION, "xgboost", 1)
assert "There is a more up to date SageMaker XGBoost image." in caplog.text


def test_get_xgboost_image_uri_warning_with_no_sagemaker_version(caplog):
get_image_uri(REGION, "xgboost", "0.90")
assert "There is a more up to date SageMaker XGBoost image." in caplog.text


def test_get_xgboost_image_uri_no_warning_with_latest(caplog):
get_image_uri(REGION, "xgboost", XGBOOST_LATEST_VERSION.split("-")[0])
assert "There is a more up to date SageMaker XGBoost image." not in caplog.text


def test_get_xgboost_image_uri_throws_error_for_unsupported_version():
with pytest.raises(ValueError) as error:
get_image_uri(REGION, "xgboost", "99.99-9")
assert "SageMaker XGBoost version 99.99-9 is not supported" in str(error)

with pytest.raises(ValueError) as error:
get_image_uri(REGION, "xgboost", "0.90-1-gpu-py3")
assert "SageMaker XGBoost version 0.90-1-gpu-py3 is not supported" in str(error)


def test_regitry_throws_error_if_mapping_does_not_exist_for_lda():
with pytest.raises(ValueError) as error:
Expand All @@ -482,10 +532,8 @@ def test_is_latest_xgboost_version():
for version in XGBOOST_SUPPORTED_VERSIONS:
if version != XGBOOST_LATEST_VERSION:
assert _is_latest_xgboost_version(version) is False

assert _is_latest_xgboost_version("0.90-1-cpu-py3") is False

assert _is_latest_xgboost_version(XGBOOST_LATEST_VERSION) is True
else:
assert _is_latest_xgboost_version(version) is True


def test_get_image_uri_warn(caplog):
Expand Down