33
33
34
34
RESOURCE_PATH = os .path .join (os .path .dirname (__file__ ), ".." , "data" )
35
35
MNIST_RESOURCE_PATH = os .path .join (RESOURCE_PATH , "tensorflow_mnist" )
36
- TFS_RESOURCE_PATH = os .path .join (
37
- RESOURCE_PATH , "tfs" , "tfs-test-entrypoint-with-handler"
38
- )
36
+ TFS_RESOURCE_PATH = os .path .join (RESOURCE_PATH , "tfs" , "tfs-test-entrypoint-with-handler" )
39
37
40
38
SCRIPT = "mnist.py"
41
39
PARAMETER_SERVER_DISTRIBUTION = {"parameter_server" : {"enabled" : True }}
@@ -98,9 +96,7 @@ def test_mnist_with_checkpoint_config(
98
96
sagemaker_session = sagemaker_session ,
99
97
framework_version = tensorflow_training_latest_version ,
100
98
py_version = tensorflow_training_latest_py_version ,
101
- metric_definitions = [
102
- {"Name" : "train:global_steps" , "Regex" : r"global_step\/sec:\s(.*)" }
103
- ],
99
+ metric_definitions = [{"Name" : "train:global_steps" , "Regex" : r"global_step\/sec:\s(.*)" }],
104
100
checkpoint_s3_uri = checkpoint_s3_uri ,
105
101
checkpoint_local_path = checkpoint_local_path ,
106
102
environment = ENV_INPUT ,
@@ -112,9 +108,7 @@ def test_mnist_with_checkpoint_config(
112
108
)
113
109
114
110
training_job_name = unique_name_from_base ("test-tf-sm-mnist" )
115
- with tests .integ .timeout .timeout (
116
- minutes = tests .integ .TRAINING_DEFAULT_TIMEOUT_MINUTES
117
- ):
111
+ with tests .integ .timeout .timeout (minutes = tests .integ .TRAINING_DEFAULT_TIMEOUT_MINUTES ):
118
112
estimator .fit (inputs = inputs , job_name = training_job_name )
119
113
assert_s3_file_patterns_exist (
120
114
sagemaker_session ,
@@ -128,15 +122,13 @@ def test_mnist_with_checkpoint_config(
128
122
"S3Uri" : checkpoint_s3_uri ,
129
123
"LocalPath" : checkpoint_local_path ,
130
124
}
131
- actual_training_checkpoint_config = (
132
- sagemaker_session .sagemaker_client .describe_training_job (
133
- TrainingJobName = training_job_name
134
- )["CheckpointConfig" ]
135
- )
125
+ actual_training_checkpoint_config = sagemaker_session .sagemaker_client .describe_training_job (
126
+ TrainingJobName = training_job_name
127
+ )["CheckpointConfig" ]
136
128
actual_training_environment_variable_config = (
137
- sagemaker_session .sagemaker_client .describe_training_job (
138
- TrainingJobName = training_job_name
139
- )[ "Environment" ]
129
+ sagemaker_session .sagemaker_client .describe_training_job (TrainingJobName = training_job_name )[
130
+ "Environment"
131
+ ]
140
132
)
141
133
142
134
expected_retry_strategy = {
@@ -179,18 +171,14 @@ def test_server_side_encryption(sagemaker_session, tf_full_version, tf_full_py_v
179
171
key_prefix = "scriptmode/mnist" ,
180
172
)
181
173
182
- with tests .integ .timeout .timeout (
183
- minutes = tests .integ .TRAINING_DEFAULT_TIMEOUT_MINUTES
184
- ):
174
+ with tests .integ .timeout .timeout (minutes = tests .integ .TRAINING_DEFAULT_TIMEOUT_MINUTES ):
185
175
estimator .fit (
186
176
inputs = inputs ,
187
177
job_name = unique_name_from_base ("test-server-side-encryption" ),
188
178
)
189
179
190
180
endpoint_name = unique_name_from_base ("test-server-side-encryption" )
191
- with timeout .timeout_and_delete_endpoint_by_name (
192
- endpoint_name , sagemaker_session
193
- ):
181
+ with timeout .timeout_and_delete_endpoint_by_name (endpoint_name , sagemaker_session ):
194
182
estimator .deploy (
195
183
initial_instance_count = 1 ,
196
184
instance_type = "ml.c5.xlarge" ,
@@ -265,19 +253,12 @@ def test_mwms_gpu(
265
253
disable_profiler = True ,
266
254
)
267
255
268
- with tests .integ .timeout .timeout (
269
- minutes = tests .integ .TRAINING_DEFAULT_TIMEOUT_MINUTES
270
- ):
271
- estimator .fit (
272
- inputs = imagenet_train_subset , job_name = unique_name_from_base ("test-tf-mwms" )
273
- )
256
+ with tests .integ .timeout .timeout (minutes = tests .integ .TRAINING_DEFAULT_TIMEOUT_MINUTES ):
257
+ estimator .fit (inputs = imagenet_train_subset , job_name = unique_name_from_base ("test-tf-mwms" ))
274
258
275
259
captured = capsys .readouterr ()
276
260
logs = captured .out + captured .err
277
- assert (
278
- "Running distributed training job with multi_worker_mirrored_strategy setup"
279
- in logs
280
- )
261
+ assert "Running distributed training job with multi_worker_mirrored_strategy setup" in logs
281
262
assert f"num_devices = 1, group_size = { instance_count } " in logs
282
263
raise NotImplementedError ("Check model saving" )
283
264
@@ -336,12 +317,8 @@ def _create_and_fit_estimator(sagemaker_session, tf_version, py_version, instanc
336
317
key_prefix = "scriptmode/distributed_mnist" ,
337
318
)
338
319
339
- with tests .integ .timeout .timeout (
340
- minutes = tests .integ .TRAINING_DEFAULT_TIMEOUT_MINUTES
341
- ):
342
- estimator .fit (
343
- inputs = inputs , job_name = unique_name_from_base ("test-tf-sm-distributed" )
344
- )
320
+ with tests .integ .timeout .timeout (minutes = tests .integ .TRAINING_DEFAULT_TIMEOUT_MINUTES ):
321
+ estimator .fit (inputs = inputs , job_name = unique_name_from_base ("test-tf-sm-distributed" ))
345
322
assert_s3_file_patterns_exist (
346
323
sagemaker_session ,
347
324
estimator .model_dir ,
@@ -350,9 +327,7 @@ def _create_and_fit_estimator(sagemaker_session, tf_version, py_version, instanc
350
327
351
328
352
329
@pytest .mark .slow_test
353
- def test_mnist_async (
354
- sagemaker_session , cpu_instance_type , tf_full_version , tf_full_py_version
355
- ):
330
+ def test_mnist_async (sagemaker_session , cpu_instance_type , tf_full_version , tf_full_py_version ):
356
331
if tf_full_version == "2.7.0" :
357
332
tf_full_version = "2.7"
358
333
@@ -370,18 +345,14 @@ def test_mnist_async(
370
345
inputs = estimator .sagemaker_session .upload_data (
371
346
path = os .path .join (MNIST_RESOURCE_PATH , "data" ), key_prefix = "scriptmode/mnist"
372
347
)
373
- estimator .fit (
374
- inputs = inputs , wait = False , job_name = unique_name_from_base ("test-tf-sm-async" )
375
- )
348
+ estimator .fit (inputs = inputs , wait = False , job_name = unique_name_from_base ("test-tf-sm-async" ))
376
349
training_job_name = estimator .latest_training_job .name
377
350
time .sleep (20 )
378
351
endpoint_name = training_job_name
379
352
_assert_training_job_tags_match (
380
353
sagemaker_session .sagemaker_client , estimator .latest_training_job .name , TAGS
381
354
)
382
- with tests .integ .timeout .timeout_and_delete_endpoint_by_name (
383
- endpoint_name , sagemaker_session
384
- ):
355
+ with tests .integ .timeout .timeout_and_delete_endpoint_by_name (endpoint_name , sagemaker_session ):
385
356
estimator = TensorFlow .attach (
386
357
training_job_name = training_job_name , sagemaker_session = sagemaker_session
387
358
)
@@ -399,9 +370,7 @@ def test_mnist_async(
399
370
sagemaker_session .sagemaker_client , predictor .endpoint_name , TAGS
400
371
)
401
372
_assert_model_tags_match (sagemaker_session .sagemaker_client , model_name , TAGS )
402
- _assert_model_name_match (
403
- sagemaker_session .sagemaker_client , endpoint_name , model_name
404
- )
373
+ _assert_model_name_match (sagemaker_session .sagemaker_client , endpoint_name , model_name )
405
374
406
375
407
376
def test_deploy_with_input_handlers (
@@ -487,9 +456,7 @@ def _assert_model_tags_match(sagemaker_client, model_name, tags):
487
456
488
457
489
458
def _assert_endpoint_tags_match (sagemaker_client , endpoint_name , tags ):
490
- endpoint_description = sagemaker_client .describe_endpoint (
491
- EndpointName = endpoint_name
492
- )
459
+ endpoint_description = sagemaker_client .describe_endpoint (EndpointName = endpoint_name )
493
460
494
461
_assert_tags_match (sagemaker_client , endpoint_description ["EndpointArn" ], tags )
495
462
@@ -498,15 +465,11 @@ def _assert_training_job_tags_match(sagemaker_client, training_job_name, tags):
498
465
training_job_description = sagemaker_client .describe_training_job (
499
466
TrainingJobName = training_job_name
500
467
)
501
- _assert_tags_match (
502
- sagemaker_client , training_job_description ["TrainingJobArn" ], tags
503
- )
468
+ _assert_tags_match (sagemaker_client , training_job_description ["TrainingJobArn" ], tags )
504
469
505
470
506
471
def _assert_model_name_match (sagemaker_client , endpoint_config_name , model_name ):
507
472
endpoint_config_description = sagemaker_client .describe_endpoint_config (
508
473
EndpointConfigName = endpoint_config_name
509
474
)
510
- assert (
511
- model_name == endpoint_config_description ["ProductionVariants" ][0 ]["ModelName" ]
512
- )
475
+ assert model_name == endpoint_config_description ["ProductionVariants" ][0 ]["ModelName" ]
0 commit comments