Skip to content

Commit e6bc154

Browse files
committed
change: use image_uris.retrieve instead of fw_utils.create_image_uri for DLC frameworks
1 parent e4485b7 commit e6bc154

File tree

13 files changed

+125
-131
lines changed

13 files changed

+125
-131
lines changed

src/sagemaker/chainer/model.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
import logging
1717

1818
import sagemaker
19+
from sagemaker import image_uris
1920
from sagemaker.fw_utils import (
20-
create_image_uri,
2121
model_code_key_prefix,
2222
python_deprecation_warning,
2323
validate_version_or_image_args,
@@ -175,11 +175,12 @@ def serving_image_uri(self, region_name, instance_type, accelerator_type=None):
175175
str: The appropriate image URI based on the given parameters.
176176
177177
"""
178-
return create_image_uri(
179-
region_name,
178+
return image_uris.retrieve(
180179
self.__framework_name__,
181-
instance_type,
182-
self.framework_version,
183-
self.py_version,
180+
region_name,
181+
version=self.framework_version,
182+
py_version=self.py_version,
183+
instance_type=instance_type,
184184
accelerator_type=accelerator_type,
185+
image_scope="inference",
185186
)

src/sagemaker/cli/compatibility/v2/modifiers/tf_legacy_mode.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
from sagemaker.cli.compatibility.v2.modifiers import framework_version, matching
2424
from sagemaker.cli.compatibility.v2.modifiers.modifier import Modifier
25-
from sagemaker import fw_utils
25+
from sagemaker import image_uris
2626

2727
TF_NAMESPACES = ("sagemaker.tensorflow", "sagemaker.tensorflow.estimator")
2828
LEGACY_MODE_PARAMETERS = (
@@ -169,9 +169,14 @@ def _image_uri_from_args(self, keywords):
169169
instance_type = kw.value.s if isinstance(kw.value, ast.Str) else None
170170

171171
if tf_version and instance_type:
172-
return fw_utils.create_image_uri(
173-
self.region, "tensorflow", instance_type, tf_version, "py2"
174-
)
172+
return image_uris.retrieve(
173+
"tensorflow",
174+
self.region,
175+
version=tf_version,
176+
py_version="py2",
177+
instance_type=instance_type,
178+
image_scope="training",
179+
).replace("-scriptmode", "")
175180

176181
return None
177182

src/sagemaker/estimator.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,14 @@
2525
from six import string_types
2626
from six.moves.urllib.parse import urlparse
2727
import sagemaker
28-
from sagemaker import git_utils
28+
from sagemaker import git_utils, image_uris
2929
from sagemaker.analytics import TrainingJobAnalytics
3030
from sagemaker.debugger import DebuggerHookConfig
3131
from sagemaker.debugger import TensorBoardOutputConfig # noqa: F401 # pylint: disable=unused-import
3232
from sagemaker.debugger import get_rule_container_image_uri
3333
from sagemaker.s3 import S3Uploader
3434

3535
from sagemaker.fw_utils import (
36-
create_image_uri,
3736
tar_and_upload_dir,
3837
parse_s3_url,
3938
UploadedCode,
@@ -1822,12 +1821,13 @@ def train_image(self):
18221821
"""
18231822
if self.image_uri:
18241823
return self.image_uri
1825-
return create_image_uri(
1826-
self.sagemaker_session.boto_region_name,
1824+
return image_uris.retrieve(
18271825
self.__framework_name__,
1828-
self.instance_type,
1829-
self.framework_version, # pylint: disable=no-member
1826+
self.sagemaker_session.boto_region_name,
1827+
instance_type=self.instance_type,
1828+
version=self.framework_version, # pylint: disable=no-member
18301829
py_version=self.py_version, # pylint: disable=no-member
1830+
image_scope="training",
18311831
)
18321832

18331833
@classmethod

src/sagemaker/mxnet/model.py

+7-10
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@
1818
import packaging.version
1919

2020
import sagemaker
21+
from sagemaker import image_uris
2122
from sagemaker.deserializers import JSONDeserializer
2223
from sagemaker.fw_utils import (
23-
create_image_uri,
2424
model_code_key_prefix,
2525
python_deprecation_warning,
2626
validate_version_or_image_args,
@@ -183,17 +183,14 @@ def serving_image_uri(self, region_name, instance_type, accelerator_type=None):
183183
str: The appropriate image URI based on the given parameters.
184184
185185
"""
186-
framework_name = self.__framework_name__
187-
if self._is_mms_version():
188-
framework_name = "{}-serving".format(framework_name)
189-
190-
return create_image_uri(
186+
return image_uris.retrieve(
187+
self.__framework_name__,
191188
region_name,
192-
framework_name,
193-
instance_type,
194-
self.framework_version,
195-
self.py_version,
189+
version=self.framework_version,
190+
py_version=self.py_version,
191+
instance_type=instance_type,
196192
accelerator_type=accelerator_type,
193+
image_scope="inference",
197194
)
198195

199196
def _is_mms_version(self):

src/sagemaker/pytorch/model.py

+7-10
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@
1717
import packaging.version
1818

1919
import sagemaker
20+
from sagemaker import image_uris
2021
from sagemaker.deserializers import NumpyDeserializer
2122
from sagemaker.fw_utils import (
22-
create_image_uri,
2323
model_code_key_prefix,
2424
python_deprecation_warning,
2525
validate_version_or_image_args,
@@ -182,17 +182,14 @@ def serving_image_uri(self, region_name, instance_type, accelerator_type=None):
182182
str: The appropriate image URI based on the given parameters.
183183
184184
"""
185-
framework_name = self.__framework_name__
186-
if self._is_mms_version():
187-
framework_name = "{}-serving".format(framework_name)
188-
189-
return create_image_uri(
185+
return image_uris.retrieve(
186+
self.__framework_name__,
190187
region_name,
191-
framework_name,
192-
instance_type,
193-
self.framework_version,
194-
self.py_version,
188+
version=self.framework_version,
189+
py_version=self.py_version,
190+
instance_type=instance_type,
195191
accelerator_type=accelerator_type,
192+
image_scope="inference",
196193
)
197194

198195
def _is_mms_version(self):

src/sagemaker/tensorflow/estimator.py

+7-20
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
from packaging import version
2020

21-
from sagemaker import utils
21+
from sagemaker import image_uris, utils
2222
from sagemaker.debugger import DebuggerHookConfig
2323
from sagemaker.estimator import Framework
2424
import sagemaker.fw_utils as fw
@@ -34,7 +34,6 @@ class TensorFlow(Framework):
3434
"""Handle end-to-end training and deployment of user-provided TensorFlow code."""
3535

3636
__framework_name__ = "tensorflow"
37-
_ECR_REPO_NAME = "tensorflow-scriptmode"
3837

3938
_HIGHEST_LEGACY_MODE_ONLY_VERSION = version.Version("1.10.0")
4039
_HIGHEST_PYTHON_2_VERSION = version.Version("2.1.0")
@@ -151,12 +150,13 @@ def _validate_args(self, py_version):
151150
raise AttributeError(msg)
152151

153152
if self.image_uri is None and self._only_legacy_mode_supported():
154-
legacy_image_uri = fw.create_image_uri(
155-
self.sagemaker_session.boto_region_name,
153+
legacy_image_uri = image_uris.retrieve(
156154
"tensorflow",
157-
self.instance_type,
158-
self.framework_version,
159-
self.py_version,
155+
self.sagemaker_session.boto_region_name,
156+
instance_type=self.instance_type,
157+
version=self.framework_version,
158+
py_version=self.py_version,
159+
image_scope="training",
160160
)
161161

162162
msg = (
@@ -355,19 +355,6 @@ def _validate_and_set_debugger_configs(self):
355355
# Set defaults for debugging.
356356
self.debugger_hook_config = DebuggerHookConfig(s3_output_path=self.output_path)
357357

358-
def train_image(self):
359-
"""Placeholder docstring"""
360-
if self.image_uri:
361-
return self.image_uri
362-
363-
return fw.create_image_uri(
364-
self.sagemaker_session.boto_region_name,
365-
self._ECR_REPO_NAME,
366-
self.instance_type,
367-
self.framework_version,
368-
self.py_version,
369-
)
370-
371358
def transformer(
372359
self,
373360
instance_count,

src/sagemaker/tensorflow/model.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@
1616
import logging
1717

1818
import sagemaker
19+
from sagemaker import image_uris
1920
from sagemaker.content_types import CONTENT_TYPE_JSON
2021
from sagemaker.deserializers import JSONDeserializer
21-
from sagemaker.fw_utils import create_image_uri
2222
from sagemaker.predictor import Predictor
2323
from sagemaker.serializers import JSONSerializer
2424

@@ -122,7 +122,7 @@ def predict(self, data, initial_args=None):
122122
class TensorFlowModel(sagemaker.model.FrameworkModel):
123123
"""A ``FrameworkModel`` implementation for inference with TensorFlow Serving."""
124124

125-
__framework_name__ = "tensorflow-serving"
125+
__framework_name__ = "tensorflow"
126126
LOG_LEVEL_PARAM_NAME = "SAGEMAKER_TFS_NGINX_LOGLEVEL"
127127
LOG_LEVEL_MAP = {
128128
logging.DEBUG: "debug",
@@ -286,13 +286,13 @@ def _get_image_uri(self, instance_type, accelerator_type=None):
286286
if self.image_uri:
287287
return self.image_uri
288288

289-
region_name = self.sagemaker_session.boto_region_name
290-
return create_image_uri(
291-
region_name,
289+
return image_uris.retrieve(
292290
self.__framework_name__,
293-
instance_type,
294-
self.framework_version,
291+
self.sagemaker_session.boto_region_name,
292+
version=self.framework_version,
293+
instance_type=instance_type,
295294
accelerator_type=accelerator_type,
295+
image_scope="inference",
296296
)
297297

298298
def serving_image_uri(

tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_tf_legacy_mode.py

+25-13
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,8 @@ def test_node_should_be_modified_random_function_call():
8181

8282

8383
@patch("boto3.Session")
84-
@patch("sagemaker.fw_utils.create_image_uri", return_value=IMAGE_URI)
85-
def test_modify_node_set_model_dir_and_image_name(create_image_uri, boto_session):
84+
@patch("sagemaker.image_uris.retrieve", return_value=IMAGE_URI)
85+
def test_modify_node_set_model_dir_and_image_name(retrieve_image_uri, boto_session):
8686
boto_session.return_value.region_name = REGION_NAME
8787

8888
tf_constructors = (
@@ -97,14 +97,19 @@ def test_modify_node_set_model_dir_and_image_name(create_image_uri, boto_session
9797
modifier.modify_node(node)
9898

9999
assert "TensorFlow(image_uri='{}', model_dir=False)".format(IMAGE_URI) == pasta.dump(node)
100-
create_image_uri.assert_called_with(
101-
REGION_NAME, "tensorflow", "ml.m4.xlarge", "1.11.0", "py2"
100+
retrieve_image_uri.assert_called_with(
101+
"tensorflow",
102+
REGION_NAME,
103+
instance_type="ml.m4.xlarge",
104+
version="1.11.0",
105+
py_version="py2",
106+
image_scope="training",
102107
)
103108

104109

105110
@patch("boto3.Session")
106-
@patch("sagemaker.fw_utils.create_image_uri", return_value=IMAGE_URI)
107-
def test_modify_node_set_image_name_from_args(create_image_uri, boto_session):
111+
@patch("sagemaker.image_uris.retrieve", return_value=IMAGE_URI)
112+
def test_modify_node_set_image_name_from_args(retrieve_image_uri, boto_session):
108113
boto_session.return_value.region_name = REGION_NAME
109114

110115
tf_constructor = "TensorFlow(train_instance_type='ml.p2.xlarge', framework_version='1.4.0')"
@@ -113,7 +118,14 @@ def test_modify_node_set_image_name_from_args(create_image_uri, boto_session):
113118
modifier = tf_legacy_mode.TensorFlowLegacyModeConstructorUpgrader()
114119
modifier.modify_node(node)
115120

116-
create_image_uri.assert_called_with(REGION_NAME, "tensorflow", "ml.p2.xlarge", "1.4.0", "py2")
121+
retrieve_image_uri.assert_called_with(
122+
"tensorflow",
123+
REGION_NAME,
124+
instance_type="ml.p2.xlarge",
125+
version="1.4.0",
126+
py_version="py2",
127+
image_scope="training",
128+
)
117129

118130
expected_string = (
119131
"TensorFlow(train_instance_type='ml.p2.xlarge', framework_version='1.4.0', "
@@ -123,8 +135,8 @@ def test_modify_node_set_image_name_from_args(create_image_uri, boto_session):
123135

124136

125137
@patch("boto3.Session", MagicMock())
126-
@patch("sagemaker.fw_utils.create_image_uri", return_value=IMAGE_URI)
127-
def test_modify_node_set_hyperparameters(create_image_uri):
138+
@patch("sagemaker.image_uris.retrieve", return_value=IMAGE_URI)
139+
def test_modify_node_set_hyperparameters(retrieve_image_uri):
128140
tf_constructor = """TensorFlow(
129141
checkpoint_path='s3://foo/bar',
130142
training_steps=100,
@@ -147,8 +159,8 @@ def test_modify_node_set_hyperparameters(create_image_uri):
147159

148160

149161
@patch("boto3.Session", MagicMock())
150-
@patch("sagemaker.fw_utils.create_image_uri", return_value=IMAGE_URI)
151-
def test_modify_node_preserve_other_hyperparameters(create_image_uri):
162+
@patch("sagemaker.image_uris.retrieve", return_value=IMAGE_URI)
163+
def test_modify_node_preserve_other_hyperparameters(retrieve_image_uri):
152164
tf_constructor = """sagemaker.tensorflow.TensorFlow(
153165
training_steps=100,
154166
evaluation_steps=10,
@@ -173,8 +185,8 @@ def test_modify_node_preserve_other_hyperparameters(create_image_uri):
173185

174186

175187
@patch("boto3.Session", MagicMock())
176-
@patch("sagemaker.fw_utils.create_image_uri", return_value=IMAGE_URI)
177-
def test_modify_node_prefer_param_over_hyperparameter(create_image_uri):
188+
@patch("sagemaker.image_uris.retrieve", return_value=IMAGE_URI)
189+
def test_modify_node_prefer_param_over_hyperparameter(retrieve_image_uri):
178190
tf_constructor = """sagemaker.tensorflow.TensorFlow(
179191
training_steps=100,
180192
requirements_file='source/requirements.txt',

tests/unit/sagemaker/tensorflow/test_estimator.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -187,10 +187,10 @@ def test_create_model(
187187
container_log_level=container_log_level,
188188
base_job_name=base_job_name,
189189
enable_network_isolation=True,
190+
output_path="s3://mybucket/output",
190191
)
191192

192-
job_name = "doing something"
193-
tf.fit(inputs="s3://mybucket/train", job_name=job_name)
193+
tf._current_job_name = "doing something"
194194

195195
model_name = "doing something else"
196196
name_from_base.return_value = model_name
@@ -233,10 +233,10 @@ def test_create_model_with_optional_params(
233233
base_job_name="job",
234234
source_dir=source_dir,
235235
enable_cloudwatch_metrics=enable_cloudwatch_metrics,
236+
output_path="s3://mybucket/output",
236237
)
237238

238-
job_name = "doing something"
239-
tf.fit(inputs="s3://mybucket/train", job_name=job_name)
239+
tf._current_job_name = "doing something"
240240

241241
new_role = "role"
242242
vpc_config = {"Subnets": ["foo"], "SecurityGroupIds": ["bar"]}

0 commit comments

Comments
 (0)