@@ -49,17 +49,17 @@ def test_keras_example(
49
49
def test_tf_model_garden (
50
50
sagemaker_session , instance_type , image_uri , tmpdir , framework_version , capsys
51
51
):
52
- epochs = 10
52
+ epochs = 1
53
53
global_batch_size = 64
54
- train_steps = int (1024 * epochs / global_batch_size )
55
- steps_per_loop = train_steps // 10
54
+ train_steps = int (10 ** 6 * epochs / global_batch_size )
55
+ steps_per_loop = train_steps // 100
56
56
overrides = (
57
57
f"runtime.enable_xla=False,"
58
58
f"runtime.num_gpus=1,"
59
59
f"runtime.distribution_strategy=multi_worker_mirrored,"
60
60
f"runtime.mixed_precision_dtype=float16,"
61
61
f"task.train_data.global_batch_size={ global_batch_size } ,"
62
- f"task.train_data.input_path=/opt/ml/input/data/training/validation *,"
62
+ f"task.train_data.input_path=/opt/ml/input/data/training/train *,"
63
63
f"task.train_data.cache=True,"
64
64
f"trainer.train_steps={ train_steps } ,"
65
65
f"trainer.steps_per_loop={ steps_per_loop } ,"
@@ -87,11 +87,14 @@ def test_tf_model_garden(
87
87
"model_dir" : "/opt/ml/model" ,
88
88
"params_override" : overrides ,
89
89
},
90
- max_run = 60 * 60 * 1 , # 1 hour
90
+ environment = {
91
+ 'NCCL_DEBUG' : 'INFO' ,
92
+ },
93
+ max_run = 60 * 60 * 12 , # 1 hour
91
94
role = "SageMakerRole" ,
92
95
)
93
96
estimator .fit (
94
- inputs = "s3://collection-of-ml-datasets/Imagenet/TFRecords/validation " ,
97
+ inputs = "s3://collection-of-ml-datasets/Imagenet/TFRecords/train " ,
95
98
job_name = unique_name_from_base ("test-tf-mwms" ),
96
99
)
97
100
captured = capsys .readouterr ()
0 commit comments