Skip to content

Commit aa202cd

Browse files
author
Rohit Kumar Srivastava
committed
fixing black-check
1 parent 86f3594 commit aa202cd

File tree

2 files changed

+25
-23
lines changed

2 files changed

+25
-23
lines changed

tests/conftest.py

+7
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,13 @@ def huggingface_training_compiler_tensorflow_version(huggingface_training_compil
248248
)[0]
249249

250250

251+
@pytest.fixture(scope="module")
252+
def huggingface_training_compiler_py_version(huggingface_training_compiler_tensorflow_version):
253+
return (
254+
"py37" if Version(huggingface_training_compiler_tensorflow_version) < Version("2.6") else "py38"
255+
)
256+
257+
251258
@pytest.fixture(scope="module")
252259
def huggingface_pytorch_latest_training_py_version(huggingface_training_pytorch_latest_version):
253260
return (

tests/unit/sagemaker/training_compiler/test_huggingface_tensorflow_compiler.py

+18-23
Original file line numberDiff line numberDiff line change
@@ -79,11 +79,7 @@ def fixture_sagemaker_session():
7979

8080

8181
def _get_full_gpu_image_uri(
82-
version,
83-
base_framework_version,
84-
instance_type,
85-
training_compiler_config,
86-
py_version
82+
version, base_framework_version, instance_type, training_compiler_config, py_version
8783
):
8884
return image_uris.retrieve(
8985
"huggingface",
@@ -98,7 +94,9 @@ def _get_full_gpu_image_uri(
9894
)
9995

10096

101-
def _create_train_job(version, base_framework_version, instance_type, training_compiler_config, py_version):
97+
def _create_train_job(
98+
version, base_framework_version, instance_type, training_compiler_config, py_version
99+
):
102100
return {
103101
"image_uri": _get_full_gpu_image_uri(
104102
version, base_framework_version, instance_type, training_compiler_config, py_version
@@ -159,7 +157,7 @@ def _create_train_job(version, base_framework_version, instance_type, training_c
159157
def test_unsupported_BYOC(
160158
huggingface_training_compiler_version,
161159
huggingface_training_compiler_tensorflow_version,
162-
huggingface_training_compiler_py_version
160+
huggingface_training_compiler_py_version,
163161
):
164162
byoc = (
165163
f"1.dkr.ecr.us-east-1.amazonaws.com/huggingface-tensorflow-trcomp-training:"
@@ -186,7 +184,7 @@ def test_unsupported_cpu_instance(
186184
cpu_instance_type,
187185
huggingface_training_compiler_version,
188186
huggingface_training_compiler_tensorflow_version,
189-
huggingface_training_compiler_py_version
187+
huggingface_training_compiler_py_version,
190188
):
191189
with pytest.raises(ValueError):
192190
HuggingFace(
@@ -207,7 +205,7 @@ def test_unsupported_gpu_instance(
207205
unsupported_gpu_instance_class,
208206
huggingface_training_compiler_version,
209207
huggingface_training_compiler_tensorflow_version,
210-
huggingface_training_compiler_py_version
208+
huggingface_training_compiler_py_version,
211209
):
212210
with pytest.raises(ValueError):
213211
HuggingFace(
@@ -225,7 +223,7 @@ def test_unsupported_gpu_instance(
225223

226224
def test_unsupported_framework_version(
227225
huggingface_training_compiler_version,
228-
huggingface_training_compiler_py_version
226+
huggingface_training_compiler_py_version,
229227
):
230228
with pytest.raises(ValueError):
231229
HuggingFace(
@@ -245,7 +243,7 @@ def test_unsupported_framework_version(
245243

246244
def test_unsupported_framework_mxnet(
247245
huggingface_training_compiler_version,
248-
huggingface_training_compiler_py_version
246+
huggingface_training_compiler_py_version,
249247
):
250248
with pytest.raises(ValueError):
251249
HuggingFace(
@@ -263,7 +261,7 @@ def test_unsupported_framework_mxnet(
263261

264262
def test_unsupported_python_2(
265263
huggingface_training_compiler_version,
266-
huggingface_training_compiler_tensorflow_version
264+
huggingface_training_compiler_tensorflow_version,
267265
):
268266
with pytest.raises(ValueError):
269267
HuggingFace(
@@ -291,7 +289,7 @@ def test_default_compiler_config(
291289
huggingface_training_compiler_version,
292290
huggingface_training_compiler_tensorflow_version,
293291
instance_class,
294-
huggingface_training_compiler_py_version
292+
huggingface_training_compiler_py_version,
295293
):
296294
compiler_config = TrainingCompilerConfig()
297295
instance_type = f"ml.{instance_class}.xlarge"
@@ -323,7 +321,7 @@ def test_default_compiler_config(
323321
f"tensorflow{huggingface_training_compiler_tensorflow_version}",
324322
instance_type,
325323
compiler_config,
326-
huggingface_training_compiler_py_version
324+
huggingface_training_compiler_py_version,
327325
)
328326
expected_train_args["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] = inputs
329327
expected_train_args["enable_sagemaker_metrics"] = False
@@ -350,7 +348,7 @@ def test_debug_compiler_config(
350348
sagemaker_session,
351349
huggingface_training_compiler_version,
352350
huggingface_training_compiler_tensorflow_version,
353-
huggingface_training_compiler_py_version
351+
huggingface_training_compiler_py_version,
354352
):
355353
compiler_config = TrainingCompilerConfig(debug=True)
356354

@@ -381,7 +379,7 @@ def test_debug_compiler_config(
381379
f"tensorflow{huggingface_training_compiler_tensorflow_version}",
382380
INSTANCE_TYPE,
383381
compiler_config,
384-
huggingface_training_compiler_py_version
382+
huggingface_training_compiler_py_version,
385383
)
386384
expected_train_args["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] = inputs
387385
expected_train_args["enable_sagemaker_metrics"] = False
@@ -408,7 +406,7 @@ def test_disable_compiler_config(
408406
sagemaker_session,
409407
huggingface_training_compiler_version,
410408
huggingface_training_compiler_tensorflow_version,
411-
huggingface_training_compiler_py_version
409+
huggingface_training_compiler_py_version,
412410
):
413411
compiler_config = TrainingCompilerConfig(enabled=False)
414412

@@ -439,7 +437,7 @@ def test_disable_compiler_config(
439437
f"tensorflow{huggingface_training_compiler_tensorflow_version}",
440438
INSTANCE_TYPE,
441439
compiler_config,
442-
huggingface_training_compiler_py_version
440+
huggingface_training_compiler_py_version,
443441
)
444442
expected_train_args["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] = inputs
445443
expected_train_args["enable_sagemaker_metrics"] = False
@@ -460,10 +458,7 @@ def test_disable_compiler_config(
460458
["compiler_enabled", "debug_enabled"], [(True, False), (True, True), (False, False)]
461459
)
462460
def test_attach(
463-
sagemaker_session,
464-
compiler_enabled,
465-
debug_enabled,
466-
huggingface_training_compiler_py_version
461+
sagemaker_session, compiler_enabled, debug_enabled, huggingface_training_compiler_py_version
467462
):
468463
training_image = (
469464
f"1.dkr.ecr.us-east-1.amazonaws.com/huggingface-tensorflow-trcomp-training:"
@@ -503,7 +498,7 @@ def test_attach(
503498

504499
estimator = HuggingFace.attach(training_job_name="hopper", sagemaker_session=sagemaker_session)
505500
assert estimator.latest_training_job.job_name == "hopper"
506-
assert estimator.py_version == "py38"
501+
assert estimator.py_version == huggingface_training_compiler_py_version
507502
assert estimator.framework_version == "4.17.0"
508503
assert estimator.tensorflow_version == "2.6.3"
509504
assert estimator.role == "arn:aws:iam::366:role/SageMakerRole"

0 commit comments

Comments
 (0)