Skip to content

Commit 74ef7cd

Browse files
committed
Fixing MWMS tests for TF2
1 parent ebfcac8 commit 74ef7cd

File tree

1 file changed

+23
-60
lines changed

1 file changed

+23
-60
lines changed

tests/integ/test_tf.py

+23-60
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,7 @@
3333

3434
RESOURCE_PATH = os.path.join(os.path.dirname(__file__), "..", "data")
3535
MNIST_RESOURCE_PATH = os.path.join(RESOURCE_PATH, "tensorflow_mnist")
36-
TFS_RESOURCE_PATH = os.path.join(
37-
RESOURCE_PATH, "tfs", "tfs-test-entrypoint-with-handler"
38-
)
36+
TFS_RESOURCE_PATH = os.path.join(RESOURCE_PATH, "tfs", "tfs-test-entrypoint-with-handler")
3937

4038
SCRIPT = "mnist.py"
4139
PARAMETER_SERVER_DISTRIBUTION = {"parameter_server": {"enabled": True}}
@@ -98,9 +96,7 @@ def test_mnist_with_checkpoint_config(
9896
sagemaker_session=sagemaker_session,
9997
framework_version=tensorflow_training_latest_version,
10098
py_version=tensorflow_training_latest_py_version,
101-
metric_definitions=[
102-
{"Name": "train:global_steps", "Regex": r"global_step\/sec:\s(.*)"}
103-
],
99+
metric_definitions=[{"Name": "train:global_steps", "Regex": r"global_step\/sec:\s(.*)"}],
104100
checkpoint_s3_uri=checkpoint_s3_uri,
105101
checkpoint_local_path=checkpoint_local_path,
106102
environment=ENV_INPUT,
@@ -112,9 +108,7 @@ def test_mnist_with_checkpoint_config(
112108
)
113109

114110
training_job_name = unique_name_from_base("test-tf-sm-mnist")
115-
with tests.integ.timeout.timeout(
116-
minutes=tests.integ.TRAINING_DEFAULT_TIMEOUT_MINUTES
117-
):
111+
with tests.integ.timeout.timeout(minutes=tests.integ.TRAINING_DEFAULT_TIMEOUT_MINUTES):
118112
estimator.fit(inputs=inputs, job_name=training_job_name)
119113
assert_s3_file_patterns_exist(
120114
sagemaker_session,
@@ -128,15 +122,13 @@ def test_mnist_with_checkpoint_config(
128122
"S3Uri": checkpoint_s3_uri,
129123
"LocalPath": checkpoint_local_path,
130124
}
131-
actual_training_checkpoint_config = (
132-
sagemaker_session.sagemaker_client.describe_training_job(
133-
TrainingJobName=training_job_name
134-
)["CheckpointConfig"]
135-
)
125+
actual_training_checkpoint_config = sagemaker_session.sagemaker_client.describe_training_job(
126+
TrainingJobName=training_job_name
127+
)["CheckpointConfig"]
136128
actual_training_environment_variable_config = (
137-
sagemaker_session.sagemaker_client.describe_training_job(
138-
TrainingJobName=training_job_name
139-
)["Environment"]
129+
sagemaker_session.sagemaker_client.describe_training_job(TrainingJobName=training_job_name)[
130+
"Environment"
131+
]
140132
)
141133

142134
expected_retry_strategy = {
@@ -179,18 +171,14 @@ def test_server_side_encryption(sagemaker_session, tf_full_version, tf_full_py_v
179171
key_prefix="scriptmode/mnist",
180172
)
181173

182-
with tests.integ.timeout.timeout(
183-
minutes=tests.integ.TRAINING_DEFAULT_TIMEOUT_MINUTES
184-
):
174+
with tests.integ.timeout.timeout(minutes=tests.integ.TRAINING_DEFAULT_TIMEOUT_MINUTES):
185175
estimator.fit(
186176
inputs=inputs,
187177
job_name=unique_name_from_base("test-server-side-encryption"),
188178
)
189179

190180
endpoint_name = unique_name_from_base("test-server-side-encryption")
191-
with timeout.timeout_and_delete_endpoint_by_name(
192-
endpoint_name, sagemaker_session
193-
):
181+
with timeout.timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
194182
estimator.deploy(
195183
initial_instance_count=1,
196184
instance_type="ml.c5.xlarge",
@@ -265,19 +253,12 @@ def test_mwms_gpu(
265253
disable_profiler=True,
266254
)
267255

268-
with tests.integ.timeout.timeout(
269-
minutes=tests.integ.TRAINING_DEFAULT_TIMEOUT_MINUTES
270-
):
271-
estimator.fit(
272-
inputs=imagenet_train_subset, job_name=unique_name_from_base("test-tf-mwms")
273-
)
256+
with tests.integ.timeout.timeout(minutes=tests.integ.TRAINING_DEFAULT_TIMEOUT_MINUTES):
257+
estimator.fit(inputs=imagenet_train_subset, job_name=unique_name_from_base("test-tf-mwms"))
274258

275259
captured = capsys.readouterr()
276260
logs = captured.out + captured.err
277-
assert (
278-
"Running distributed training job with multi_worker_mirrored_strategy setup"
279-
in logs
280-
)
261+
assert "Running distributed training job with multi_worker_mirrored_strategy setup" in logs
281262
assert f"num_devices = 1, group_size = {instance_count}" in logs
282263
raise NotImplementedError("Check model saving")
283264

@@ -336,12 +317,8 @@ def _create_and_fit_estimator(sagemaker_session, tf_version, py_version, instanc
336317
key_prefix="scriptmode/distributed_mnist",
337318
)
338319

339-
with tests.integ.timeout.timeout(
340-
minutes=tests.integ.TRAINING_DEFAULT_TIMEOUT_MINUTES
341-
):
342-
estimator.fit(
343-
inputs=inputs, job_name=unique_name_from_base("test-tf-sm-distributed")
344-
)
320+
with tests.integ.timeout.timeout(minutes=tests.integ.TRAINING_DEFAULT_TIMEOUT_MINUTES):
321+
estimator.fit(inputs=inputs, job_name=unique_name_from_base("test-tf-sm-distributed"))
345322
assert_s3_file_patterns_exist(
346323
sagemaker_session,
347324
estimator.model_dir,
@@ -350,9 +327,7 @@ def _create_and_fit_estimator(sagemaker_session, tf_version, py_version, instanc
350327

351328

352329
@pytest.mark.slow_test
353-
def test_mnist_async(
354-
sagemaker_session, cpu_instance_type, tf_full_version, tf_full_py_version
355-
):
330+
def test_mnist_async(sagemaker_session, cpu_instance_type, tf_full_version, tf_full_py_version):
356331
if tf_full_version == "2.7.0":
357332
tf_full_version = "2.7"
358333

@@ -370,18 +345,14 @@ def test_mnist_async(
370345
inputs = estimator.sagemaker_session.upload_data(
371346
path=os.path.join(MNIST_RESOURCE_PATH, "data"), key_prefix="scriptmode/mnist"
372347
)
373-
estimator.fit(
374-
inputs=inputs, wait=False, job_name=unique_name_from_base("test-tf-sm-async")
375-
)
348+
estimator.fit(inputs=inputs, wait=False, job_name=unique_name_from_base("test-tf-sm-async"))
376349
training_job_name = estimator.latest_training_job.name
377350
time.sleep(20)
378351
endpoint_name = training_job_name
379352
_assert_training_job_tags_match(
380353
sagemaker_session.sagemaker_client, estimator.latest_training_job.name, TAGS
381354
)
382-
with tests.integ.timeout.timeout_and_delete_endpoint_by_name(
383-
endpoint_name, sagemaker_session
384-
):
355+
with tests.integ.timeout.timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
385356
estimator = TensorFlow.attach(
386357
training_job_name=training_job_name, sagemaker_session=sagemaker_session
387358
)
@@ -399,9 +370,7 @@ def test_mnist_async(
399370
sagemaker_session.sagemaker_client, predictor.endpoint_name, TAGS
400371
)
401372
_assert_model_tags_match(sagemaker_session.sagemaker_client, model_name, TAGS)
402-
_assert_model_name_match(
403-
sagemaker_session.sagemaker_client, endpoint_name, model_name
404-
)
373+
_assert_model_name_match(sagemaker_session.sagemaker_client, endpoint_name, model_name)
405374

406375

407376
def test_deploy_with_input_handlers(
@@ -487,9 +456,7 @@ def _assert_model_tags_match(sagemaker_client, model_name, tags):
487456

488457

489458
def _assert_endpoint_tags_match(sagemaker_client, endpoint_name, tags):
490-
endpoint_description = sagemaker_client.describe_endpoint(
491-
EndpointName=endpoint_name
492-
)
459+
endpoint_description = sagemaker_client.describe_endpoint(EndpointName=endpoint_name)
493460

494461
_assert_tags_match(sagemaker_client, endpoint_description["EndpointArn"], tags)
495462

@@ -498,15 +465,11 @@ def _assert_training_job_tags_match(sagemaker_client, training_job_name, tags):
498465
training_job_description = sagemaker_client.describe_training_job(
499466
TrainingJobName=training_job_name
500467
)
501-
_assert_tags_match(
502-
sagemaker_client, training_job_description["TrainingJobArn"], tags
503-
)
468+
_assert_tags_match(sagemaker_client, training_job_description["TrainingJobArn"], tags)
504469

505470

506471
def _assert_model_name_match(sagemaker_client, endpoint_config_name, model_name):
507472
endpoint_config_description = sagemaker_client.describe_endpoint_config(
508473
EndpointConfigName=endpoint_config_name
509474
)
510-
assert (
511-
model_name == endpoint_config_description["ProductionVariants"][0]["ModelName"]
512-
)
475+
assert model_name == endpoint_config_description["ProductionVariants"][0]["ModelName"]

0 commit comments

Comments
 (0)