Skip to content

Commit 5da90ca

Browse files
committed
Increasing dataset size for MWMS test.
1 parent 895324b commit 5da90ca

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

test/integration/sagemaker/test_multi_worker_mirrored.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,17 +49,17 @@ def test_keras_example(
4949
def test_tf_model_garden(
5050
sagemaker_session, instance_type, image_uri, tmpdir, framework_version, capsys
5151
):
52-
epochs = 10
52+
epochs = 1
5353
global_batch_size = 64
54-
train_steps = int(1024 * epochs / global_batch_size)
55-
steps_per_loop = train_steps // 10
54+
train_steps = int(10**6 * epochs / global_batch_size)
55+
steps_per_loop = train_steps // 100
5656
overrides = (
5757
f"runtime.enable_xla=False,"
5858
f"runtime.num_gpus=1,"
5959
f"runtime.distribution_strategy=multi_worker_mirrored,"
6060
f"runtime.mixed_precision_dtype=float16,"
6161
f"task.train_data.global_batch_size={global_batch_size},"
62-
f"task.train_data.input_path=/opt/ml/input/data/training/validation*,"
62+
f"task.train_data.input_path=/opt/ml/input/data/training/train*,"
6363
f"task.train_data.cache=True,"
6464
f"trainer.train_steps={train_steps},"
6565
f"trainer.steps_per_loop={steps_per_loop},"
@@ -87,11 +87,14 @@ def test_tf_model_garden(
8787
"model_dir": "/opt/ml/model",
8888
"params_override": overrides,
8989
},
90-
max_run=60 * 60 * 1, # 1 hour
90+
environment={
91+
'NCCL_DEBUG': 'INFO',
92+
},
93+
max_run=60 * 60 * 12, # 1 hour
9194
role="SageMakerRole",
9295
)
9396
estimator.fit(
94-
inputs="s3://collection-of-ml-datasets/Imagenet/TFRecords/validation",
97+
inputs="s3://collection-of-ml-datasets/Imagenet/TFRecords/train",
9598
job_name=unique_name_from_base("test-tf-mwms"),
9699
)
97100
captured = capsys.readouterr()

0 commit comments

Comments
 (0)