Skip to content

Commit ebfcac8

Browse files
committed
Fixing MWMS tests for TF2
1 parent d0036eb commit ebfcac8

File tree

1 file changed

+71
-28
lines changed

1 file changed

+71
-28
lines changed

tests/integ/test_tf.py

+71-28
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,9 @@
3333

3434
RESOURCE_PATH = os.path.join(os.path.dirname(__file__), "..", "data")
3535
MNIST_RESOURCE_PATH = os.path.join(RESOURCE_PATH, "tensorflow_mnist")
36-
TFS_RESOURCE_PATH = os.path.join(RESOURCE_PATH, "tfs", "tfs-test-entrypoint-with-handler")
36+
TFS_RESOURCE_PATH = os.path.join(
37+
RESOURCE_PATH, "tfs", "tfs-test-entrypoint-with-handler"
38+
)
3739

3840
SCRIPT = "mnist.py"
3941
PARAMETER_SERVER_DISTRIBUTION = {"parameter_server": {"enabled": True}}
@@ -96,7 +98,9 @@ def test_mnist_with_checkpoint_config(
9698
sagemaker_session=sagemaker_session,
9799
framework_version=tensorflow_training_latest_version,
98100
py_version=tensorflow_training_latest_py_version,
99-
metric_definitions=[{"Name": "train:global_steps", "Regex": r"global_step\/sec:\s(.*)"}],
101+
metric_definitions=[
102+
{"Name": "train:global_steps", "Regex": r"global_step\/sec:\s(.*)"}
103+
],
100104
checkpoint_s3_uri=checkpoint_s3_uri,
101105
checkpoint_local_path=checkpoint_local_path,
102106
environment=ENV_INPUT,
@@ -108,7 +112,9 @@ def test_mnist_with_checkpoint_config(
108112
)
109113

110114
training_job_name = unique_name_from_base("test-tf-sm-mnist")
111-
with tests.integ.timeout.timeout(minutes=tests.integ.TRAINING_DEFAULT_TIMEOUT_MINUTES):
115+
with tests.integ.timeout.timeout(
116+
minutes=tests.integ.TRAINING_DEFAULT_TIMEOUT_MINUTES
117+
):
112118
estimator.fit(inputs=inputs, job_name=training_job_name)
113119
assert_s3_file_patterns_exist(
114120
sagemaker_session,
@@ -122,13 +128,15 @@ def test_mnist_with_checkpoint_config(
122128
"S3Uri": checkpoint_s3_uri,
123129
"LocalPath": checkpoint_local_path,
124130
}
125-
actual_training_checkpoint_config = sagemaker_session.sagemaker_client.describe_training_job(
126-
TrainingJobName=training_job_name
127-
)["CheckpointConfig"]
131+
actual_training_checkpoint_config = (
132+
sagemaker_session.sagemaker_client.describe_training_job(
133+
TrainingJobName=training_job_name
134+
)["CheckpointConfig"]
135+
)
128136
actual_training_environment_variable_config = (
129-
sagemaker_session.sagemaker_client.describe_training_job(TrainingJobName=training_job_name)[
130-
"Environment"
131-
]
137+
sagemaker_session.sagemaker_client.describe_training_job(
138+
TrainingJobName=training_job_name
139+
)["Environment"]
132140
)
133141

134142
expected_retry_strategy = {
@@ -143,7 +151,10 @@ def test_mnist_with_checkpoint_config(
143151

144152

145153
def test_server_side_encryption(sagemaker_session, tf_full_version, tf_full_py_version):
146-
with kms_utils.bucket_with_encryption(sagemaker_session, ROLE) as (bucket_with_kms, kms_key):
154+
with kms_utils.bucket_with_encryption(sagemaker_session, ROLE) as (
155+
bucket_with_kms,
156+
kms_key,
157+
):
147158
output_path = os.path.join(
148159
bucket_with_kms, "test-server-side-encryption", time.strftime("%y%m%d-%H%M")
149160
)
@@ -164,16 +175,22 @@ def test_server_side_encryption(sagemaker_session, tf_full_version, tf_full_py_v
164175
)
165176

166177
inputs = estimator.sagemaker_session.upload_data(
167-
path=os.path.join(MNIST_RESOURCE_PATH, "data"), key_prefix="scriptmode/mnist"
178+
path=os.path.join(MNIST_RESOURCE_PATH, "data"),
179+
key_prefix="scriptmode/mnist",
168180
)
169181

170-
with tests.integ.timeout.timeout(minutes=tests.integ.TRAINING_DEFAULT_TIMEOUT_MINUTES):
182+
with tests.integ.timeout.timeout(
183+
minutes=tests.integ.TRAINING_DEFAULT_TIMEOUT_MINUTES
184+
):
171185
estimator.fit(
172-
inputs=inputs, job_name=unique_name_from_base("test-server-side-encryption")
186+
inputs=inputs,
187+
job_name=unique_name_from_base("test-server-side-encryption"),
173188
)
174189

175190
endpoint_name = unique_name_from_base("test-server-side-encryption")
176-
with timeout.timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
191+
with timeout.timeout_and_delete_endpoint_by_name(
192+
endpoint_name, sagemaker_session
193+
):
177194
estimator.deploy(
178195
initial_instance_count=1,
179196
instance_type="ml.c5.xlarge",
@@ -198,7 +215,7 @@ def test_mwms_gpu(
198215
imagenet_train_subset,
199216
**kwargs,
200217
):
201-
instance_count=2
218+
instance_count = 2
202219
epochs = 1
203220
global_batch_size = 64
204221
train_steps = int(10**5 * epochs / global_batch_size)
@@ -248,12 +265,19 @@ def test_mwms_gpu(
248265
disable_profiler=True,
249266
)
250267

251-
with tests.integ.timeout.timeout(minutes=tests.integ.TRAINING_DEFAULT_TIMEOUT_MINUTES):
252-
estimator.fit(inputs=imagenet_train_subset, job_name=unique_name_from_base("test-tf-mwms"))
268+
with tests.integ.timeout.timeout(
269+
minutes=tests.integ.TRAINING_DEFAULT_TIMEOUT_MINUTES
270+
):
271+
estimator.fit(
272+
inputs=imagenet_train_subset, job_name=unique_name_from_base("test-tf-mwms")
273+
)
253274

254275
captured = capsys.readouterr()
255276
logs = captured.out + captured.err
256-
assert "Running distributed training job with multi_worker_mirrored_strategy setup" in logs
277+
assert (
278+
"Running distributed training job with multi_worker_mirrored_strategy setup"
279+
in logs
280+
)
257281
assert f"num_devices = 1, group_size = {instance_count}" in logs
258282
raise NotImplementedError("Check model saving")
259283

@@ -308,11 +332,16 @@ def _create_and_fit_estimator(sagemaker_session, tf_version, py_version, instanc
308332
disable_profiler=True,
309333
)
310334
inputs = estimator.sagemaker_session.upload_data(
311-
path=os.path.join(MNIST_RESOURCE_PATH, "data"), key_prefix="scriptmode/distributed_mnist"
335+
path=os.path.join(MNIST_RESOURCE_PATH, "data"),
336+
key_prefix="scriptmode/distributed_mnist",
312337
)
313338

314-
with tests.integ.timeout.timeout(minutes=tests.integ.TRAINING_DEFAULT_TIMEOUT_MINUTES):
315-
estimator.fit(inputs=inputs, job_name=unique_name_from_base("test-tf-sm-distributed"))
339+
with tests.integ.timeout.timeout(
340+
minutes=tests.integ.TRAINING_DEFAULT_TIMEOUT_MINUTES
341+
):
342+
estimator.fit(
343+
inputs=inputs, job_name=unique_name_from_base("test-tf-sm-distributed")
344+
)
316345
assert_s3_file_patterns_exist(
317346
sagemaker_session,
318347
estimator.model_dir,
@@ -321,7 +350,9 @@ def _create_and_fit_estimator(sagemaker_session, tf_version, py_version, instanc
321350

322351

323352
@pytest.mark.slow_test
324-
def test_mnist_async(sagemaker_session, cpu_instance_type, tf_full_version, tf_full_py_version):
353+
def test_mnist_async(
354+
sagemaker_session, cpu_instance_type, tf_full_version, tf_full_py_version
355+
):
325356
if tf_full_version == "2.7.0":
326357
tf_full_version = "2.7"
327358

@@ -339,14 +370,18 @@ def test_mnist_async(sagemaker_session, cpu_instance_type, tf_full_version, tf_f
339370
inputs = estimator.sagemaker_session.upload_data(
340371
path=os.path.join(MNIST_RESOURCE_PATH, "data"), key_prefix="scriptmode/mnist"
341372
)
342-
estimator.fit(inputs=inputs, wait=False, job_name=unique_name_from_base("test-tf-sm-async"))
373+
estimator.fit(
374+
inputs=inputs, wait=False, job_name=unique_name_from_base("test-tf-sm-async")
375+
)
343376
training_job_name = estimator.latest_training_job.name
344377
time.sleep(20)
345378
endpoint_name = training_job_name
346379
_assert_training_job_tags_match(
347380
sagemaker_session.sagemaker_client, estimator.latest_training_job.name, TAGS
348381
)
349-
with tests.integ.timeout.timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
382+
with tests.integ.timeout.timeout_and_delete_endpoint_by_name(
383+
endpoint_name, sagemaker_session
384+
):
350385
estimator = TensorFlow.attach(
351386
training_job_name=training_job_name, sagemaker_session=sagemaker_session
352387
)
@@ -364,7 +399,9 @@ def test_mnist_async(sagemaker_session, cpu_instance_type, tf_full_version, tf_f
364399
sagemaker_session.sagemaker_client, predictor.endpoint_name, TAGS
365400
)
366401
_assert_model_tags_match(sagemaker_session.sagemaker_client, model_name, TAGS)
367-
_assert_model_name_match(sagemaker_session.sagemaker_client, endpoint_name, model_name)
402+
_assert_model_name_match(
403+
sagemaker_session.sagemaker_client, endpoint_name, model_name
404+
)
368405

369406

370407
def test_deploy_with_input_handlers(
@@ -450,7 +487,9 @@ def _assert_model_tags_match(sagemaker_client, model_name, tags):
450487

451488

452489
def _assert_endpoint_tags_match(sagemaker_client, endpoint_name, tags):
453-
endpoint_description = sagemaker_client.describe_endpoint(EndpointName=endpoint_name)
490+
endpoint_description = sagemaker_client.describe_endpoint(
491+
EndpointName=endpoint_name
492+
)
454493

455494
_assert_tags_match(sagemaker_client, endpoint_description["EndpointArn"], tags)
456495

@@ -459,11 +498,15 @@ def _assert_training_job_tags_match(sagemaker_client, training_job_name, tags):
459498
training_job_description = sagemaker_client.describe_training_job(
460499
TrainingJobName=training_job_name
461500
)
462-
_assert_tags_match(sagemaker_client, training_job_description["TrainingJobArn"], tags)
501+
_assert_tags_match(
502+
sagemaker_client, training_job_description["TrainingJobArn"], tags
503+
)
463504

464505

465506
def _assert_model_name_match(sagemaker_client, endpoint_config_name, model_name):
466507
endpoint_config_description = sagemaker_client.describe_endpoint_config(
467508
EndpointConfigName=endpoint_config_name
468509
)
469-
assert model_name == endpoint_config_description["ProductionVariants"][0]["ModelName"]
510+
assert (
511+
model_name == endpoint_config_description["ProductionVariants"][0]["ModelName"]
512+
)

0 commit comments

Comments
 (0)