Skip to content

Commit 75466ce

Browse files
icywang86ruiknakad
authored andcommitted
feature: use deep learning images (aws#883)
1 parent 7a45955 commit 75466ce

File tree

4 files changed

+130
-16
lines changed

4 files changed

+130
-16
lines changed

src/sagemaker/fw_utils.py

+68-4
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,56 @@
5454
VALID_EIA_FRAMEWORKS = ["tensorflow", "tensorflow-serving", "mxnet", "mxnet-serving"]
5555
VALID_ACCOUNTS_BY_REGION = {"us-gov-west-1": "246785580436", "us-iso-east-1": "744548109606"}
5656

57+
MERGED_FRAMEWORKS_REPO_MAP = {
58+
"tensorflow-scriptmode": "tensorflow-training",
59+
"mxnet": "mxnet-training",
60+
"tensorflow-serving": "tensorflow-inference",
61+
"mxnet-serving": "mxnet-inference",
62+
}
63+
64+
MERGED_FRAMEWORKS_LOWEST_VERSIONS = {
65+
"tensorflow-scriptmode": [1, 13, 1],
66+
"mxnet": [1, 4, 1],
67+
"tensorflow-serving": [1, 13, 0],
68+
"mxnet-serving": [1, 4, 1],
69+
}
70+
71+
72+
def is_version_equal_or_higher(lowest_version, framework_version):
73+
"""Determine whether the ``framework_version`` is equal to or higher than ``lowest_version``
74+
75+
Args:
76+
lowest_version (List[int]): lowest version represented in an integer list
77+
framework_version (str): framework version string
78+
79+
Returns:
80+
bool: Whether or not framework_version is equal to or higher than lowest_version
81+
"""
82+
version_list = [int(s) for s in framework_version.split(".")]
83+
return version_list >= lowest_version[0 : len(version_list)]
84+
85+
86+
def _is_merged_versions(framework, framework_version):
87+
lowest_version_list = MERGED_FRAMEWORKS_LOWEST_VERSIONS.get(framework)
88+
if lowest_version_list:
89+
return is_version_equal_or_higher(lowest_version_list, framework_version)
90+
else:
91+
return False
92+
93+
94+
def _using_merged_images(region, framework, py_version, accelerator_type, framework_version):
95+
is_gov_region = region in VALID_ACCOUNTS_BY_REGION
96+
is_py3 = py_version == "py3" or py_version is None
97+
is_merged_versions = _is_merged_versions(framework, framework_version)
98+
return (not is_gov_region) and is_merged_versions and is_py3 and accelerator_type is None
99+
100+
101+
def _registry_id(region, framework, py_version, account, accelerator_type, framework_version):
102+
if _using_merged_images(region, framework, py_version, accelerator_type, framework_version):
103+
return "763104351884"
104+
else:
105+
return VALID_ACCOUNTS_BY_REGION.get(region, account)
106+
57107

58108
def create_image_uri(
59109
region,
@@ -86,8 +136,15 @@ def create_image_uri(
86136
if py_version and py_version not in VALID_PY_VERSIONS:
87137
raise ValueError("invalid py_version argument: {}".format(py_version))
88138

89-
# Handle Account Number for Gov Cloud
90-
account = VALID_ACCOUNTS_BY_REGION.get(region, account)
139+
# Handle Account Number for Gov Cloud and frameworks with DLC merged images
140+
account = _registry_id(
141+
region=region,
142+
framework=framework,
143+
py_version=py_version,
144+
account=account,
145+
accelerator_type=accelerator_type,
146+
framework_version=framework_version,
147+
)
91148

92149
# Handle Local Mode
93150
if instance_type.startswith("local"):
@@ -121,7 +178,14 @@ def create_image_uri(
121178
):
122179
framework += "-eia"
123180

124-
return "{}/sagemaker-{}:{}".format(get_ecr_image_uri_prefix(account, region), framework, tag)
181+
if _using_merged_images(region, framework, py_version, accelerator_type, framework_version):
182+
return "{}/{}:{}".format(
183+
get_ecr_image_uri_prefix(account, region), MERGED_FRAMEWORKS_REPO_MAP[framework], tag
184+
)
185+
else:
186+
return "{}/sagemaker-{}:{}".format(
187+
get_ecr_image_uri_prefix(account, region), framework, tag
188+
)
125189

126190

127191
def _accelerator_type_valid_for_framework(
@@ -264,7 +328,7 @@ def framework_name_from_image(image_name):
264328
# extract framework, python version and image tag
265329
# We must support both the legacy and current image name format.
266330
name_pattern = re.compile(
267-
r"^sagemaker(?:-rl)?-(tensorflow|mxnet|chainer|pytorch|scikit-learn)(?:-)?(scriptmode)?:(.*)-(.*?)-(py2|py3)$" # noqa: E501
331+
r"^(?:sagemaker(?:-rl)?-)?(tensorflow|mxnet|chainer|pytorch|scikit-learn)(?:-)?(scriptmode|training)?:(.*)-(.*?)-(py2|py3)$" # noqa: E501
268332
)
269333
legacy_name_pattern = re.compile(r"^sagemaker-(tensorflow|mxnet)-(py2|py3)-(cpu|gpu):(.*)$")
270334

tests/integ/test_tf_script_mode.py

+4
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def test_mnist(sagemaker_session, instance_type):
6565
sagemaker_session=sagemaker_session,
6666
script_mode=True,
6767
framework_version=TensorFlow.LATEST_VERSION,
68+
py_version=tests.integ.PYTHON_VERSION,
6869
metric_definitions=[{"Name": "train:global_steps", "Regex": r"global_step\/sec:\s(.*)"}],
6970
)
7071
inputs = estimator.sagemaker_session.upload_data(
@@ -98,6 +99,7 @@ def test_server_side_encryption(sagemaker_session):
9899
sagemaker_session=sagemaker_session,
99100
script_mode=True,
100101
framework_version=TensorFlow.LATEST_VERSION,
102+
py_version=tests.integ.PYTHON_VERSION,
101103
code_location=output_path,
102104
output_path=output_path,
103105
model_dir="/opt/ml/model",
@@ -144,6 +146,7 @@ def test_mnist_async(sagemaker_session):
144146
role=ROLE,
145147
train_instance_count=1,
146148
train_instance_type="ml.c5.4xlarge",
149+
py_version=tests.integ.PYTHON_VERSION,
147150
sagemaker_session=sagemaker_session,
148151
script_mode=True,
149152
framework_version=TensorFlow.LATEST_VERSION,
@@ -182,6 +185,7 @@ def test_deploy_with_input_handlers(sagemaker_session, instance_type):
182185
role=ROLE,
183186
train_instance_count=1,
184187
train_instance_type=instance_type,
188+
py_version=tests.integ.PYTHON_VERSION,
185189
sagemaker_session=sagemaker_session,
186190
script_mode=True,
187191
framework_version=TensorFlow.LATEST_VERSION,

tests/unit/test_fw_utils.py

+57-11
Original file line numberDiff line numberDiff line change
@@ -136,18 +136,59 @@ def test_create_image_uri_gov_cloud():
136136
)
137137

138138

139+
def test_create_image_uri_merged():
140+
image_uri = fw_utils.create_image_uri(
141+
"us-west-2", "tensorflow-scriptmode", "ml.p3.2xlarge", "1.13.1", "py3"
142+
)
143+
assert (
144+
image_uri
145+
== "763104351884.dkr.ecr.us-west-2.amazonaws.com/tensorflow-training:1.13.1-gpu-py3"
146+
)
147+
148+
image_uri = fw_utils.create_image_uri(
149+
"us-west-2", "tensorflow-serving", "ml.c4.2xlarge", "1.13.1"
150+
)
151+
assert (
152+
image_uri == "763104351884.dkr.ecr.us-west-2.amazonaws.com/tensorflow-inference:1.13.1-cpu"
153+
)
154+
155+
image_uri = fw_utils.create_image_uri("us-west-2", "mxnet", "ml.p3.2xlarge", "1.4.1", "py3")
156+
assert image_uri == "763104351884.dkr.ecr.us-west-2.amazonaws.com/mxnet-training:1.4.1-gpu-py3"
157+
158+
image_uri = fw_utils.create_image_uri(
159+
"us-west-2", "mxnet-serving", "ml.c4.2xlarge", "1.4.1", "py3"
160+
)
161+
assert image_uri == "763104351884.dkr.ecr.us-west-2.amazonaws.com/mxnet-inference:1.4.1-cpu-py3"
162+
163+
164+
def test_create_image_uri_merged_py2():
165+
image_uri = fw_utils.create_image_uri(
166+
"us-west-2", "tensorflow-scriptmode", "ml.p3.2xlarge", "1.13.1", "py2"
167+
)
168+
assert (
169+
image_uri
170+
== "520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-tensorflow-scriptmode:1.13.1-gpu-py2"
171+
)
172+
173+
image_uri = fw_utils.create_image_uri("us-west-2", "mxnet", "ml.p3.2xlarge", "1.4.1", "py2")
174+
assert image_uri == "520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-mxnet:1.4.1-gpu-py2"
175+
176+
image_uri = fw_utils.create_image_uri(
177+
"us-west-2", "mxnet-serving", "ml.c4.2xlarge", "1.4.1", "py2"
178+
)
179+
assert (
180+
image_uri
181+
== "520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-mxnet-serving:1.4.1-cpu-py2"
182+
)
183+
184+
139185
def test_create_image_uri_accelerator_tf():
140186
image_uri = fw_utils.create_image_uri(
141-
MOCK_REGION,
142-
"tensorflow",
143-
"ml.p3.2xlarge",
144-
"1.0rc",
145-
"py3",
146-
accelerator_type="ml.eia1.medium",
187+
MOCK_REGION, "tensorflow", "ml.p3.2xlarge", "1.0", "py3", accelerator_type="ml.eia1.medium"
147188
)
148189
assert (
149190
image_uri
150-
== "520713654638.dkr.ecr.mars-south-3.amazonaws.com/sagemaker-tensorflow-eia:1.0rc-gpu-py3"
191+
== "520713654638.dkr.ecr.mars-south-3.amazonaws.com/sagemaker-tensorflow-eia:1.0-gpu-py3"
151192
)
152193

153194

@@ -156,13 +197,13 @@ def test_create_image_uri_accelerator_mxnet_serving():
156197
MOCK_REGION,
157198
"mxnet-serving",
158199
"ml.p3.2xlarge",
159-
"1.0rc",
200+
"1.0",
160201
"py3",
161202
accelerator_type="ml.eia1.medium",
162203
)
163204
assert (
164205
image_uri
165-
== "520713654638.dkr.ecr.mars-south-3.amazonaws.com/sagemaker-mxnet-serving-eia:1.0rc-gpu-py3"
206+
== "520713654638.dkr.ecr.mars-south-3.amazonaws.com/sagemaker-mxnet-serving-eia:1.0-gpu-py3"
166207
)
167208

168209

@@ -171,13 +212,13 @@ def test_create_image_uri_local_sagemaker_notebook_accelerator():
171212
MOCK_REGION,
172213
"mxnet",
173214
"ml.p3.2xlarge",
174-
"1.0rc",
215+
"1.0",
175216
"py3",
176217
accelerator_type="local_sagemaker_notebook",
177218
)
178219
assert (
179220
image_uri
180-
== "520713654638.dkr.ecr.mars-south-3.amazonaws.com/sagemaker-mxnet-eia:1.0rc-gpu-py3"
221+
== "520713654638.dkr.ecr.mars-south-3.amazonaws.com/sagemaker-mxnet-eia:1.0-gpu-py3"
181222
)
182223

183224

@@ -555,6 +596,11 @@ def test_framework_name_from_image_tf_scriptmode():
555596
"scriptmode",
556597
) == fw_utils.framework_name_from_image(image_name)
557598

599+
image_name = "123.dkr.ecr.us-west-2.amazonaws.com/tensorflow-training:1.13-cpu-py3"
600+
assert ("tensorflow", "py3", "1.13-cpu-py3", "training") == fw_utils.framework_name_from_image(
601+
image_name
602+
)
603+
558604

559605
def test_framework_name_from_image_rl():
560606
image_name = "123.dkr.ecr.us-west-2.amazonaws.com/sagemaker-rl-mxnet:toolkit1.1-gpu-py3"

tests/unit/test_tf_estimator.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -924,7 +924,7 @@ def test_script_mode_tensorboard(
924924
sagemaker_session=sagemaker_session,
925925
train_instance_count=INSTANCE_COUNT,
926926
train_instance_type=INSTANCE_TYPE,
927-
framework_version="some_version",
927+
framework_version="1.0",
928928
script_mode=True,
929929
)
930930
popen().poll.return_value = None

0 commit comments

Comments
 (0)