@@ -81,8 +81,8 @@ def test_node_should_be_modified_random_function_call():
81
81
82
82
83
83
@patch ("boto3.Session" )
84
- @patch ("sagemaker.fw_utils.create_image_uri " , return_value = IMAGE_URI )
85
- def test_modify_node_set_model_dir_and_image_name (create_image_uri , boto_session ):
84
+ @patch ("sagemaker.image_uris.retrieve " , return_value = IMAGE_URI )
85
+ def test_modify_node_set_model_dir_and_image_name (retrieve_image_uri , boto_session ):
86
86
boto_session .return_value .region_name = REGION_NAME
87
87
88
88
tf_constructors = (
@@ -97,14 +97,19 @@ def test_modify_node_set_model_dir_and_image_name(create_image_uri, boto_session
97
97
modifier .modify_node (node )
98
98
99
99
assert "TensorFlow(image_uri='{}', model_dir=False)" .format (IMAGE_URI ) == pasta .dump (node )
100
- create_image_uri .assert_called_with (
101
- REGION_NAME , "tensorflow" , "ml.m4.xlarge" , "1.11.0" , "py2"
100
+ retrieve_image_uri .assert_called_with (
101
+ "tensorflow" ,
102
+ REGION_NAME ,
103
+ instance_type = "ml.m4.xlarge" ,
104
+ version = "1.11.0" ,
105
+ py_version = "py2" ,
106
+ image_scope = "training" ,
102
107
)
103
108
104
109
105
110
@patch ("boto3.Session" )
106
- @patch ("sagemaker.fw_utils.create_image_uri " , return_value = IMAGE_URI )
107
- def test_modify_node_set_image_name_from_args (create_image_uri , boto_session ):
111
+ @patch ("sagemaker.image_uris.retrieve " , return_value = IMAGE_URI )
112
+ def test_modify_node_set_image_name_from_args (retrieve_image_uri , boto_session ):
108
113
boto_session .return_value .region_name = REGION_NAME
109
114
110
115
tf_constructor = "TensorFlow(train_instance_type='ml.p2.xlarge', framework_version='1.4.0')"
@@ -113,7 +118,14 @@ def test_modify_node_set_image_name_from_args(create_image_uri, boto_session):
113
118
modifier = tf_legacy_mode .TensorFlowLegacyModeConstructorUpgrader ()
114
119
modifier .modify_node (node )
115
120
116
- create_image_uri .assert_called_with (REGION_NAME , "tensorflow" , "ml.p2.xlarge" , "1.4.0" , "py2" )
121
+ retrieve_image_uri .assert_called_with (
122
+ "tensorflow" ,
123
+ REGION_NAME ,
124
+ instance_type = "ml.p2.xlarge" ,
125
+ version = "1.4.0" ,
126
+ py_version = "py2" ,
127
+ image_scope = "training" ,
128
+ )
117
129
118
130
expected_string = (
119
131
"TensorFlow(train_instance_type='ml.p2.xlarge', framework_version='1.4.0', "
@@ -123,8 +135,8 @@ def test_modify_node_set_image_name_from_args(create_image_uri, boto_session):
123
135
124
136
125
137
@patch ("boto3.Session" , MagicMock ())
126
- @patch ("sagemaker.fw_utils.create_image_uri " , return_value = IMAGE_URI )
127
- def test_modify_node_set_hyperparameters (create_image_uri ):
138
+ @patch ("sagemaker.image_uris.retrieve " , return_value = IMAGE_URI )
139
+ def test_modify_node_set_hyperparameters (retrieve_image_uri ):
128
140
tf_constructor = """TensorFlow(
129
141
checkpoint_path='s3://foo/bar',
130
142
training_steps=100,
@@ -147,8 +159,8 @@ def test_modify_node_set_hyperparameters(create_image_uri):
147
159
148
160
149
161
@patch ("boto3.Session" , MagicMock ())
150
- @patch ("sagemaker.fw_utils.create_image_uri " , return_value = IMAGE_URI )
151
- def test_modify_node_preserve_other_hyperparameters (create_image_uri ):
162
+ @patch ("sagemaker.image_uris.retrieve " , return_value = IMAGE_URI )
163
+ def test_modify_node_preserve_other_hyperparameters (retrieve_image_uri ):
152
164
tf_constructor = """sagemaker.tensorflow.TensorFlow(
153
165
training_steps=100,
154
166
evaluation_steps=10,
@@ -173,8 +185,8 @@ def test_modify_node_preserve_other_hyperparameters(create_image_uri):
173
185
174
186
175
187
@patch ("boto3.Session" , MagicMock ())
176
- @patch ("sagemaker.fw_utils.create_image_uri " , return_value = IMAGE_URI )
177
- def test_modify_node_prefer_param_over_hyperparameter (create_image_uri ):
188
+ @patch ("sagemaker.image_uris.retrieve " , return_value = IMAGE_URI )
189
+ def test_modify_node_prefer_param_over_hyperparameter (retrieve_image_uri ):
178
190
tf_constructor = """sagemaker.tensorflow.TensorFlow(
179
191
training_steps=100,
180
192
requirements_file='source/requirements.txt',
0 commit comments