35
35
"worker" : ["{}:8890" .format (HOST2 )],
36
36
"ps" : ["{}:2223" .format (HOST1 ), "{}:2223" .format (HOST2 )],
37
37
}
38
- CLUSTER_WITH_MWMS = {
39
- "worker" : ["{}:8890" .format (HOST ) for HOST IN (HOST1 , HOST2 )],
40
- }
38
+ CLUSTER_WITH_MWMS = {"worker" : ["{}:8890" .format (HOST ) for HOST in HOST_LIST ]}
41
39
42
40
MASTER_TASK = {"index" : 0 , "type" : "master" }
43
41
WORKER_TASK = {"index" : 0 , "type" : "worker" }
@@ -54,7 +52,9 @@ def distributed_training_env():
54
52
env = simple_training_env ()
55
53
56
54
env .hosts = HOST_LIST
57
- env .additional_framework_parameters = {training .SAGEMAKER_PARAMETER_SERVER_ENABLED : True }
55
+ env .additional_framework_parameters = {
56
+ training .SAGEMAKER_PARAMETER_SERVER_ENABLED : True
57
+ }
58
58
return env
59
59
60
60
@@ -98,7 +98,9 @@ def test_single_machine(run_module, single_machine_training_env):
98
98
99
99
@patch ("sagemaker_training.entry_point.run" )
100
100
def test_train_horovod (run_module , single_machine_training_env ):
101
- single_machine_training_env .additional_framework_parameters ["sagemaker_mpi_enabled" ] = True
101
+ single_machine_training_env .additional_framework_parameters [
102
+ "sagemaker_mpi_enabled"
103
+ ] = True
102
104
103
105
training .train (single_machine_training_env , MODEL_DIR_CMD_LIST )
104
106
run_module .assert_called_with (
@@ -113,22 +115,32 @@ def test_train_horovod(run_module, single_machine_training_env):
113
115
114
116
@pytest .mark .skip_on_pipeline
115
117
@pytest .mark .skipif (
116
- sys .version_info .major != 3 , reason = "Skip this for python 2 because of dict key order mismatch"
118
+ sys .version_info .major != 3 ,
119
+ reason = "Skip this for python 2 because of dict key order mismatch" ,
117
120
)
118
121
@patch ("tensorflow.train.ClusterSpec" )
119
122
@patch ("tensorflow.train.Server" )
120
123
@patch ("sagemaker_training.entry_point.run" )
121
124
@patch ("multiprocessing.Process" , lambda target : target ())
122
125
@patch ("time.sleep" , MagicMock ())
123
- def test_train_distributed_master (run , tf_server , cluster_spec , distributed_training_env ):
126
+ def test_train_distributed_master (
127
+ run , tf_server , cluster_spec , distributed_training_env
128
+ ):
124
129
training .train (distributed_training_env , MODEL_DIR_CMD_LIST )
125
130
126
131
cluster_spec .assert_called_with (
127
- {"worker" : ["host2:2222" ], "master" : ["host1:2222" ], "ps" : ["host1:2223" , "host2:2223" ]}
132
+ {
133
+ "worker" : ["host2:2222" ],
134
+ "master" : ["host1:2222" ],
135
+ "ps" : ["host1:2223" , "host2:2223" ],
136
+ }
128
137
)
129
138
130
139
tf_server .assert_called_with (
131
- cluster_spec (), job_name = "ps" , task_index = 0 , config = tf .ConfigProto (device_count = {"GPU" : 0 })
140
+ cluster_spec (),
141
+ job_name = "ps" ,
142
+ task_index = 0 ,
143
+ config = tf .ConfigProto (device_count = {"GPU" : 0 }),
132
144
)
133
145
tf_server ().join .assert_called_with ()
134
146
@@ -152,24 +164,34 @@ def test_train_distributed_master(run, tf_server, cluster_spec, distributed_trai
152
164
153
165
@pytest .mark .skip_on_pipeline
154
166
@pytest .mark .skipif (
155
- sys .version_info .major != 3 , reason = "Skip this for python 2 because of dict key order mismatch"
167
+ sys .version_info .major != 3 ,
168
+ reason = "Skip this for python 2 because of dict key order mismatch" ,
156
169
)
157
170
@patch ("tensorflow.train.ClusterSpec" )
158
171
@patch ("tensorflow.train.Server" )
159
172
@patch ("sagemaker_training.entry_point.run" )
160
173
@patch ("multiprocessing.Process" , lambda target : target ())
161
174
@patch ("time.sleep" , MagicMock ())
162
- def test_train_distributed_worker (run , tf_server , cluster_spec , distributed_training_env ):
175
+ def test_train_distributed_worker (
176
+ run , tf_server , cluster_spec , distributed_training_env
177
+ ):
163
178
distributed_training_env .current_host = HOST2
164
179
165
180
training .train (distributed_training_env , MODEL_DIR_CMD_LIST )
166
181
167
182
cluster_spec .assert_called_with (
168
- {"worker" : ["host2:2222" ], "master" : ["host1:2222" ], "ps" : ["host1:2223" , "host2:2223" ]}
183
+ {
184
+ "worker" : ["host2:2222" ],
185
+ "master" : ["host1:2222" ],
186
+ "ps" : ["host1:2223" , "host2:2223" ],
187
+ }
169
188
)
170
189
171
190
tf_server .assert_called_with (
172
- cluster_spec (), job_name = "ps" , task_index = 1 , config = tf .ConfigProto (device_count = {"GPU" : 0 })
191
+ cluster_spec (),
192
+ job_name = "ps" ,
193
+ task_index = 1 ,
194
+ config = tf .ConfigProto (device_count = {"GPU" : 0 }),
173
195
)
174
196
tf_server ().join .assert_called_with ()
175
197
@@ -248,8 +270,9 @@ def test_build_tf_config_for_ps():
248
270
def test_build_tf_config_for_ps_error ():
249
271
with pytest .raises (ValueError ) as error :
250
272
training ._build_tf_config_for_ps ([HOST1 ], HOST1 , ps_task = True )
251
- assert "Cannot have a ps task if there are no parameter servers in the cluster" in str (
252
- error .value
273
+ assert (
274
+ "Cannot have a ps task if there are no parameter servers in the cluster"
275
+ in str (error .value )
253
276
)
254
277
255
278
@@ -271,7 +294,9 @@ def test_log_model_missing_warning_no_model(logger):
271
294
272
295
@patch ("sagemaker_tensorflow_container.training.logger" )
273
296
def test_log_model_missing_warning_wrong_format (logger ):
274
- training ._log_model_missing_warning (os .path .join (RESOURCE_PATH , "test_dir_wrong_model" ))
297
+ training ._log_model_missing_warning (
298
+ os .path .join (RESOURCE_PATH , "test_dir_wrong_model" )
299
+ )
275
300
logger .warn .assert_called_with (
276
301
"Your model will NOT be servable with SageMaker TensorFlow Serving container. "
277
302
"The model artifact was not saved in the TensorFlow "
@@ -282,16 +307,22 @@ def test_log_model_missing_warning_wrong_format(logger):
282
307
283
308
@patch ("sagemaker_tensorflow_container.training.logger" )
284
309
def test_log_model_missing_warning_wrong_parent_dir (logger ):
285
- training ._log_model_missing_warning (os .path .join (RESOURCE_PATH , "test_dir_wrong_parent_dir" ))
310
+ training ._log_model_missing_warning (
311
+ os .path .join (RESOURCE_PATH , "test_dir_wrong_parent_dir" )
312
+ )
286
313
logger .warn .assert_called_with (
287
314
"Your model will NOT be servable with SageMaker TensorFlow Serving containers. "
288
- 'The SavedModel bundle is under directory "{}", not a numeric name.' .format ("not-digit" )
315
+ 'The SavedModel bundle is under directory "{}", not a numeric name.' .format (
316
+ "not-digit"
317
+ )
289
318
)
290
319
291
320
292
321
@patch ("sagemaker_tensorflow_container.training.logger" )
293
322
def test_log_model_missing_warning_correct (logger ):
294
- training ._log_model_missing_warning (os .path .join (RESOURCE_PATH , "test_dir_correct_model" ))
323
+ training ._log_model_missing_warning (
324
+ os .path .join (RESOURCE_PATH , "test_dir_correct_model" )
325
+ )
295
326
logger .warn .assert_not_called ()
296
327
297
328
@@ -323,7 +354,10 @@ def test_main(
323
354
@patch ("sagemaker_tensorflow_container.training.train" )
324
355
@patch ("logging.Logger.setLevel" )
325
356
@patch ("sagemaker_training.environment.Environment" )
326
- @patch ("sagemaker_training.environment.read_hyperparameters" , return_value = {"model_dir" : MODEL_DIR })
357
+ @patch (
358
+ "sagemaker_training.environment.read_hyperparameters" ,
359
+ return_value = {"model_dir" : MODEL_DIR },
360
+ )
327
361
@patch ("sagemaker_tensorflow_container.s3_utils.configure" )
328
362
def test_main_simple_training_model_dir (
329
363
configure_s3_env ,
@@ -361,7 +395,9 @@ def test_main_tuning_model_dir(
361
395
training_env .return_value = single_machine_training_env
362
396
os .environ ["SAGEMAKER_REGION" ] = REGION
363
397
training .main ()
364
- expected_model_dir = "{}/{}/model" .format (MODEL_DIR , single_machine_training_env .job_name )
398
+ expected_model_dir = "{}/{}/model" .format (
399
+ MODEL_DIR , single_machine_training_env .job_name
400
+ )
365
401
configure_s3_env .assert_called_once_with (expected_model_dir , REGION )
366
402
367
403
0 commit comments