Skip to content

Commit e4485b7

Browse files
authored
breaking: use image_uris.retrieve() for XGBoost URIs (#1719)
This also deprecates sagemaker.fw_utils.get_unsupported_framework_version_error(), as well as many of the constant defined in sagemaker.xgboost.defaults.
1 parent 81ab6e2 commit e4485b7

12 files changed

+52
-300
lines changed

src/sagemaker/amazon/amazon_estimator.py

-79
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,6 @@
2626
from sagemaker.inputs import FileSystemInput, TrainingInput
2727
from sagemaker.model import NEO_IMAGE_ACCOUNT
2828
from sagemaker.utils import sagemaker_timestamp, get_ecr_image_uri_prefix
29-
from sagemaker.xgboost.defaults import (
30-
XGBOOST_1P_VERSIONS,
31-
XGBOOST_LATEST_VERSION,
32-
XGBOOST_NAME,
33-
XGBOOST_SUPPORTED_VERSIONS,
34-
XGBOOST_VERSION_EQUIVALENTS,
35-
)
36-
from sagemaker.xgboost.estimator import get_xgboost_image_uri
3729

3830
logger = logging.getLogger(__name__)
3931

@@ -622,76 +614,5 @@ def get_image_uri(region_name, repo_name, repo_version=1):
622614
"in SageMaker Python SDK v2."
623615
)
624616

625-
repo_version = str(repo_version)
626-
627-
if repo_name == XGBOOST_NAME:
628-
629-
if repo_version in XGBOOST_1P_VERSIONS:
630-
_warn_newer_xgboost_image()
631-
return "{}/{}:{}".format(registry(region_name, repo_name), repo_name, repo_version)
632-
633-
if "-" not in repo_version:
634-
xgboost_version_matches = [
635-
version
636-
for version in XGBOOST_SUPPORTED_VERSIONS
637-
if repo_version == version.split("-")[0]
638-
]
639-
if xgboost_version_matches:
640-
# Assumes that XGBOOST_SUPPORTED_VERSION is sorted from oldest version to latest.
641-
# When SageMaker version is not specified, we use the oldest one that matches
642-
# XGBoost version for backward compatibility.
643-
repo_version = xgboost_version_matches[0]
644-
645-
supported_framework_versions = [
646-
version
647-
for version in XGBOOST_SUPPORTED_VERSIONS
648-
if repo_version in _generate_version_equivalents(version)
649-
]
650-
651-
if not supported_framework_versions:
652-
raise ValueError(
653-
"SageMaker XGBoost version {} is not supported. Supported versions: {}".format(
654-
repo_version, ", ".join(XGBOOST_SUPPORTED_VERSIONS)
655-
)
656-
)
657-
658-
if not _is_latest_xgboost_version(repo_version):
659-
_warn_newer_xgboost_image()
660-
661-
return get_xgboost_image_uri(region_name, supported_framework_versions[-1])
662-
663617
repo = "{}:{}".format(repo_name, repo_version)
664618
return "{}/{}".format(registry(region_name, repo_name), repo)
665-
666-
667-
def _warn_newer_xgboost_image():
668-
"""Print a warning when there is a newer XGBoost image"""
669-
logging.warning(
670-
"There is a more up to date SageMaker XGBoost image. "
671-
"To use the newer image, please set 'repo_version'="
672-
"'%s'. For example:\n"
673-
"\tget_image_uri(region, '%s', '%s').",
674-
XGBOOST_LATEST_VERSION,
675-
XGBOOST_NAME,
676-
XGBOOST_LATEST_VERSION,
677-
)
678-
679-
680-
def _is_latest_xgboost_version(repo_version):
681-
"""Compare xgboost image version with latest version
682-
683-
Args:
684-
repo_version:
685-
"""
686-
if repo_version in XGBOOST_1P_VERSIONS:
687-
return False
688-
return repo_version in _generate_version_equivalents(XGBOOST_LATEST_VERSION)
689-
690-
691-
def _generate_version_equivalents(version):
692-
"""Returns a list of version equivalents for XGBoost
693-
694-
Args:
695-
version:
696-
"""
697-
return [version + suffix for suffix in XGBOOST_VERSION_EQUIVALENTS] + [version]

src/sagemaker/fw_utils.py

-22
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,6 @@
6565
"framework_version is required for script mode estimator. "
6666
"Please add framework_version={} to your constructor to avoid this error."
6767
)
68-
UNSUPPORTED_FRAMEWORK_VERSION_ERROR = (
69-
"{} framework does not support version {}. Please use one of the following: {}."
70-
)
7168

