33
33
34
34
RESOURCE_PATH = os .path .join (os .path .dirname (__file__ ), ".." , "data" )
35
35
MNIST_RESOURCE_PATH = os .path .join (RESOURCE_PATH , "tensorflow_mnist" )
36
- TFS_RESOURCE_PATH = os .path .join (RESOURCE_PATH , "tfs" , "tfs-test-entrypoint-with-handler" )
36
+ TFS_RESOURCE_PATH = os .path .join (
37
+ RESOURCE_PATH , "tfs" , "tfs-test-entrypoint-with-handler"
38
+ )
37
39
38
40
SCRIPT = "mnist.py"
39
41
PARAMETER_SERVER_DISTRIBUTION = {"parameter_server" : {"enabled" : True }}
@@ -96,7 +98,9 @@ def test_mnist_with_checkpoint_config(
96
98
sagemaker_session = sagemaker_session ,
97
99
framework_version = tensorflow_training_latest_version ,
98
100
py_version = tensorflow_training_latest_py_version ,
99
- metric_definitions = [{"Name" : "train:global_steps" , "Regex" : r"global_step\/sec:\s(.*)" }],
101
+ metric_definitions = [
102
+ {"Name" : "train:global_steps" , "Regex" : r"global_step\/sec:\s(.*)" }
103
+ ],
100
104
checkpoint_s3_uri = checkpoint_s3_uri ,
101
105
checkpoint_local_path = checkpoint_local_path ,
102
106
environment = ENV_INPUT ,
@@ -108,7 +112,9 @@ def test_mnist_with_checkpoint_config(
108
112
)
109
113
110
114
training_job_name = unique_name_from_base ("test-tf-sm-mnist" )
111
- with tests .integ .timeout .timeout (minutes = tests .integ .TRAINING_DEFAULT_TIMEOUT_MINUTES ):
115
+ with tests .integ .timeout .timeout (
116
+ minutes = tests .integ .TRAINING_DEFAULT_TIMEOUT_MINUTES
117
+ ):
112
118
estimator .fit (inputs = inputs , job_name = training_job_name )
113
119
assert_s3_file_patterns_exist (
114
120
sagemaker_session ,
@@ -122,13 +128,15 @@ def test_mnist_with_checkpoint_config(
122
128
"S3Uri" : checkpoint_s3_uri ,
123
129
"LocalPath" : checkpoint_local_path ,
124
130
}
125
- actual_training_checkpoint_config = sagemaker_session .sagemaker_client .describe_training_job (
126
- TrainingJobName = training_job_name
127
- )["CheckpointConfig" ]
131
+ actual_training_checkpoint_config = (
132
+ sagemaker_session .sagemaker_client .describe_training_job (
133
+ TrainingJobName = training_job_name
134
+ )["CheckpointConfig" ]
135
+ )
128
136
actual_training_environment_variable_config = (
129
- sagemaker_session .sagemaker_client .describe_training_job (TrainingJobName = training_job_name )[
130
- "Environment"
131
- ]
137
+ sagemaker_session .sagemaker_client .describe_training_job (
138
+ TrainingJobName = training_job_name
139
+ )[ "Environment" ]
132
140
)
133
141
134
142
expected_retry_strategy = {
@@ -143,7 +151,10 @@ def test_mnist_with_checkpoint_config(
143
151
144
152
145
153
def test_server_side_encryption (sagemaker_session , tf_full_version , tf_full_py_version ):
146
- with kms_utils .bucket_with_encryption (sagemaker_session , ROLE ) as (bucket_with_kms , kms_key ):
154
+ with kms_utils .bucket_with_encryption (sagemaker_session , ROLE ) as (
155
+ bucket_with_kms ,
156
+ kms_key ,
157
+ ):
147
158
output_path = os .path .join (
148
159
bucket_with_kms , "test-server-side-encryption" , time .strftime ("%y%m%d-%H%M" )
149
160
)
@@ -164,16 +175,22 @@ def test_server_side_encryption(sagemaker_session, tf_full_version, tf_full_py_v
164
175
)
165
176
166
177
inputs = estimator .sagemaker_session .upload_data (
167
- path = os .path .join (MNIST_RESOURCE_PATH , "data" ), key_prefix = "scriptmode/mnist"
178
+ path = os .path .join (MNIST_RESOURCE_PATH , "data" ),
179
+ key_prefix = "scriptmode/mnist" ,
168
180
)
169
181
170
- with tests .integ .timeout .timeout (minutes = tests .integ .TRAINING_DEFAULT_TIMEOUT_MINUTES ):
182
+ with tests .integ .timeout .timeout (
183
+ minutes = tests .integ .TRAINING_DEFAULT_TIMEOUT_MINUTES
184
+ ):
171
185
estimator .fit (
172
- inputs = inputs , job_name = unique_name_from_base ("test-server-side-encryption" )
186
+ inputs = inputs ,
187
+ job_name = unique_name_from_base ("test-server-side-encryption" ),
173
188
)
174
189
175
190
endpoint_name = unique_name_from_base ("test-server-side-encryption" )
176
- with timeout .timeout_and_delete_endpoint_by_name (endpoint_name , sagemaker_session ):
191
+ with timeout .timeout_and_delete_endpoint_by_name (
192
+ endpoint_name , sagemaker_session
193
+ ):
177
194
estimator .deploy (
178
195
initial_instance_count = 1 ,
179
196
instance_type = "ml.c5.xlarge" ,
@@ -198,7 +215,7 @@ def test_mwms_gpu(
198
215
imagenet_train_subset ,
199
216
** kwargs ,
200
217
):
201
- instance_count = 2
218
+ instance_count = 2
202
219
epochs = 1
203
220
global_batch_size = 64
204
221
train_steps = int (10 ** 5 * epochs / global_batch_size )
@@ -248,12 +265,19 @@ def test_mwms_gpu(
248
265
disable_profiler = True ,
249
266
)
250
267
251
- with tests .integ .timeout .timeout (minutes = tests .integ .TRAINING_DEFAULT_TIMEOUT_MINUTES ):
252
- estimator .fit (inputs = imagenet_train_subset , job_name = unique_name_from_base ("test-tf-mwms" ))
268
+ with tests .integ .timeout .timeout (
269
+ minutes = tests .integ .TRAINING_DEFAULT_TIMEOUT_MINUTES
270
+ ):
271
+ estimator .fit (
272
+ inputs = imagenet_train_subset , job_name = unique_name_from_base ("test-tf-mwms" )
273
+ )
253
274
254
275
captured = capsys .readouterr ()
255
276
logs = captured .out + captured .err
256
- assert "Running distributed training job with multi_worker_mirrored_strategy setup" in logs
277
+ assert (
278
+ "Running distributed training job with multi_worker_mirrored_strategy setup"
279
+ in logs
280
+ )
257
281
assert f"num_devices = 1, group_size = { instance_count } " in logs
258
282
raise NotImplementedError ("Check model saving" )
259
283
@@ -308,11 +332,16 @@ def _create_and_fit_estimator(sagemaker_session, tf_version, py_version, instanc
308
332
disable_profiler = True ,
309
333
)
310
334
inputs = estimator .sagemaker_session .upload_data (
311
- path = os .path .join (MNIST_RESOURCE_PATH , "data" ), key_prefix = "scriptmode/distributed_mnist"
335
+ path = os .path .join (MNIST_RESOURCE_PATH , "data" ),
336
+ key_prefix = "scriptmode/distributed_mnist" ,
312
337
)
313
338
314
- with tests .integ .timeout .timeout (minutes = tests .integ .TRAINING_DEFAULT_TIMEOUT_MINUTES ):
315
- estimator .fit (inputs = inputs , job_name = unique_name_from_base ("test-tf-sm-distributed" ))
339
+ with tests .integ .timeout .timeout (
340
+ minutes = tests .integ .TRAINING_DEFAULT_TIMEOUT_MINUTES
341
+ ):
342
+ estimator .fit (
343
+ inputs = inputs , job_name = unique_name_from_base ("test-tf-sm-distributed" )
344
+ )
316
345
assert_s3_file_patterns_exist (
317
346
sagemaker_session ,
318
347
estimator .model_dir ,
@@ -321,7 +350,9 @@ def _create_and_fit_estimator(sagemaker_session, tf_version, py_version, instanc
321
350
322
351
323
352
@pytest .mark .slow_test
324
- def test_mnist_async (sagemaker_session , cpu_instance_type , tf_full_version , tf_full_py_version ):
353
+ def test_mnist_async (
354
+ sagemaker_session , cpu_instance_type , tf_full_version , tf_full_py_version
355
+ ):
325
356
if tf_full_version == "2.7.0" :
326
357
tf_full_version = "2.7"
327
358
@@ -339,14 +370,18 @@ def test_mnist_async(sagemaker_session, cpu_instance_type, tf_full_version, tf_f
339
370
inputs = estimator .sagemaker_session .upload_data (
340
371
path = os .path .join (MNIST_RESOURCE_PATH , "data" ), key_prefix = "scriptmode/mnist"
341
372
)
342
- estimator .fit (inputs = inputs , wait = False , job_name = unique_name_from_base ("test-tf-sm-async" ))
373
+ estimator .fit (
374
+ inputs = inputs , wait = False , job_name = unique_name_from_base ("test-tf-sm-async" )
375
+ )
343
376
training_job_name = estimator .latest_training_job .name
344
377
time .sleep (20 )
345
378
endpoint_name = training_job_name
346
379
_assert_training_job_tags_match (
347
380
sagemaker_session .sagemaker_client , estimator .latest_training_job .name , TAGS
348
381
)
349
- with tests .integ .timeout .timeout_and_delete_endpoint_by_name (endpoint_name , sagemaker_session ):
382
+ with tests .integ .timeout .timeout_and_delete_endpoint_by_name (
383
+ endpoint_name , sagemaker_session
384
+ ):
350
385
estimator = TensorFlow .attach (
351
386
training_job_name = training_job_name , sagemaker_session = sagemaker_session
352
387
)
@@ -364,7 +399,9 @@ def test_mnist_async(sagemaker_session, cpu_instance_type, tf_full_version, tf_f
364
399
sagemaker_session .sagemaker_client , predictor .endpoint_name , TAGS
365
400
)
366
401
_assert_model_tags_match (sagemaker_session .sagemaker_client , model_name , TAGS )
367
- _assert_model_name_match (sagemaker_session .sagemaker_client , endpoint_name , model_name )
402
+ _assert_model_name_match (
403
+ sagemaker_session .sagemaker_client , endpoint_name , model_name
404
+ )
368
405
369
406
370
407
def test_deploy_with_input_handlers (
@@ -450,7 +487,9 @@ def _assert_model_tags_match(sagemaker_client, model_name, tags):
450
487
451
488
452
489
def _assert_endpoint_tags_match (sagemaker_client , endpoint_name , tags ):
453
- endpoint_description = sagemaker_client .describe_endpoint (EndpointName = endpoint_name )
490
+ endpoint_description = sagemaker_client .describe_endpoint (
491
+ EndpointName = endpoint_name
492
+ )
454
493
455
494
_assert_tags_match (sagemaker_client , endpoint_description ["EndpointArn" ], tags )
456
495
@@ -459,11 +498,15 @@ def _assert_training_job_tags_match(sagemaker_client, training_job_name, tags):
459
498
training_job_description = sagemaker_client .describe_training_job (
460
499
TrainingJobName = training_job_name
461
500
)
462
- _assert_tags_match (sagemaker_client , training_job_description ["TrainingJobArn" ], tags )
501
+ _assert_tags_match (
502
+ sagemaker_client , training_job_description ["TrainingJobArn" ], tags
503
+ )
463
504
464
505
465
506
def _assert_model_name_match (sagemaker_client , endpoint_config_name , model_name ):
466
507
endpoint_config_description = sagemaker_client .describe_endpoint_config (
467
508
EndpointConfigName = endpoint_config_name
468
509
)
469
- assert model_name == endpoint_config_description ["ProductionVariants" ][0 ]["ModelName" ]
510
+ assert (
511
+ model_name == endpoint_config_description ["ProductionVariants" ][0 ]["ModelName" ]
512
+ )
0 commit comments