Skip to content

change: make image_scope optional for some images in image_uris.retrieve() #1723

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"scope": ["inference", "training"],
"scope": ["inference"],
"versions": {
"latest": {
"registries": {
Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/image_uri_config/xgboost-neo.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"scope": ["inference", "training"],
"scope": ["inference"],
"versions": {
"latest": {
"registries": {
Expand Down
30 changes: 23 additions & 7 deletions src/sagemaker/image_uris.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,24 @@ def _config_for_framework_and_scope(framework, image_scope, accelerator_type=Non
)
image_scope = "eia"

_validate_arg("image scope", image_scope, config.get("scope", config.keys()))
available_scopes = config.get("scope", config.keys())
if len(available_scopes) == 1:
if image_scope and image_scope != available_scopes[0]:
logger.warning(
"Defaulting to only supported image scope: %s. Ignoring image scope: %s.",
available_scopes[0],
image_scope,
)
image_scope = available_scopes[0]

if not image_scope and "scope" in config and set(available_scopes) == {"training", "inference"}:
logger.info(
"Same images used for training and inference. Defaulting to image scope: %s.",
available_scopes[0],
)
image_scope = available_scopes[0]

_validate_arg(image_scope, available_scopes, "image scope")
return config if "scope" in config else config[image_scope]


Expand All @@ -116,8 +133,7 @@ def _validate_version_and_set_if_needed(version, config, framework):

return available_versions[0]

_validate_arg("{} version".format(framework), version, available_versions + aliased_versions)

_validate_arg(version, available_versions + aliased_versions, "{} version".format(framework))
return version


Expand All @@ -132,7 +148,7 @@ def _version_for_config(version, config):

def _registry_from_region(region, registry_dict):
"""Returns the ECR registry (AWS account number) for the given region."""
_validate_arg("region", region, registry_dict.keys())
_validate_arg(region, registry_dict.keys(), "region")
return registry_dict[region]


Expand All @@ -159,7 +175,7 @@ def _processor(instance_type, available_processors):
family = instance_type.split(".")[1]
processor = "gpu" if family[0] in ("g", "p") else "cpu"

_validate_arg("processor", processor, available_processors)
_validate_arg(processor, available_processors, "processor")
return processor


Expand All @@ -179,11 +195,11 @@ def _validate_py_version_and_set_if_needed(py_version, version_config):
logger.info("Defaulting to only available Python version: %s", available_versions[0])
return available_versions[0]

_validate_arg("Python version", py_version, available_versions)
_validate_arg(py_version, available_versions, "Python version")
return py_version


def _validate_arg(arg_name, arg, available_options):
def _validate_arg(arg, available_options, arg_name):
"""Checks if the arg is in the available options, and raises a ``ValueError`` if not."""
if arg not in available_options:
raise ValueError(
Expand Down
20 changes: 9 additions & 11 deletions tests/unit/sagemaker/image_uris/test_algos.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,21 +170,19 @@ def test_algo_uris(algo):
accounts = _accounts_for_algo(algo)

for region in regions.regions():
for scope in ("training", "inference"):
uri = image_uris.retrieve(algo, region, image_scope=scope)
assert expected_uris.algo_uri(algo, accounts[region], region) == uri
uri = image_uris.retrieve(algo, region)
assert expected_uris.algo_uri(algo, accounts[region], region) == uri


def test_lda():
algo = "lda"
accounts = _accounts_for_algo(algo)

for region in regions.regions():
for scope in ("training", "inference"):
if region in accounts:
uri = image_uris.retrieve(algo, region, image_scope=scope)
assert expected_uris.algo_uri(algo, accounts[region], region) == uri
else:
with pytest.raises(ValueError) as e:
image_uris.retrieve(algo, region, image_scope=scope)
assert "Unsupported region: {}.".format(region) in str(e.value)
if region in accounts:
uri = image_uris.retrieve(algo, region)
assert expected_uris.algo_uri(algo, accounts[region], region) == uri
else:
with pytest.raises(ValueError) as e:
image_uris.retrieve(algo, region)
assert "Unsupported region: {}.".format(region) in str(e.value)
17 changes: 8 additions & 9 deletions tests/unit/sagemaker/image_uris/test_neo.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,11 @@
@pytest.mark.parametrize("algo", ALGO_NAMES)
def test_algo_uris(algo):
for region in regions.regions():
for scope in ("training", "inference"):
if region in NEO_REGION_LIST:
uri = image_uris.retrieve(algo, region, image_scope=scope)
expected = expected_uris.algo_uri(algo, ACCOUNTS[region], region, version="latest")
assert expected == uri
else:
with pytest.raises(ValueError) as e:
image_uris.retrieve(algo, region, image_scope=scope)
assert "Unsupported region: {}.".format(region) in str(e.value)
if region in NEO_REGION_LIST:
uri = image_uris.retrieve(algo, region)
expected = expected_uris.algo_uri(algo, ACCOUNTS[region], region, version="latest")
assert expected == uri
else:
with pytest.raises(ValueError) as e:
image_uris.retrieve(algo, region)
assert "Unsupported region: {}.".format(region) in str(e.value)
42 changes: 42 additions & 0 deletions tests/unit/sagemaker/image_uris/test_retrieve.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,48 @@ def test_retrieve_unsupported_image_scope(config_for_framework):
assert "Unsupported image scope: invalid-image-scope." in str(e.value)
assert "Supported image scope(s): training, inference." in str(e.value)

config = copy.deepcopy(BASE_CONFIG)
config["scope"].append("eia")
config_for_framework.return_value = config

with pytest.raises(ValueError) as e:
image_uris.retrieve(
framework="useless-string",
version="1.0.0",
py_version="py3",
instance_type="ml.c4.xlarge",
region="us-west-2",
)
assert "Unsupported image scope: None." in str(e.value)
assert "Supported image scope(s): training, inference, eia." in str(e.value)


@patch("sagemaker.image_uris.config_for_framework", return_value=BASE_CONFIG)
def test_retrieve_default_image_scope(config_for_framework, caplog):
uri = image_uris.retrieve(
framework="useless-string",
version="1.0.0",
py_version="py3",
instance_type="ml.c4.xlarge",
region="us-west-2",
)
assert "123412341234.dkr.ecr.us-west-2.amazonaws.com/dummy:1.0.0-cpu-py3" == uri

config = copy.deepcopy(BASE_CONFIG)
config["scope"] = ["eia"]
config_for_framework.return_value = config

uri = image_uris.retrieve(
framework="useless-string",
version="1.0.0",
py_version="py3",
instance_type="ml.c4.xlarge",
region="us-west-2",
image_scope="ignorable-scope",
)
assert "123412341234.dkr.ecr.us-west-2.amazonaws.com/dummy:1.0.0-cpu-py3" == uri
assert "Ignoring image scope: ignorable-scope." in caplog.text


@patch("sagemaker.image_uris.config_for_framework")
def test_retrieve_eia(config_for_framework, caplog):
Expand Down
45 changes: 20 additions & 25 deletions tests/unit/sagemaker/image_uris/test_xgboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,35 +71,30 @@

def test_xgboost_framework(xgboost_framework_version):
for region in regions.regions():
for scope in ("training", "inference"):
uri = image_uris.retrieve(
framework="xgboost",
region=region,
version=xgboost_framework_version,
py_version="py3",
instance_type="ml.c4.xlarge",
image_scope=scope,
)
uri = image_uris.retrieve(
framework="xgboost",
region=region,
version=xgboost_framework_version,
py_version="py3",
instance_type="ml.c4.xlarge",
)

expected = expected_uris.framework_uri(
"sagemaker-xgboost",
xgboost_framework_version,
FRAMEWORK_REGISTRIES[region],
py_version="py3",
region=region,
)
assert expected == uri
expected = expected_uris.framework_uri(
"sagemaker-xgboost",
xgboost_framework_version,
FRAMEWORK_REGISTRIES[region],
py_version="py3",
region=region,
)
assert expected == uri


@pytest.mark.parametrize("xgboost_algo_version", ("1", "latest"))
def test_xgboost_algo(xgboost_algo_version):
for region in regions.regions():
for scope in ("training", "inference"):
uri = image_uris.retrieve(
framework="xgboost", region=region, version=xgboost_algo_version, image_scope=scope,
)
uri = image_uris.retrieve(framework="xgboost", region=region, version=xgboost_algo_version)

expected = expected_uris.algo_uri(
"xgboost", ALGO_REGISTRIES[region], region, version=xgboost_algo_version
)
assert expected == uri
expected = expected_uris.algo_uri(
"xgboost", ALGO_REGISTRIES[region], region, version=xgboost_algo_version
)
assert expected == uri