Skip to content

breaking: use image_uris.retrieve() for XGBoost URIs #1719

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jul 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 0 additions & 79 deletions src/sagemaker/amazon/amazon_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,6 @@
from sagemaker.inputs import FileSystemInput, TrainingInput
from sagemaker.model import NEO_IMAGE_ACCOUNT
from sagemaker.utils import sagemaker_timestamp, get_ecr_image_uri_prefix
from sagemaker.xgboost.defaults import (
XGBOOST_1P_VERSIONS,
XGBOOST_LATEST_VERSION,
XGBOOST_NAME,
XGBOOST_SUPPORTED_VERSIONS,
XGBOOST_VERSION_EQUIVALENTS,
)
from sagemaker.xgboost.estimator import get_xgboost_image_uri

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -622,76 +614,5 @@ def get_image_uri(region_name, repo_name, repo_version=1):
"in SageMaker Python SDK v2."
)

repo_version = str(repo_version)

if repo_name == XGBOOST_NAME:

if repo_version in XGBOOST_1P_VERSIONS:
_warn_newer_xgboost_image()
return "{}/{}:{}".format(registry(region_name, repo_name), repo_name, repo_version)

if "-" not in repo_version:
xgboost_version_matches = [
version
for version in XGBOOST_SUPPORTED_VERSIONS
if repo_version == version.split("-")[0]
]
if xgboost_version_matches:
# Assumes that XGBOOST_SUPPORTED_VERSION is sorted from oldest version to latest.
# When SageMaker version is not specified, we use the oldest one that matches
# XGBoost version for backward compatibility.
repo_version = xgboost_version_matches[0]

supported_framework_versions = [
version
for version in XGBOOST_SUPPORTED_VERSIONS
if repo_version in _generate_version_equivalents(version)
]

if not supported_framework_versions:
raise ValueError(
"SageMaker XGBoost version {} is not supported. Supported versions: {}".format(
repo_version, ", ".join(XGBOOST_SUPPORTED_VERSIONS)
)
)

if not _is_latest_xgboost_version(repo_version):
_warn_newer_xgboost_image()

return get_xgboost_image_uri(region_name, supported_framework_versions[-1])

repo = "{}:{}".format(repo_name, repo_version)
return "{}/{}".format(registry(region_name, repo_name), repo)


def _warn_newer_xgboost_image():
"""Print a warning when there is a newer XGBoost image"""
logging.warning(
"There is a more up to date SageMaker XGBoost image. "
"To use the newer image, please set 'repo_version'="
"'%s'. For example:\n"
"\tget_image_uri(region, '%s', '%s').",
XGBOOST_LATEST_VERSION,
XGBOOST_NAME,
XGBOOST_LATEST_VERSION,
)


def _is_latest_xgboost_version(repo_version):
"""Compare xgboost image version with latest version

Args:
repo_version:
"""
if repo_version in XGBOOST_1P_VERSIONS:
return False
return repo_version in _generate_version_equivalents(XGBOOST_LATEST_VERSION)


def _generate_version_equivalents(version):
"""Returns a list of version equivalents for XGBoost

Args:
version:
"""
return [version + suffix for suffix in XGBOOST_VERSION_EQUIVALENTS] + [version]
22 changes: 0 additions & 22 deletions src/sagemaker/fw_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,6 @@
"framework_version is required for script mode estimator. "
"Please add framework_version={} to your constructor to avoid this error."
)
UNSUPPORTED_FRAMEWORK_VERSION_ERROR = (
"{} framework does not support version {}. Please use one of the following: {}."
)

