54
54
55
55
56
56
@pytest .fixture (scope = "module" , autouse = True )
57
- def skip_if_incompatible (tensorflow_training_version ):
57
+ def skip_if_incompatible (tensorflow_training_version , request ):
58
58
if version .parse (tensorflow_training_version ) < version .parse ("2.9" ):
59
59
pytest .skip ("Training Compiler only supports TF >= 2.9" )
60
60
@@ -158,6 +158,7 @@ def _create_train_job(framework_version, instance_type, training_compiler_config
158
158
159
159
class TestUnsupportedConfig :
160
160
def test_cpu_instance (
161
+ self ,
161
162
cpu_instance_type ,
162
163
tensorflow_training_version ,
163
164
tensorflow_training_py_version ,
@@ -176,6 +177,7 @@ def test_cpu_instance(
176
177
177
178
@pytest .mark .parametrize ("unsupported_gpu_instance_class" , UNSUPPORTED_GPU_INSTANCE_CLASSES )
178
179
def test_gpu_instance (
180
+ self ,
179
181
unsupported_gpu_instance_class ,
180
182
tensorflow_training_version ,
181
183
tensorflow_training_py_version ,
@@ -193,6 +195,7 @@ def test_gpu_instance(
193
195
).fit ()
194
196
195
197
def test_framework_version (
198
+ self ,
196
199
tensorflow_training_py_version ,
197
200
):
198
201
with pytest .raises (ValueError ):
@@ -208,6 +211,7 @@ def test_framework_version(
208
211
).fit ()
209
212
210
213
def test_python_2 (
214
+ self ,
211
215
tensorflow_training_version ,
212
216
):
213
217
with pytest .raises (ValueError ):
@@ -230,6 +234,7 @@ def test_python_2(
230
234
class TestTrainingCompilerConfig :
231
235
@pytest .mark .parametrize ("instance_class" , SUPPORTED_GPU_INSTANCE_CLASSES )
232
236
def test_default (
237
+ self ,
233
238
time ,
234
239
name_from_base ,
235
240
sagemaker_session ,
@@ -282,6 +287,7 @@ def test_default(
282
287
), f"{ json .dumps (actual_train_args , indent = 2 )} != { json .dumps (expected_train_args , indent = 2 )} "
283
288
284
289
def test_debug_compiler_config (
290
+ self ,
285
291
time ,
286
292
name_from_base ,
287
293
sagemaker_session ,
@@ -332,6 +338,7 @@ def test_debug_compiler_config(
332
338
), f"{ json .dumps (actual_train_args , indent = 2 )} != { json .dumps (expected_train_args , indent = 2 )} "
333
339
334
340
def test_disable_compiler_config (
341
+ self ,
335
342
time ,
336
343
name_from_base ,
337
344
sagemaker_session ,
0 commit comments