Skip to content

Commit 17610e0

Browse files
authored
Merge branch 'zwei' into attach-log
2 parents c5dafc5 + 413d05a commit 17610e0

File tree

7 files changed

+1712
-167
lines changed

7 files changed

+1712
-167
lines changed

src/sagemaker/image_uri_config/chainer.json

+1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
{
22
"processors": ["cpu", "gpu"],
3+
"scope": ["inference", "training"],
34
"version_aliases": {
45
"4.0": "4.0.0",
56
"4.1": "4.1.0",

src/sagemaker/image_uri_config/tensorflow.json

+1,207
Large diffs are not rendered by default.

src/sagemaker/image_uris.py

+70-38
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,25 @@
1414
from __future__ import absolute_import
1515

1616
import json
17+
import logging
1718
import os
1819

1920
from sagemaker import utils
2021

22+
logger = logging.getLogger(__name__)
23+
2124
ECR_URI_TEMPLATE = "{registry}.dkr.{hostname}/{repository}:{tag}"
2225

2326

24-
def retrieve(framework, region, version=None, py_version=None, instance_type=None):
27+
def retrieve(
28+
framework,
29+
region,
30+
version=None,
31+
py_version=None,
32+
instance_type=None,
33+
accelerator_type=None,
34+
image_scope=None,
35+
):
2536
"""Retrieves the ECR URI for the Docker image matching the given arguments.
2637
2738
Args:
@@ -34,28 +45,48 @@ def retrieve(framework, region, version=None, py_version=None, instance_type=Non
3445
instance_type (str): The SageMaker instance type. For supported types, see
3546
https://aws.amazon.com/sagemaker/pricing/instance-types. This is required if
3647
there are different images for different processor types.
48+
accelerator_type (str): Elastic Inference accelerator type. For more, see
49+
https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html.
50+
image_scope (str): The image type, i.e. what it is used for.
51+
Valid values: "training", "inference", "eia". If ``accelerator_type`` is set,
52+
``image_scope`` is ignored.
3753
3854
Returns:
3955
str: the ECR URI for the corresponding SageMaker Docker image.
4056
4157
Raises:
42-
ValueError: If the framework version, Python version, processor type, or region is
43-
not supported given the other arguments.
58+
ValueError: If the combination of arguments specified is not supported.
4459
"""
45-
config = config_for_framework(framework)
60+
config = _config_for_framework_and_scope(framework, image_scope, accelerator_type)
4661
version_config = config["versions"][_version_for_config(version, config, framework)]
4762

63+
py_version = _validate_py_version_and_set_if_needed(py_version, version_config)
64+
version_config = version_config.get(py_version) or version_config
65+
4866
registry = _registry_from_region(region, version_config["registries"])
4967
hostname = utils._botocore_resolver().construct_endpoint("ecr", region)["hostname"]
5068

5169
repo = version_config["repository"]
52-
53-
_validate_py_version(py_version, version_config["py_versions"], framework, version)
54-
tag = "{}-{}-{}".format(version, _processor(instance_type, config["processors"]), py_version)
70+
tag = _format_tag(version, _processor(instance_type, config["processors"]), py_version)
5571

5672
return ECR_URI_TEMPLATE.format(registry=registry, hostname=hostname, repository=repo, tag=tag)
5773

5874

75+
def _config_for_framework_and_scope(framework, image_scope, accelerator_type=None):
76+
"""Loads the JSON config for the given framework and image scope."""
77+
config = config_for_framework(framework)
78+
79+
if accelerator_type:
80+
if image_scope not in ("eia", "inference"):
81+
logger.warning(
82+
"Elastic inference is for inference only. Ignoring image scope: %s.", image_scope
83+
)
84+
image_scope = "eia"
85+
86+
_validate_arg("image scope", image_scope, config.get("scope", config.keys()))
87+
return config if "scope" in config else config[image_scope]
88+
89+
5990
def config_for_framework(framework):
6091
"""Loads the JSON config for the given framework."""
6192
fname = os.path.join(os.path.dirname(__file__), "image_uri_config", "{}.json".format(framework))
@@ -69,27 +100,13 @@ def _version_for_config(version, config, framework):
69100
if version in config["version_aliases"].keys():
70101
return config["version_aliases"][version]
71102

72-
available_versions = config["versions"].keys()
73-
if version in available_versions:
74-
return version
75-
76-
raise ValueError(
77-
"Unsupported {} version: {}. "
78-
"You may need to upgrade your SDK version (pip install -U sagemaker) for newer versions. "
79-
"Supported version(s): {}.".format(framework, version, ", ".join(available_versions))
80-
)
103+
_validate_arg("{} version".format(framework), version, config["versions"].keys())
104+
return version
81105

82106

83107
def _registry_from_region(region, registry_dict):
84108
"""Returns the ECR registry (AWS account number) for the given region."""
85-
available_regions = registry_dict.keys()
86-
if region not in available_regions:
87-
raise ValueError(
88-
"Unsupported region: {}. You may need to upgrade "
89-
"your SDK version (pip install -U sagemaker) for newer regions. "
90-
"Supported region(s): {}.".format(region, ", ".join(available_regions))
91-
)
92-
109+
_validate_arg("region", region, registry_dict.keys())
93110
return registry_dict[region]
94111

95112

@@ -106,22 +123,37 @@ def _processor(instance_type, available_processors):
106123
family = instance_type.split(".")[1]
107124
processor = "gpu" if family[0] in ("g", "p") else "cpu"
108125

109-
if processor in available_processors:
110-
return processor
111-
112-
raise ValueError(
113-
"Unsupported processor type: {} (for {}). "
114-
"Supported type(s): {}.".format(processor, instance_type, ", ".join(available_processors))
115-
)
126+
_validate_arg("processor", processor, available_processors)
127+
return processor
116128

117129

118-
def _validate_py_version(py_version, available_versions, framework, fw_version):
130+
def _validate_py_version_and_set_if_needed(py_version, version_config):
119131
"""Checks if the Python version is one of the supported versions."""
120-
if py_version not in available_versions:
132+
available_versions = version_config.get("py_versions", version_config.keys())
133+
134+
if len(available_versions) == 0:
135+
if py_version:
136+
logger.info("Ignoring unnecessary Python version: %s.", py_version)
137+
return None
138+
139+
if py_version is None and len(available_versions) == 1:
140+
logger.info("Defaulting to only available Python version: %s", available_versions[0])
141+
return available_versions[0]
142+
143+
_validate_arg("Python version", py_version, available_versions)
144+
return py_version
145+
146+
147+
def _validate_arg(arg_name, arg, available_options):
148+
"""Checks if the arg is in the available options, and raises a ``ValueError`` if not."""
149+
if arg not in available_options:
121150
raise ValueError(
122-
"Unsupported Python version for {} {}: {}. You may need to upgrade "
123-
"your SDK version (pip install -U sagemaker) for newer versions. "
124-
"Supported Python version(s): {}.".format(
125-
framework, fw_version, py_version, ", ".join(available_versions)
126-
)
151+
"Unsupported {arg_name}: {arg}. You may need to upgrade your SDK version "
152+
"(pip install -U sagemaker) for newer {arg_name}s. Supported {arg_name}(s): "
153+
"{options}.".format(arg_name=arg_name, arg=arg, options=", ".join(available_options))
127154
)
155+
156+
157+
def _format_tag(version, processor, py_version):
158+
"""Creates a tag for the image URI."""
159+
return "-".join([x for x in (version, processor, py_version) if x])

tests/conftest.py

+29-48
Original file line numberDiff line numberDiff line change
@@ -164,50 +164,20 @@ def xgboost_version(request):
164164
return request.param
165165

166166

167-
@pytest.fixture(
168-
scope="module",
169-
params=[
170-
"1.4",
171-
"1.4.1",
172-
"1.5",
173-
"1.5.0",
174-
"1.6",
175-
"1.6.0",
176-
"1.7",
177-
"1.7.0",
178-
"1.8",
179-
"1.8.0",
180-
"1.9",
181-
"1.9.0",
182-
"1.10",
183-
"1.10.0",
184-
"1.11",
185-
"1.11.0",
186-
"1.12",
187-
"1.12.0",
188-
"1.13",
189-
"1.14",
190-
"1.14.0",
191-
"1.15",
192-
"1.15.0",
193-
"1.15.2",
194-
"2.0",
195-
"2.0.0",
196-
"2.0.1",
197-
"2.1",
198-
"2.1.0",
199-
],
200-
)
201-
def tf_version(request):
202-
return request.param
167+
@pytest.fixture(scope="module")
168+
def tf_version(tensorflow_training_version):
169+
# TODO: remove this fixture and update tests
170+
if tensorflow_training_version in ("1.13.1", "2.2", "2.2.0"):
171+
pytest.skip("version isn't compatible with both training and inference.")
172+
return tensorflow_training_version
203173

204174

205175
@pytest.fixture(scope="module", params=["py2", "py3"])
206-
def tf_py_version(tf_version, request):
207-
version = [int(val) for val in tf_version.split(".")]
208-
if version < [1, 11]:
176+
def tf_py_version(tensorflow_training_version, request):
177+
version = Version(tensorflow_training_version)
178+
if version < Version("1.11"):
209179
return "py2"
210-
if version < [2, 2]:
180+
if version < Version("2.2"):
211181
return request.param
212182
return "py37"
213183

@@ -401,11 +371,22 @@ def pytest_generate_tests(metafunc):
401371
params.append("ml.p2.xlarge")
402372
metafunc.parametrize("instance_type", params, scope="session")
403373

404-
for fw in ("chainer",):
405-
fixture_name = "{}_version".format(fw)
406-
if fixture_name in metafunc.fixturenames:
407-
config = image_uris.config_for_framework(fw)
408-
versions = list(config["versions"].keys()) + list(
409-
config.get("version_aliases", {}).keys()
410-
)
411-
metafunc.parametrize(fixture_name, versions, scope="session")
374+
_generate_all_framework_version_fixtures(metafunc)
375+
376+
377+
def _generate_all_framework_version_fixtures(metafunc):
378+
for fw in ("chainer", "tensorflow"):
379+
config = image_uris.config_for_framework(fw)
380+
if "scope" in config:
381+
_parametrize_framework_version_fixture(metafunc, "{}_version".format(fw), config)
382+
else:
383+
for image_scope in config.keys():
384+
_parametrize_framework_version_fixture(
385+
metafunc, "{}_{}_version".format(fw, image_scope), config[image_scope]
386+
)
387+
388+
389+
def _parametrize_framework_version_fixture(metafunc, fixture_name, config):
390+
if fixture_name in metafunc.fixturenames:
391+
versions = list(config["versions"].keys()) + list(config.get("version_aliases", {}).keys())
392+
metafunc.parametrize(fixture_name, versions, scope="session")

0 commit comments

Comments
 (0)