VALID_PY_VERSIONS = ["py2", "py3", "py37"]
VALID_EIA_FRAMEWORKS = [
Expand Down Expand Up @@ -637,25 +634,6 @@ def warn_if_parameter_server_with_multi_gpu(training_instance_type, distribution
logger.warning(PARAMETER_SERVER_MULTI_GPU_WARNING)


def get_unsupported_framework_version_error(
framework_name, unsupported_version, supported_versions
):
"""Return error message for unsupported framework version.

This should also return the supported versions for customers.

:param framework_name:
:param unsupported_version:
:param supported_versions:
:return:
"""
return UNSUPPORTED_FRAMEWORK_VERSION_ERROR.format(
framework_name,
unsupported_version,
", ".join('"{}"'.format(version) for version in supported_versions),
)


def python_deprecation_warning(framework, latest_supported_version):
"""
Args:
Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/xgboost/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Placeholder docstring"""
from sagemaker.xgboost.defaults import XGBOOST_NAME, XGBOOST_LATEST_VERSION # noqa: F401
from sagemaker.xgboost.defaults import XGBOOST_NAME # noqa: F401
from sagemaker.xgboost.estimator import XGBoost # noqa: F401
from sagemaker.xgboost.model import XGBoostModel, XGBoostPredictor # noqa: F401
17 changes: 0 additions & 17 deletions src/sagemaker/xgboost/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,3 @@
from __future__ import absolute_import

XGBOOST_NAME = "xgboost"
XGBOOST_1P_VERSIONS = ["1", "latest"]

XGBOOST_VERSION_0_90_1 = "0.90-1"
XGBOOST_VERSION_0_90_2 = "0.90-2"

XGBOOST_LATEST_VERSION = "1.0-1"

# XGBOOST_SUPPORTED_VERSIONS has XGBoost Framework versions sorted from oldest to latest
XGBOOST_SUPPORTED_VERSIONS = [
XGBOOST_VERSION_0_90_1,
XGBOOST_VERSION_0_90_2,
XGBOOST_LATEST_VERSION,
]

# TODO: evaluate use of this constant. it's used in precisely one place in different a module
# may possibly be unnecessary indirection
XGBOOST_VERSION_EQUIVALENTS = ["-cpu-py3"]
30 changes: 9 additions & 21 deletions src/sagemaker/xgboost/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,11 @@

import logging

from sagemaker import image_uris
from sagemaker.estimator import Framework, _TrainingJob
from sagemaker.fw_registry import default_framework_uri
from sagemaker.fw_utils import (
framework_name_from_image,
framework_version_from_tag,
get_unsupported_framework_version_error,
UploadedCode,
)
from sagemaker.session import Session
Expand All @@ -31,12 +30,6 @@
logger = logging.getLogger("sagemaker")


def get_xgboost_image_uri(region, framework_version, py_version="py3"):
"""Get XGBoost framework image URI"""
image_tag = "{}-{}-{}".format(framework_version, "cpu", py_version)
return default_framework_uri(XGBoost.__framework_name__, region, image_tag)


class XGBoost(Framework):
"""Handle end-to-end training and deployment of XGBoost booster training or training using
customer provided XGBoost entry point script."""
Expand Down Expand Up @@ -105,22 +98,17 @@ def __init__(
entry_point, source_dir, hyperparameters, image_uri=image_uri, **kwargs
)

if py_version == "py2":
raise AttributeError("XGBoost container does not support Python 2, please use Python 3")
self.py_version = py_version

if framework_version in defaults.XGBOOST_SUPPORTED_VERSIONS:
self.framework_version = framework_version
else:
raise ValueError(
get_unsupported_framework_version_error(
self.__framework_name__, framework_version, defaults.XGBOOST_SUPPORTED_VERSIONS
)
)
self.framework_version = framework_version

if image_uri is None:
self.image_uri = get_xgboost_image_uri(
self.sagemaker_session.boto_region_name, framework_version
self.image_uri = image_uris.retrieve(
self.__framework_name__,
self.sagemaker_session.boto_region_name,
version=framework_version,
py_version=self.py_version,
instance_type=kwargs.get("instance_type"),
image_scope="training",
)

def create_model(
Expand Down
21 changes: 11 additions & 10 deletions src/sagemaker/xgboost/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
import logging

import sagemaker
from sagemaker import image_uris
from sagemaker.deserializers import CSVDeserializer
from sagemaker.fw_utils import model_code_key_prefix
from sagemaker.fw_registry import default_framework_uri
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
from sagemaker.predictor import Predictor
from sagemaker.serializers import NumpySerializer
Expand Down Expand Up @@ -100,9 +100,6 @@ def __init__(
model_data, image_uri, role, entry_point, predictor_cls=predictor_cls, **kwargs
)

if py_version == "py2":
raise AttributeError("XGBoost container does not support Python 2, please use Python 3")

self.py_version = py_version
self.framework_version = framework_version
self.model_server_workers = model_server_workers
Expand Down Expand Up @@ -136,17 +133,21 @@ def prepare_container_def(self, instance_type=None, accelerator_type=None):
deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str(self.model_server_workers)
return sagemaker.container_def(deploy_image, self.model_data, deploy_env)

def serving_image_uri(self, region_name, instance_type): # pylint: disable=unused-argument
def serving_image_uri(self, region_name, instance_type):
"""Create a URI for the serving image.

Args:
region_name (str): AWS region where the image is uploaded.
instance_type (str): SageMaker instance type. This parameter is unused because
XGBoost supports only CPU.
instance_type (str): SageMaker instance type. Must be a CPU instance type.

Returns:
str: The appropriate image URI based on the given parameters.

"""
image_tag = "{}-{}-{}".format(self.framework_version, "cpu", self.py_version)
return default_framework_uri(self.__framework_name__, region_name, image_tag)
return image_uris.retrieve(
self.__framework_name__,
region_name,
version=self.framework_version,
py_version=self.py_version,
instance_type=instance_type,
image_scope="inference",
)
10 changes: 0 additions & 10 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,16 +242,6 @@ def tf_full_py_version(tf_full_version):
return "py37"


@pytest.fixture(scope="module")
def xgboost_full_version():
return "1.0-1"


@pytest.fixture(scope="module")
def xgboost_full_py_version():
return "py3"


@pytest.fixture(scope="session")
def cpu_instance_type(sagemaker_session, request):
region = sagemaker_session.boto_session.region_name
Expand Down
6 changes: 3 additions & 3 deletions tests/integ/test_airflow_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,13 +552,13 @@ def test_tf_airflow_config_uploads_data_source_to_s3(

@pytest.mark.canary_quick
def test_xgboost_airflow_config_uploads_data_source_to_s3(
sagemaker_session, cpu_instance_type, xgboost_full_version, xgboost_full_py_version
sagemaker_session, cpu_instance_type, xgboost_latest_version, xgboost_latest_py_version
):
with timeout(seconds=AIRFLOW_CONFIG_TIMEOUT_IN_SECONDS):
xgboost = XGBoost(
entry_point=os.path.join(DATA_DIR, "dummy_script.py"),
framework_version=xgboost_full_version,
py_version=xgboost_full_py_version,
framework_version=xgboost_latest_version,
py_version=xgboost_latest_py_version,
role=ROLE,
sagemaker_session=sagemaker_session,
instance_type=cpu_instance_type,
Expand Down
14 changes: 10 additions & 4 deletions tests/integ/test_inference_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
timeout_and_delete_model_with_transformer,
)

from sagemaker.amazon.amazon_estimator import get_image_uri
from sagemaker import image_uris
from sagemaker.content_types import CONTENT_TYPE_CSV
from sagemaker.model import Model
from sagemaker.pipeline import PipelineModel
Expand Down Expand Up @@ -66,7 +66,9 @@ def test_inference_pipeline_batch_transform(sagemaker_session, cpu_instance_type
env={"SAGEMAKER_SPARKML_SCHEMA": SCHEMA},
sagemaker_session=sagemaker_session,
)
xgb_image = get_image_uri(sagemaker_session.boto_region_name, "xgboost")
xgb_image = image_uris.retrieve(
"xgboost", sagemaker_session.boto_region_name, version="1", image_scope="inference"
)
xgb_model = Model(
model_data=xgb_model_data, image_uri=xgb_image, sagemaker_session=sagemaker_session
)
Expand Down Expand Up @@ -115,7 +117,9 @@ def test_inference_pipeline_model_deploy(sagemaker_session, cpu_instance_type):
env={"SAGEMAKER_SPARKML_SCHEMA": SCHEMA},
sagemaker_session=sagemaker_session,
)
xgb_image = get_image_uri(sagemaker_session.boto_region_name, "xgboost")
xgb_image = image_uris.retrieve(
"xgboost", sagemaker_session.boto_region_name, version="1", image_scope="inference"
)
xgb_model = Model(
model_data=xgb_model_data, image_uri=xgb_image, sagemaker_session=sagemaker_session
)
Expand Down Expand Up @@ -169,7 +173,9 @@ def test_inference_pipeline_model_deploy_and_update_endpoint(
env={"SAGEMAKER_SPARKML_SCHEMA": SCHEMA},
sagemaker_session=sagemaker_session,
)
xgb_image = get_image_uri(sagemaker_session.boto_region_name, "xgboost")
xgb_image = image_uris.retrieve(
"xgboost", sagemaker_session.boto_region_name, version="1", image_scope="inference"
)
xgb_model = Model(
model_data=xgb_model_data, image_uri=xgb_image, sagemaker_session=sagemaker_session
)
Expand Down
19 changes: 11 additions & 8 deletions tests/integ/test_multi_variant_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,27 +13,25 @@
from __future__ import absolute_import

import json
import os
import math
import os

import pytest
import scipy.stats as st

from sagemaker import image_uris
from sagemaker.s3 import S3Uploader
from sagemaker.session import production_variant
from sagemaker.sparkml import SparkMLModel
from sagemaker.utils import sagemaker_timestamp
from sagemaker.content_types import CONTENT_TYPE_CSV
from sagemaker.utils import unique_name_from_base
from sagemaker.amazon.amazon_estimator import get_image_uri
from sagemaker.predictor import Predictor
from sagemaker.serializers import CSVSerializer


import tests.integ


ROLE = "SageMakerRole"
MODEL_NAME = "test-xgboost-model-{}".format(sagemaker_timestamp())
MODEL_NAME = unique_name_from_base("test-xgboost-model")
DEFAULT_REGION = "us-west-2"
DEFAULT_INSTANCE_TYPE = "ml.m5.xlarge"
DEFAULT_INSTANCE_COUNT = 1
Expand Down Expand Up @@ -97,8 +95,13 @@ def multi_variant_endpoint(sagemaker_session):
sagemaker_session=sagemaker_session,
)

image_uri = get_image_uri(sagemaker_session.boto_session.region_name, "xgboost", "0.90-1")

image_uri = image_uris.retrieve(
"xgboost",
sagemaker_session.boto_region_name,
version="0.90-1",
instance_type=DEFAULT_INSTANCE_TYPE,
image_scope="inference",
)
multi_variant_endpoint_model = sagemaker_session.create_model(
name=MODEL_NAME,
role=ROLE,
Expand Down
Loading