Skip to content

Commit fdd3fc8

Browse files
committed
change: make image_scope optional for some images in image_uris.retrieve()
1 parent e4485b7 commit fdd3fc8

File tree

7 files changed

+104
-54
lines changed

7 files changed

+104
-54
lines changed

src/sagemaker/image_uri_config/image-classification-neo.json

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
{
2-
"scope": ["inference", "training"],
2+
"scope": ["inference"],
33
"versions": {
44
"latest": {
55
"registries": {

src/sagemaker/image_uri_config/xgboost-neo.json

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
{
2-
"scope": ["inference", "training"],
2+
"scope": ["inference"],
33
"versions": {
44
"latest": {
55
"registries": {

src/sagemaker/image_uris.py

+23-7
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,24 @@ def _config_for_framework_and_scope(framework, image_scope, accelerator_type=Non
8989
)
9090
image_scope = "eia"
9191

92-
_validate_arg("image scope", image_scope, config.get("scope", config.keys()))
92+
available_scopes = config.get("scope", config.keys())
93+
if len(available_scopes) == 1:
94+
if image_scope and image_scope != available_scopes[0]:
95+
logger.warning(
96+
"Defaulting to only supported image scope: %s. Ignoring image scope: %s.",
97+
available_scopes[0],
98+
image_scope,
99+
)
100+
image_scope = available_scopes[0]
101+
102+
if not image_scope and "scope" in config and set(available_scopes) == {"training", "inference"}:
103+
logger.info(
104+
"Same images used for training and inference. Defaulting to image scope: %s.",
105+
available_scopes[0],
106+
)
107+
image_scope = available_scopes[0]
108+
109+
_validate_arg(image_scope, available_scopes, "image scope")
93110
return config if "scope" in config else config[image_scope]
94111

95112

@@ -116,8 +133,7 @@ def _validate_version_and_set_if_needed(version, config, framework):
116133

117134
return available_versions[0]
118135

119-
_validate_arg("{} version".format(framework), version, available_versions + aliased_versions)
120-
136+
_validate_arg(version, available_versions + aliased_versions, "{} version".format(framework))
121137
return version
122138

123139

@@ -132,7 +148,7 @@ def _version_for_config(version, config):
132148

133149
def _registry_from_region(region, registry_dict):
134150
"""Returns the ECR registry (AWS account number) for the given region."""
135-
_validate_arg("region", region, registry_dict.keys())
151+
_validate_arg(region, registry_dict.keys(), "region")
136152
return registry_dict[region]
137153

138154

@@ -159,7 +175,7 @@ def _processor(instance_type, available_processors):
159175
family = instance_type.split(".")[1]
160176
processor = "gpu" if family[0] in ("g", "p") else "cpu"
161177

162-
_validate_arg("processor", processor, available_processors)
178+
_validate_arg(processor, available_processors, "processor")
163179
return processor
164180

165181

@@ -179,11 +195,11 @@ def _validate_py_version_and_set_if_needed(py_version, version_config):
179195
logger.info("Defaulting to only available Python version: %s", available_versions[0])
180196
return available_versions[0]
181197

182-
_validate_arg("Python version", py_version, available_versions)
198+
_validate_arg(py_version, available_versions, "Python version")
183199
return py_version
184200

185201

186-
def _validate_arg(arg_name, arg, available_options):
202+
def _validate_arg(arg, available_options, arg_name):
187203
"""Checks if the arg is in the available options, and raises a ``ValueError`` if not."""
188204
if arg not in available_options:
189205
raise ValueError(

tests/unit/sagemaker/image_uris/test_algos.py

+9-11
Original file line numberDiff line numberDiff line change
@@ -170,21 +170,19 @@ def test_algo_uris(algo):
170170
accounts = _accounts_for_algo(algo)
171171

172172
for region in regions.regions():
173-
for scope in ("training", "inference"):
174-
uri = image_uris.retrieve(algo, region, image_scope=scope)
175-
assert expected_uris.algo_uri(algo, accounts[region], region) == uri
173+
uri = image_uris.retrieve(algo, region)
174+
assert expected_uris.algo_uri(algo, accounts[region], region) == uri
176175

177176

178177
def test_lda():
179178
algo = "lda"
180179
accounts = _accounts_for_algo(algo)
181180

182181
for region in regions.regions():
183-
for scope in ("training", "inference"):
184-
if region in accounts:
185-
uri = image_uris.retrieve(algo, region, image_scope=scope)
186-
assert expected_uris.algo_uri(algo, accounts[region], region) == uri
187-
else:
188-
with pytest.raises(ValueError) as e:
189-
image_uris.retrieve(algo, region, image_scope=scope)
190-
assert "Unsupported region: {}.".format(region) in str(e.value)
182+
if region in accounts:
183+
uri = image_uris.retrieve(algo, region)
184+
assert expected_uris.algo_uri(algo, accounts[region], region) == uri
185+
else:
186+
with pytest.raises(ValueError) as e:
187+
image_uris.retrieve(algo, region)
188+
assert "Unsupported region: {}.".format(region) in str(e.value)

tests/unit/sagemaker/image_uris/test_neo.py

+8-9
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,11 @@
4949
@pytest.mark.parametrize("algo", ALGO_NAMES)
5050
def test_algo_uris(algo):
5151
for region in regions.regions():
52-
for scope in ("training", "inference"):
53-
if region in NEO_REGION_LIST:
54-
uri = image_uris.retrieve(algo, region, image_scope=scope)
55-
expected = expected_uris.algo_uri(algo, ACCOUNTS[region], region, version="latest")
56-
assert expected == uri
57-
else:
58-
with pytest.raises(ValueError) as e:
59-
image_uris.retrieve(algo, region, image_scope=scope)
60-
assert "Unsupported region: {}.".format(region) in str(e.value)
52+
if region in NEO_REGION_LIST:
53+
uri = image_uris.retrieve(algo, region)
54+
expected = expected_uris.algo_uri(algo, ACCOUNTS[region], region, version="latest")
55+
assert expected == uri
56+
else:
57+
with pytest.raises(ValueError) as e:
58+
image_uris.retrieve(algo, region)
59+
assert "Unsupported region: {}.".format(region) in str(e.value)

tests/unit/sagemaker/image_uris/test_retrieve.py

+42
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,48 @@ def test_retrieve_unsupported_image_scope(config_for_framework):
6565
assert "Unsupported image scope: invalid-image-scope." in str(e.value)
6666
assert "Supported image scope(s): training, inference." in str(e.value)
6767

68+
config = copy.deepcopy(BASE_CONFIG)
69+
config["scope"].append("eia")
70+
config_for_framework.return_value = config
71+
72+
with pytest.raises(ValueError) as e:
73+
image_uris.retrieve(
74+
framework="useless-string",
75+
version="1.0.0",
76+
py_version="py3",
77+
instance_type="ml.c4.xlarge",
78+
region="us-west-2",
79+
)
80+
assert "Unsupported image scope: None." in str(e.value)
81+
assert "Supported image scope(s): training, inference, eia." in str(e.value)
82+
83+
84+
@patch("sagemaker.image_uris.config_for_framework", return_value=BASE_CONFIG)
85+
def test_retrieve_default_image_scope(config_for_framework, caplog):
86+
uri = image_uris.retrieve(
87+
framework="useless-string",
88+
version="1.0.0",
89+
py_version="py3",
90+
instance_type="ml.c4.xlarge",
91+
region="us-west-2",
92+
)
93+
assert "123412341234.dkr.ecr.us-west-2.amazonaws.com/dummy:1.0.0-cpu-py3" == uri
94+
95+
config = copy.deepcopy(BASE_CONFIG)
96+
config["scope"] = ["eia"]
97+
config_for_framework.return_value = config
98+
99+
uri = image_uris.retrieve(
100+
framework="useless-string",
101+
version="1.0.0",
102+
py_version="py3",
103+
instance_type="ml.c4.xlarge",
104+
region="us-west-2",
105+
image_scope="ignorable-scope",
106+
)
107+
assert "123412341234.dkr.ecr.us-west-2.amazonaws.com/dummy:1.0.0-cpu-py3" == uri
108+
assert "Ignoring image scope: ignorable-scope." in caplog.text
109+
68110

69111
@patch("sagemaker.image_uris.config_for_framework")
70112
def test_retrieve_eia(config_for_framework, caplog):

tests/unit/sagemaker/image_uris/test_xgboost.py

+20-25
Original file line numberDiff line numberDiff line change
@@ -71,35 +71,30 @@
7171

7272
def test_xgboost_framework(xgboost_framework_version):
7373
for region in regions.regions():
74-
for scope in ("training", "inference"):
75-
uri = image_uris.retrieve(
76-
framework="xgboost",
77-
region=region,
78-
version=xgboost_framework_version,
79-
py_version="py3",
80-
instance_type="ml.c4.xlarge",
81-
image_scope=scope,
82-
)
74+
uri = image_uris.retrieve(
75+
framework="xgboost",
76+
region=region,
77+
version=xgboost_framework_version,
78+
py_version="py3",
79+
instance_type="ml.c4.xlarge",
80+
)
8381

84-
expected = expected_uris.framework_uri(
85-
"sagemaker-xgboost",
86-
xgboost_framework_version,
87-
FRAMEWORK_REGISTRIES[region],
88-
py_version="py3",
89-
region=region,
90-
)
91-
assert expected == uri
82+
expected = expected_uris.framework_uri(
83+
"sagemaker-xgboost",
84+
xgboost_framework_version,
85+
FRAMEWORK_REGISTRIES[region],
86+
py_version="py3",
87+
region=region,
88+
)
89+
assert expected == uri
9290

9391

9492
@pytest.mark.parametrize("xgboost_algo_version", ("1", "latest"))
9593
def test_xgboost_algo(xgboost_algo_version):
9694
for region in regions.regions():
97-
for scope in ("training", "inference"):
98-
uri = image_uris.retrieve(
99-
framework="xgboost", region=region, version=xgboost_algo_version, image_scope=scope,
100-
)
95+
uri = image_uris.retrieve(framework="xgboost", region=region, version=xgboost_algo_version)
10196

102-
expected = expected_uris.algo_uri(
103-
"xgboost", ALGO_REGISTRIES[region], region, version=xgboost_algo_version
104-
)
105-
assert expected == uri
97+
expected = expected_uris.algo_uri(
98+
"xgboost", ALGO_REGISTRIES[region], region, version=xgboost_algo_version
99+
)
100+
assert expected == uri

0 commit comments

Comments
 (0)