@@ -303,15 +303,15 @@ def test_unsupported_distribution(
303
303
@patch ("time.time" , return_value = TIME )
304
304
@pytest .mark .parametrize ("instance_class" , SUPPORTED_GPU_INSTANCE_CLASSES )
305
305
def test_pytorchxla_distribution (
306
- time , name_from_base , sagemaker_session , pytorch_training_compiler_version , instance_class
306
+ time , name_from_base , sagemaker_session , pytorch_training_compiler_version , instance_class , pytorch_training_py_version
307
307
):
308
308
if Version (pytorch_training_compiler_version ) < Version ("1.12" ):
309
309
pytest .skip ("This test is intended for PyTorch 1.12 and above" )
310
310
compiler_config = TrainingCompilerConfig ()
311
311
instance_type = f"ml.{ instance_class } .xlarge"
312
312
313
313
pt = PyTorch (
314
- py_version = "py38" ,
314
+ py_version = pytorch_training_py_version ,
315
315
entry_point = SCRIPT_PATH ,
316
316
role = ROLE ,
317
317
sagemaker_session = sagemaker_session ,
@@ -357,13 +357,13 @@ def test_pytorchxla_distribution(
357
357
@patch ("time.time" , return_value = TIME )
358
358
@pytest .mark .parametrize ("instance_class" , SUPPORTED_GPU_INSTANCE_CLASSES )
359
359
def test_default_compiler_config (
360
- time , name_from_base , sagemaker_session , pytorch_training_compiler_version , instance_class
360
+ time , name_from_base , sagemaker_session , pytorch_training_compiler_version , instance_class , pytorch_training_py_version
361
361
):
362
362
compiler_config = TrainingCompilerConfig ()
363
363
instance_type = f"ml.{ instance_class } .xlarge"
364
364
365
365
pt = PyTorch (
366
- py_version = "py38" ,
366
+ py_version = pytorch_training_py_version ,
367
367
entry_point = SCRIPT_PATH ,
368
368
role = ROLE ,
369
369
sagemaker_session = sagemaker_session ,
@@ -406,12 +406,12 @@ def test_default_compiler_config(
406
406
@patch ("sagemaker.estimator.name_from_base" , return_value = JOB_NAME )
407
407
@patch ("time.time" , return_value = TIME )
408
408
def test_debug_compiler_config (
409
- time , name_from_base , sagemaker_session , pytorch_training_compiler_version
409
+ time , name_from_base , sagemaker_session , pytorch_training_compiler_version , pytorch_training_py_version
410
410
):
411
411
compiler_config = TrainingCompilerConfig (debug = True )
412
412
413
413
pt = PyTorch (
414
- py_version = "py38" ,
414
+ py_version = pytorch_training_py_version ,
415
415
entry_point = SCRIPT_PATH ,
416
416
role = ROLE ,
417
417
sagemaker_session = sagemaker_session ,
@@ -454,12 +454,12 @@ def test_debug_compiler_config(
454
454
@patch ("sagemaker.estimator.name_from_base" , return_value = JOB_NAME )
455
455
@patch ("time.time" , return_value = TIME )
456
456
def test_disable_compiler_config (
457
- time , name_from_base , sagemaker_session , pytorch_training_compiler_version
457
+ time , name_from_base , sagemaker_session , pytorch_training_compiler_version , pytorch_training_py_version
458
458
):
459
459
compiler_config = TrainingCompilerConfig (enabled = False )
460
460
461
461
pt = PyTorch (
462
- py_version = "py38" ,
462
+ py_version = pytorch_training_py_version ,
463
463
entry_point = SCRIPT_PATH ,
464
464
role = ROLE ,
465
465
sagemaker_session = sagemaker_session ,
0 commit comments