@@ -31,8 +31,42 @@ def gpu_instance_type(request):
31
31
return "ml.p3.2xlarge"
32
32
33
33
34
+ @pytest .fixture (scope = "module" )
35
+ def imagenet_val_set (request , sagemaker_session , tmpdir_factory ):
36
+ """
37
+ Copies the dataset from the bucket it's hosted in to the local bucket in the test region
38
+ """
39
+ local_path = tmpdir_factory .mktemp ("trcomp_imagenet_val_set" )
40
+ sagemaker_session .download_data (
41
+ path = local_path ,
42
+ bucket = "collection-of-ml-datasets" ,
43
+ key_prefix = "Imagenet/TFRecords/validation" ,
44
+ )
45
+ train_input = sagemaker_session .upload_data (
46
+ path = local_path ,
47
+ key_prefix = "integ-test-data/trcomp/tensorflow/imagenet/val" ,
48
+ )
49
+ return train_input
50
+
51
+
52
+ @pytest .fixture (scope = "module" )
53
+ def huggingface_dummy_dataset (request , sagemaker_session ):
54
+ """
55
+ Copies the dataset from the local disk to the local bucket in the test region
56
+ """
57
+ data_path = os .path .join (DATA_DIR , "huggingface" )
58
+ train_input = sagemaker_session .upload_data (
59
+ path = os .path .join (data_path , "train" ),
60
+ key_prefix = "integ-test-data/trcomp/huggingface/dummy/train" ,
61
+ )
62
+ return train_input
63
+
64
+
34
65
@pytest .fixture (scope = "module" , autouse = True )
35
66
def skip_if_incompatible (request ):
67
+ """
68
+ These tests are for training compiler enabled images/estimators only.
69
+ """
36
70
if integ .test_region () not in integ .TRAINING_COMPILER_SUPPORTED_REGIONS :
37
71
pytest .skip ("SageMaker Training Compiler is not supported in this region" )
38
72
if integ .test_region () in integ .TRAINING_NO_P3_REGIONS :
@@ -45,7 +79,11 @@ def test_huggingface_pytorch(
45
79
gpu_instance_type ,
46
80
huggingface_training_compiler_latest_version ,
47
81
huggingface_training_compiler_pytorch_latest_version ,
82
+ huggingface_dummy_dataset ,
48
83
):
84
+ """
85
+ Test the HuggingFace estimator with PyTorch
86
+ """
49
87
with timeout (minutes = TRAINING_DEFAULT_TIMEOUT_MINUTES ):
50
88
data_path = os .path .join (DATA_DIR , "huggingface" )
51
89
@@ -73,12 +111,7 @@ def test_huggingface_pytorch(
73
111
compiler_config = HFTrainingCompilerConfig (),
74
112
)
75
113
76
- train_input = hf .sagemaker_session .upload_data (
77
- path = os .path .join (data_path , "train" ),
78
- key_prefix = "integ-test-data/huggingface/train" ,
79
- )
80
-
81
- hf .fit (train_input )
114
+ hf .fit (huggingface_dummy_dataset )
82
115
83
116
84
117
@pytest .mark .release
@@ -87,7 +120,11 @@ def test_huggingface_tensorflow(
87
120
gpu_instance_type ,
88
121
huggingface_training_compiler_latest_version ,
89
122
huggingface_training_compiler_tensorflow_latest_version ,
123
+ huggingface_dummy_dataset ,
90
124
):
125
+ """
126
+ Test the HuggingFace estimator with TensorFlow
127
+ """
91
128
with timeout (minutes = TRAINING_DEFAULT_TIMEOUT_MINUTES ):
92
129
data_path = os .path .join (DATA_DIR , "huggingface" )
93
130
@@ -112,19 +149,19 @@ def test_huggingface_tensorflow(
112
149
compiler_config = HFTrainingCompilerConfig (),
113
150
)
114
151
115
- train_input = hf .sagemaker_session .upload_data (
116
- path = os .path .join (data_path , "train" ), key_prefix = "integ-test-data/huggingface/train"
117
- )
118
-
119
- hf .fit (train_input )
152
+ hf .fit (huggingface_dummy_dataset )
120
153
121
154
122
155
@pytest .mark .release
123
156
def test_tensorflow (
124
157
sagemaker_session ,
125
158
gpu_instance_type ,
126
159
tensorflow_training_latest_version ,
160
+ imagenet_val_set ,
127
161
):
162
+ """
163
+ Test the TensorFlow estimator
164
+ """
128
165
if version .parse (tensorflow_training_latest_version ) < version .parse ("2.9" ):
129
166
pytest .skip ("Training Compiler only supports TF >= 2.9" )
130
167
with timeout (minutes = TRAINING_DEFAULT_TIMEOUT_MINUTES ):
@@ -173,7 +210,7 @@ def test_tensorflow(
173
210
)
174
211
175
212
tf .fit (
176
- inputs = "s3://collection-of-ml-datasets/Imagenet/TFRecords/validation" ,
213
+ inputs = imagenet_val_set ,
177
214
logs = True ,
178
215
wait = True ,
179
216
)
0 commit comments