Skip to content

Commit fedcb99

Browse files
author
BruceZhang@eitug
committed
edit pytorch_training_py_version in test_pytorch_compiler
1 parent ac63570 commit fedcb99

File tree

2 files changed

+9
-9
lines changed

2 files changed

+9
-9
lines changed

tests/integ/test_training_compiler.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def test_pytorch(
150150
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
151151

152152
hf = PyTorch(
153-
py_version="py38",
153+
py_version="py39",
154154
source_dir=os.path.join(DATA_DIR, "huggingface_byoc"),
155155
entry_point="run_glue.py",
156156
role="SageMakerRole",

tests/unit/sagemaker/training_compiler/test_pytorch_compiler.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -303,15 +303,15 @@ def test_unsupported_distribution(
303303
@patch("time.time", return_value=TIME)
304304
@pytest.mark.parametrize("instance_class", SUPPORTED_GPU_INSTANCE_CLASSES)
305305
def test_pytorchxla_distribution(
306-
time, name_from_base, sagemaker_session, pytorch_training_compiler_version, instance_class
306+
time, name_from_base, sagemaker_session, pytorch_training_compiler_version, instance_class, pytorch_training_py_version
307307
):
308308
if Version(pytorch_training_compiler_version) < Version("1.12"):
309309
pytest.skip("This test is intended for PyTorch 1.12 and above")
310310
compiler_config = TrainingCompilerConfig()
311311
instance_type = f"ml.{instance_class}.xlarge"
312312

313313
pt = PyTorch(
314-
py_version="py38",
314+
py_version=pytorch_training_py_version,
315315
entry_point=SCRIPT_PATH,
316316
role=ROLE,
317317
sagemaker_session=sagemaker_session,
@@ -357,13 +357,13 @@ def test_pytorchxla_distribution(
357357
@patch("time.time", return_value=TIME)
358358
@pytest.mark.parametrize("instance_class", SUPPORTED_GPU_INSTANCE_CLASSES)
359359
def test_default_compiler_config(
360-
time, name_from_base, sagemaker_session, pytorch_training_compiler_version, instance_class
360+
time, name_from_base, sagemaker_session, pytorch_training_compiler_version, instance_class, pytorch_training_py_version
361361
):
362362
compiler_config = TrainingCompilerConfig()
363363
instance_type = f"ml.{instance_class}.xlarge"
364364

365365
pt = PyTorch(
366-
py_version="py38",
366+
py_version=pytorch_training_py_version,
367367
entry_point=SCRIPT_PATH,
368368
role=ROLE,
369369
sagemaker_session=sagemaker_session,
@@ -406,12 +406,12 @@ def test_default_compiler_config(
406406
@patch("sagemaker.estimator.name_from_base", return_value=JOB_NAME)
407407
@patch("time.time", return_value=TIME)
408408
def test_debug_compiler_config(
409-
time, name_from_base, sagemaker_session, pytorch_training_compiler_version
409+
time, name_from_base, sagemaker_session, pytorch_training_compiler_version, pytorch_training_py_version
410410
):
411411
compiler_config = TrainingCompilerConfig(debug=True)
412412

413413
pt = PyTorch(
414-
py_version="py38",
414+
py_version=pytorch_training_py_version,
415415
entry_point=SCRIPT_PATH,
416416
role=ROLE,
417417
sagemaker_session=sagemaker_session,
@@ -454,12 +454,12 @@ def test_debug_compiler_config(
454454
@patch("sagemaker.estimator.name_from_base", return_value=JOB_NAME)
455455
@patch("time.time", return_value=TIME)
456456
def test_disable_compiler_config(
457-
time, name_from_base, sagemaker_session, pytorch_training_compiler_version
457+
time, name_from_base, sagemaker_session, pytorch_training_compiler_version, pytorch_training_py_version
458458
):
459459
compiler_config = TrainingCompilerConfig(enabled=False)
460460

461461
pt = PyTorch(
462-
py_version="py38",
462+
py_version=pytorch_training_py_version,
463463
entry_point=SCRIPT_PATH,
464464
role=ROLE,
465465
sagemaker_session=sagemaker_session,

0 commit comments

Comments
 (0)