Skip to content

Commit cbac484

Browse files
committed
fix: syntax error in trcomp tests
1 parent befb150 commit cbac484

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

tests/unit/sagemaker/training_compiler/test_tensorflow_compiler.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@
5454

5555

5656
@pytest.fixture(scope="module", autouse=True)
57-
def skip_if_incompatible(tensorflow_training_version):
57+
def skip_if_incompatible(tensorflow_training_version, request):
5858
if version.parse(tensorflow_training_version) < version.parse("2.9"):
5959
pytest.skip("Training Compiler only supports TF >= 2.9")
6060

@@ -158,6 +158,7 @@ def _create_train_job(framework_version, instance_type, training_compiler_config
158158

159159
class TestUnsupportedConfig:
160160
def test_cpu_instance(
161+
self,
161162
cpu_instance_type,
162163
tensorflow_training_version,
163164
tensorflow_training_py_version,
@@ -176,6 +177,7 @@ def test_cpu_instance(
176177

177178
@pytest.mark.parametrize("unsupported_gpu_instance_class", UNSUPPORTED_GPU_INSTANCE_CLASSES)
178179
def test_gpu_instance(
180+
self,
179181
unsupported_gpu_instance_class,
180182
tensorflow_training_version,
181183
tensorflow_training_py_version,
@@ -193,6 +195,7 @@ def test_gpu_instance(
193195
).fit()
194196

195197
def test_framework_version(
198+
self,
196199
tensorflow_training_py_version,
197200
):
198201
with pytest.raises(ValueError):
@@ -208,6 +211,7 @@ def test_framework_version(
208211
).fit()
209212

210213
def test_python_2(
214+
self,
211215
tensorflow_training_version,
212216
):
213217
with pytest.raises(ValueError):
@@ -230,6 +234,7 @@ def test_python_2(
230234
class TestTrainingCompilerConfig:
231235
@pytest.mark.parametrize("instance_class", SUPPORTED_GPU_INSTANCE_CLASSES)
232236
def test_default(
237+
self,
233238
time,
234239
name_from_base,
235240
sagemaker_session,
@@ -282,6 +287,7 @@ def test_default(
282287
), f"{json.dumps(actual_train_args, indent=2)} != {json.dumps(expected_train_args, indent=2)}"
283288

284289
def test_debug_compiler_config(
290+
self,
285291
time,
286292
name_from_base,
287293
sagemaker_session,
@@ -332,6 +338,7 @@ def test_debug_compiler_config(
332338
), f"{json.dumps(actual_train_args, indent=2)} != {json.dumps(expected_train_args, indent=2)}"
333339

334340
def test_disable_compiler_config(
341+
self,
335342
time,
336343
name_from_base,
337344
sagemaker_session,

0 commit comments

Comments
 (0)