Skip to content

Commit d0036eb

Browse files
committed
Fixing MWMS tests for TF2
1 parent 5e725e8 commit d0036eb

File tree

2 files changed

+15
-7
lines changed

2 files changed

+15
-7
lines changed

tests/integ/test_tf.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@ def test_server_side_encryption(sagemaker_session, tf_full_version, tf_full_py_v
182182
)
183183

184184

185+
@pytest.mark.slow_test
185186
@pytest.mark.release
186187
@pytest.mark.skipif(
187188
tests.integ.test_region() in tests.integ.TRAINING_NO_P2_REGIONS
@@ -197,9 +198,10 @@ def test_mwms_gpu(
197198
imagenet_train_subset,
198199
**kwargs,
199200
):
201+
instance_count=2
200202
epochs = 1
201203
global_batch_size = 64
202-
train_steps = int(10**4 * epochs / global_batch_size)
204+
train_steps = int(10**5 * epochs / global_batch_size)
203205
steps_per_loop = train_steps // 10
204206
overrides = (
205207
f"runtime.enable_xla=False,"
@@ -225,7 +227,7 @@ def test_mwms_gpu(
225227
entry_point="official/vision/train.py",
226228
model_dir=False,
227229
instance_type=kwargs["instance_type"],
228-
instance_count=2,
230+
instance_count=instance_count,
229231
framework_version=tensorflow_training_latest_version,
230232
py_version=tensorflow_training_latest_py_version,
231233
distribution=MWMS_DISTRIBUTION,
@@ -252,6 +254,7 @@ def test_mwms_gpu(
252254
captured = capsys.readouterr()
253255
logs = captured.out + captured.err
254256
assert "Running distributed training job with multi_worker_mirrored_strategy setup" in logs
257+
assert f"num_devices = 1, group_size = {instance_count}" in logs
255258
raise NotImplementedError("Check model saving")
256259

257260

tests/unit/sagemaker/tensorflow/test_estimator.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -547,15 +547,20 @@ def test_fit_mwms(time, strftime, sagemaker_session):
547547

548548
expected_train_args = _create_train_job("2.9.1", py_version="py39")
549549
expected_train_args["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] = inputs
550-
expected_train_args["hyperparameters"][TensorFlow.LAUNCH_MWMS_ENV_NAME] = json.dumps(True)
551550
expected_train_args[
552551
"image_uri"
553552
] = f"763104351884.dkr.ecr.{REGION}.amazonaws.com/tensorflow-training:{framework_version}-cpu-{py_version}"
554553
expected_train_args["job_name"] = f"tensorflow-training-{TIMESTAMP}"
555-
expected_train_args["hyperparameters"]["sagemaker_job_name"] = expected_train_args["job_name"]
556-
expected_train_args["hyperparameters"][
557-
"sagemaker_submit_directory"
558-
] = f"s3://{BUCKET_NAME}/{expected_train_args['job_name']}/source/sourcedir.tar.gz"
554+
expected_train_args["hyperparameters"][TensorFlow.LAUNCH_MWMS_ENV_NAME] = json.dumps(True)
555+
expected_train_args["hyperparameters"]["sagemaker_job_name"] = json.dumps(
556+
expected_train_args["job_name"]
557+
)
558+
expected_train_args["hyperparameters"]["sagemaker_submit_directory"] = json.dumps(
559+
f"s3://{BUCKET_NAME}/{expected_train_args['job_name']}/source/sourcedir.tar.gz"
560+
)
561+
expected_train_args["hyperparameters"]["model_dir"] = json.dumps(
562+
f"s3://{BUCKET_NAME}/{expected_train_args['job_name']}/model"
563+
)
559564

560565
actual_train_args = sagemaker_session.method_calls[0][2]
561566
assert actual_train_args == expected_train_args

0 commit comments

Comments
 (0)