@@ -182,6 +182,7 @@ def test_server_side_encryption(sagemaker_session, tf_full_version, tf_full_py_v
182
182
)
183
183
184
184
185
+ @pytest .mark .slow_test
185
186
@pytest .mark .release
186
187
@pytest .mark .skipif (
187
188
tests .integ .test_region () in tests .integ .TRAINING_NO_P2_REGIONS
@@ -197,9 +198,10 @@ def test_mwms_gpu(
197
198
imagenet_train_subset ,
198
199
** kwargs ,
199
200
):
201
+ instance_count = 2
200
202
epochs = 1
201
203
global_batch_size = 64
202
- train_steps = int (10 ** 4 * epochs / global_batch_size )
204
+ train_steps = int (10 ** 5 * epochs / global_batch_size )
203
205
steps_per_loop = train_steps // 10
204
206
overrides = (
205
207
f"runtime.enable_xla=False,"
@@ -225,7 +227,7 @@ def test_mwms_gpu(
225
227
entry_point = "official/vision/train.py" ,
226
228
model_dir = False ,
227
229
instance_type = kwargs ["instance_type" ],
228
- instance_count = 2 ,
230
+ instance_count = instance_count ,
229
231
framework_version = tensorflow_training_latest_version ,
230
232
py_version = tensorflow_training_latest_py_version ,
231
233
distribution = MWMS_DISTRIBUTION ,
@@ -252,6 +254,7 @@ def test_mwms_gpu(
252
254
captured = capsys .readouterr ()
253
255
logs = captured .out + captured .err
254
256
assert "Running distributed training job with multi_worker_mirrored_strategy setup" in logs
257
+ assert f"num_devices = 1, group_size = { instance_count } " in logs
255
258
raise NotImplementedError ("Check model saving" )
256
259
257
260
0 commit comments