7269
VALID_PY_VERSIONS = ["py2", "py3", "py37"]
7370
VALID_EIA_FRAMEWORKS = [
@@ -637,25 +634,6 @@ def warn_if_parameter_server_with_multi_gpu(training_instance_type, distribution
637634
logger.warning(PARAMETER_SERVER_MULTI_GPU_WARNING)
638635

639636

640-
def get_unsupported_framework_version_error(
641-
framework_name, unsupported_version, supported_versions
642-
):
643-
"""Return error message for unsupported framework version.
644-
645-
This should also return the supported versions for customers.
646-
647-
:param framework_name:
648-
:param unsupported_version:
649-
:param supported_versions:
650-
:return:
651-
"""
652-
return UNSUPPORTED_FRAMEWORK_VERSION_ERROR.format(
653-
framework_name,
654-
unsupported_version,
655-
", ".join('"{}"'.format(version) for version in supported_versions),
656-
)
657-
658-
659637
def python_deprecation_warning(framework, latest_supported_version):
660638
"""
661639
Args:

src/sagemaker/xgboost/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,6 @@
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
1313
"""Placeholder docstring"""
14-
from sagemaker.xgboost.defaults import XGBOOST_NAME, XGBOOST_LATEST_VERSION # noqa: F401
14+
from sagemaker.xgboost.defaults import XGBOOST_NAME # noqa: F401
1515
from sagemaker.xgboost.estimator import XGBoost # noqa: F401
1616
from sagemaker.xgboost.model import XGBoostModel, XGBoostPredictor # noqa: F401

src/sagemaker/xgboost/defaults.py

-17
Original file line numberDiff line numberDiff line change
@@ -14,20 +14,3 @@
1414
from __future__ import absolute_import
1515

1616
XGBOOST_NAME = "xgboost"
17-
XGBOOST_1P_VERSIONS = ["1", "latest"]
18-
19-
XGBOOST_VERSION_0_90_1 = "0.90-1"
20-
XGBOOST_VERSION_0_90_2 = "0.90-2"
21-
22-
XGBOOST_LATEST_VERSION = "1.0-1"
23-
24-
# XGBOOST_SUPPORTED_VERSIONS has XGBoost Framework versions sorted from oldest to latest
25-
XGBOOST_SUPPORTED_VERSIONS = [
26-
XGBOOST_VERSION_0_90_1,
27-
XGBOOST_VERSION_0_90_2,
28-
XGBOOST_LATEST_VERSION,
29-
]
30-
31-
# TODO: evaluate use of this constant. it's used in precisely one place in different a module
32-
# may possibly be unnecessary indirection
33-
XGBOOST_VERSION_EQUIVALENTS = ["-cpu-py3"]

src/sagemaker/xgboost/estimator.py

+9-21
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,11 @@
1515

1616
import logging
1717

18+
from sagemaker import image_uris
1819
from sagemaker.estimator import Framework, _TrainingJob
19-
from sagemaker.fw_registry import default_framework_uri
2020
from sagemaker.fw_utils import (
2121
framework_name_from_image,
2222
framework_version_from_tag,
23-
get_unsupported_framework_version_error,
2423
UploadedCode,
2524
)
2625
from sagemaker.session import Session
@@ -31,12 +30,6 @@
3130
logger = logging.getLogger("sagemaker")
3231

3332

34-
def get_xgboost_image_uri(region, framework_version, py_version="py3"):
35-
"""Get XGBoost framework image URI"""
36-
image_tag = "{}-{}-{}".format(framework_version, "cpu", py_version)
37-
return default_framework_uri(XGBoost.__framework_name__, region, image_tag)
38-
39-
4033
class XGBoost(Framework):
4134
"""Handle end-to-end training and deployment of XGBoost booster training or training using
4235
customer provided XGBoost entry point script."""
@@ -105,22 +98,17 @@ def __init__(
10598
entry_point, source_dir, hyperparameters, image_uri=image_uri, **kwargs
10699
)
107100

108-
if py_version == "py2":
109-
raise AttributeError("XGBoost container does not support Python 2, please use Python 3")
110101
self.py_version = py_version
111-
112-
if framework_version in defaults.XGBOOST_SUPPORTED_VERSIONS:
113-
self.framework_version = framework_version
114-
else:
115-
raise ValueError(
116-
get_unsupported_framework_version_error(
117-
self.__framework_name__, framework_version, defaults.XGBOOST_SUPPORTED_VERSIONS
118-
)
119-
)
102+
self.framework_version = framework_version
120103

121104
if image_uri is None:
122-
self.image_uri = get_xgboost_image_uri(
123-
self.sagemaker_session.boto_region_name, framework_version
105+
self.image_uri = image_uris.retrieve(
106+
self.__framework_name__,
107+
self.sagemaker_session.boto_region_name,
108+
version=framework_version,
109+
py_version=self.py_version,
110+
instance_type=kwargs.get("instance_type"),
111+
image_scope="training",
124112
)
125113

126114
def create_model(

src/sagemaker/xgboost/model.py

+11-10
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@
1616
import logging
1717

1818
import sagemaker
19+
from sagemaker import image_uris
1920
from sagemaker.deserializers import CSVDeserializer
2021
from sagemaker.fw_utils import model_code_key_prefix
21-
from sagemaker.fw_registry import default_framework_uri
2222
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
2323
from sagemaker.predictor import Predictor
2424
from sagemaker.serializers import NumpySerializer
@@ -100,9 +100,6 @@ def __init__(
100100
model_data, image_uri, role, entry_point, predictor_cls=predictor_cls, **kwargs
101101
)
102102

103-
if py_version == "py2":
104-
raise AttributeError("XGBoost container does not support Python 2, please use Python 3")
105-
106103
self.py_version = py_version
107104
self.framework_version = framework_version
108105
self.model_server_workers = model_server_workers
@@ -136,17 +133,21 @@ def prepare_container_def(self, instance_type=None, accelerator_type=None):
136133
deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str(self.model_server_workers)
137134
return sagemaker.container_def(deploy_image, self.model_data, deploy_env)
138135

139-
def serving_image_uri(self, region_name, instance_type): # pylint: disable=unused-argument
136+
def serving_image_uri(self, region_name, instance_type):
140137
"""Create a URI for the serving image.
141138
142139
Args:
143140
region_name (str): AWS region where the image is uploaded.
144-
instance_type (str): SageMaker instance type. This parameter is unused because
145-
XGBoost supports only CPU.
141+
instance_type (str): SageMaker instance type. Must be a CPU instance type.
146142
147143
Returns:
148144
str: The appropriate image URI based on the given parameters.
149-
150145
"""
151-
image_tag = "{}-{}-{}".format(self.framework_version, "cpu", self.py_version)
152-
return default_framework_uri(self.__framework_name__, region_name, image_tag)
146+
return image_uris.retrieve(
147+
self.__framework_name__,
148+
region_name,
149+
version=self.framework_version,
150+
py_version=self.py_version,
151+
instance_type=instance_type,
152+
image_scope="inference",
153+
)

tests/conftest.py

-10
Original file line numberDiff line numberDiff line change
@@ -242,16 +242,6 @@ def tf_full_py_version(tf_full_version):
242242
return "py37"
243243

244244

245-
@pytest.fixture(scope="module")
246-
def xgboost_full_version():
247-
return "1.0-1"
248-
249-
250-
@pytest.fixture(scope="module")
251-
def xgboost_full_py_version():
252-
return "py3"
253-
254-
255245
@pytest.fixture(scope="session")
256246
def cpu_instance_type(sagemaker_session, request):
257247
region = sagemaker_session.boto_session.region_name

tests/integ/test_airflow_config.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -552,13 +552,13 @@ def test_tf_airflow_config_uploads_data_source_to_s3(
552552

553553
@pytest.mark.canary_quick
554554
def test_xgboost_airflow_config_uploads_data_source_to_s3(
555-
sagemaker_session, cpu_instance_type, xgboost_full_version, xgboost_full_py_version
555+
sagemaker_session, cpu_instance_type, xgboost_latest_version, xgboost_latest_py_version
556556
):
557557
with timeout(seconds=AIRFLOW_CONFIG_TIMEOUT_IN_SECONDS):
558558
xgboost = XGBoost(
559559
entry_point=os.path.join(DATA_DIR, "dummy_script.py"),
560-
framework_version=xgboost_full_version,
561-
py_version=xgboost_full_py_version,
560+
framework_version=xgboost_latest_version,
561+
py_version=xgboost_latest_py_version,
562562
role=ROLE,
563563
sagemaker_session=sagemaker_session,
564564
instance_type=cpu_instance_type,

tests/integ/test_inference_pipeline.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
timeout_and_delete_model_with_transformer,
2323
)
2424

25-
from sagemaker.amazon.amazon_estimator import get_image_uri
25+
from sagemaker import image_uris
2626
from sagemaker.content_types import CONTENT_TYPE_CSV
2727
from sagemaker.model import Model
2828
from sagemaker.pipeline import PipelineModel
@@ -66,7 +66,9 @@ def test_inference_pipeline_batch_transform(sagemaker_session, cpu_instance_type
6666
env={"SAGEMAKER_SPARKML_SCHEMA": SCHEMA},
6767
sagemaker_session=sagemaker_session,
6868
)
69-
xgb_image = get_image_uri(sagemaker_session.boto_region_name, "xgboost")
69+
xgb_image = image_uris.retrieve(
70+
"xgboost", sagemaker_session.boto_region_name, version="1", image_scope="inference"
71+
)
7072
xgb_model = Model(
7173
model_data=xgb_model_data, image_uri=xgb_image, sagemaker_session=sagemaker_session
7274
)
@@ -115,7 +117,9 @@ def test_inference_pipeline_model_deploy(sagemaker_session, cpu_instance_type):
115117
env={"SAGEMAKER_SPARKML_SCHEMA": SCHEMA},
116118
sagemaker_session=sagemaker_session,
117119
)
118-
xgb_image = get_image_uri(sagemaker_session.boto_region_name, "xgboost")
120+
xgb_image = image_uris.retrieve(
121+
"xgboost", sagemaker_session.boto_region_name, version="1", image_scope="inference"
122+
)
119123
xgb_model = Model(
120124
model_data=xgb_model_data, image_uri=xgb_image, sagemaker_session=sagemaker_session
121125
)
@@ -169,7 +173,9 @@ def test_inference_pipeline_model_deploy_and_update_endpoint(
169173
env={"SAGEMAKER_SPARKML_SCHEMA": SCHEMA},
170174
sagemaker_session=sagemaker_session,
171175
)
172-
xgb_image = get_image_uri(sagemaker_session.boto_region_name, "xgboost")
176+
xgb_image = image_uris.retrieve(
177+
"xgboost", sagemaker_session.boto_region_name, version="1", image_scope="inference"
178+
)
173179
xgb_model = Model(
174180
model_data=xgb_model_data, image_uri=xgb_image, sagemaker_session=sagemaker_session
175181
)

tests/integ/test_multi_variant_endpoint.py

+11-8
Original file line numberDiff line numberDiff line change
@@ -13,27 +13,25 @@
1313
from __future__ import absolute_import
1414

1515
import json
16-
import os
1716
import math
17+
import os
18+
1819
import pytest
1920
import scipy.stats as st
2021

22+
from sagemaker import image_uris
2123
from sagemaker.s3 import S3Uploader
2224
from sagemaker.session import production_variant
2325
from sagemaker.sparkml import SparkMLModel
24-
from sagemaker.utils import sagemaker_timestamp
2526
from sagemaker.content_types import CONTENT_TYPE_CSV
2627
from sagemaker.utils import unique_name_from_base
27-
from sagemaker.amazon.amazon_estimator import get_image_uri
2828
from sagemaker.predictor import Predictor
2929
from sagemaker.serializers import CSVSerializer
30-
31-
3230
import tests.integ
3331

3432

3533
ROLE = "SageMakerRole"
36-
MODEL_NAME = "test-xgboost-model-{}".format(sagemaker_timestamp())
34+
MODEL_NAME = unique_name_from_base("test-xgboost-model")
3735
DEFAULT_REGION = "us-west-2"
3836
DEFAULT_INSTANCE_TYPE = "ml.m5.xlarge"
3937
DEFAULT_INSTANCE_COUNT = 1
@@ -97,8 +95,13 @@ def multi_variant_endpoint(sagemaker_session):
9795
sagemaker_session=sagemaker_session,
9896
)
9997

100-
image_uri = get_image_uri(sagemaker_session.boto_session.region_name, "xgboost", "0.90-1")
101-
98+
image_uri = image_uris.retrieve(
99+
"xgboost",
100+
sagemaker_session.boto_region_name,
101+
version="0.90-1",
102+
instance_type=DEFAULT_INSTANCE_TYPE,
103+
image_scope="inference",
104+
)
102105
multi_variant_endpoint_model = sagemaker_session.create_model(
103106
name=MODEL_NAME,
104107
role=ROLE,

0 commit comments

Comments
 (0)