@@ -83,22 +83,26 @@ def fixture_sagemaker_session():
83
83
return session
84
84
85
85
86
- def _get_full_gpu_image_uri (version , instance_type , training_compiler_config ):
86
+ def _get_full_gpu_image_uri (version , instance_type , training_compiler_config , py_version ):
87
87
return image_uris .retrieve (
88
88
"pytorch-training-compiler" ,
89
89
REGION ,
90
90
version = version ,
91
- py_version = "py38" ,
91
+ py_version = py_version ,
92
92
instance_type = instance_type ,
93
93
image_scope = "training" ,
94
94
container_version = None ,
95
95
training_compiler_config = training_compiler_config ,
96
96
)
97
97
98
98
99
- def _create_train_job (version , instance_type , training_compiler_config , instance_count = 1 ):
99
+ def _create_train_job (
100
+ version , instance_type , training_compiler_config , py_version , instance_count = 1
101
+ ):
100
102
return {
101
- "image_uri" : _get_full_gpu_image_uri (version , instance_type , training_compiler_config ),
103
+ "image_uri" : _get_full_gpu_image_uri (
104
+ version , instance_type , training_compiler_config , py_version
105
+ ),
102
106
"input_mode" : "File" ,
103
107
"input_config" : [
104
108
{
@@ -303,15 +307,20 @@ def test_unsupported_distribution(
303
307
@patch ("time.time" , return_value = TIME )
304
308
@pytest .mark .parametrize ("instance_class" , SUPPORTED_GPU_INSTANCE_CLASSES )
305
309
def test_pytorchxla_distribution (
306
- time , name_from_base , sagemaker_session , pytorch_training_compiler_version , instance_class
310
+ time ,
311
+ name_from_base ,
312
+ sagemaker_session ,
313
+ pytorch_training_compiler_version ,
314
+ instance_class ,
315
+ pytorch_training_compiler_py_version ,
307
316
):
308
317
if Version (pytorch_training_compiler_version ) < Version ("1.12" ):
309
318
pytest .skip ("This test is intended for PyTorch 1.12 and above" )
310
319
compiler_config = TrainingCompilerConfig ()
311
320
instance_type = f"ml.{ instance_class } .xlarge"
312
321
313
322
pt = PyTorch (
314
- py_version = "py38" ,
323
+ py_version = pytorch_training_compiler_py_version ,
315
324
entry_point = SCRIPT_PATH ,
316
325
role = ROLE ,
317
326
sagemaker_session = sagemaker_session ,
@@ -333,7 +342,11 @@ def test_pytorchxla_distribution(
333
342
assert boto_call_names == ["resource" ]
334
343
335
344
expected_train_args = _create_train_job (
336
- pytorch_training_compiler_version , instance_type , compiler_config , instance_count = 2
345
+ pytorch_training_compiler_version ,
346
+ instance_type ,
347
+ compiler_config ,
348
+ pytorch_training_compiler_py_version ,
349
+ instance_count = 2 ,
337
350
)
338
351
expected_train_args ["input_config" ][0 ]["DataSource" ]["S3DataSource" ]["S3Uri" ] = inputs
339
352
expected_train_args ["enable_sagemaker_metrics" ] = False
@@ -357,13 +370,17 @@ def test_pytorchxla_distribution(
357
370
@patch ("time.time" , return_value = TIME )
358
371
@pytest .mark .parametrize ("instance_class" , SUPPORTED_GPU_INSTANCE_CLASSES )
359
372
def test_default_compiler_config (
360
- time , name_from_base , sagemaker_session , pytorch_training_compiler_version , instance_class
373
+ time ,
374
+ name_from_base ,
375
+ sagemaker_session ,
376
+ pytorch_training_compiler_version ,
377
+ instance_class ,
378
+ pytorch_training_compiler_py_version ,
361
379
):
362
380
compiler_config = TrainingCompilerConfig ()
363
381
instance_type = f"ml.{ instance_class } .xlarge"
364
-
365
382
pt = PyTorch (
366
- py_version = "py38" ,
383
+ py_version = pytorch_training_compiler_py_version ,
367
384
entry_point = SCRIPT_PATH ,
368
385
role = ROLE ,
369
386
sagemaker_session = sagemaker_session ,
@@ -384,7 +401,10 @@ def test_default_compiler_config(
384
401
assert boto_call_names == ["resource" ]
385
402
386
403
expected_train_args = _create_train_job (
387
- pytorch_training_compiler_version , instance_type , compiler_config
404
+ pytorch_training_compiler_version ,
405
+ instance_type ,
406
+ compiler_config ,
407
+ pytorch_training_compiler_py_version ,
388
408
)
389
409
expected_train_args ["input_config" ][0 ]["DataSource" ]["S3DataSource" ]["S3Uri" ] = inputs
390
410
expected_train_args ["enable_sagemaker_metrics" ] = False
@@ -406,12 +426,16 @@ def test_default_compiler_config(
406
426
@patch ("sagemaker.estimator.name_from_base" , return_value = JOB_NAME )
407
427
@patch ("time.time" , return_value = TIME )
408
428
def test_debug_compiler_config (
409
- time , name_from_base , sagemaker_session , pytorch_training_compiler_version
429
+ time ,
430
+ name_from_base ,
431
+ sagemaker_session ,
432
+ pytorch_training_compiler_version ,
433
+ pytorch_training_compiler_py_version ,
410
434
):
411
435
compiler_config = TrainingCompilerConfig (debug = True )
412
436
413
437
pt = PyTorch (
414
- py_version = "py38" ,
438
+ py_version = pytorch_training_compiler_py_version ,
415
439
entry_point = SCRIPT_PATH ,
416
440
role = ROLE ,
417
441
sagemaker_session = sagemaker_session ,
@@ -432,7 +456,10 @@ def test_debug_compiler_config(
432
456
assert boto_call_names == ["resource" ]
433
457
434
458
expected_train_args = _create_train_job (
435
- pytorch_training_compiler_version , INSTANCE_TYPE , compiler_config
459
+ pytorch_training_compiler_version ,
460
+ INSTANCE_TYPE ,
461
+ compiler_config ,
462
+ pytorch_training_compiler_py_version ,
436
463
)
437
464
expected_train_args ["input_config" ][0 ]["DataSource" ]["S3DataSource" ]["S3Uri" ] = inputs
438
465
expected_train_args ["enable_sagemaker_metrics" ] = False
@@ -454,12 +481,16 @@ def test_debug_compiler_config(
454
481
@patch ("sagemaker.estimator.name_from_base" , return_value = JOB_NAME )
455
482
@patch ("time.time" , return_value = TIME )
456
483
def test_disable_compiler_config (
457
- time , name_from_base , sagemaker_session , pytorch_training_compiler_version
484
+ time ,
485
+ name_from_base ,
486
+ sagemaker_session ,
487
+ pytorch_training_compiler_version ,
488
+ pytorch_training_compiler_py_version ,
458
489
):
459
490
compiler_config = TrainingCompilerConfig (enabled = False )
460
491
461
492
pt = PyTorch (
462
- py_version = "py38" ,
493
+ py_version = pytorch_training_compiler_py_version ,
463
494
entry_point = SCRIPT_PATH ,
464
495
role = ROLE ,
465
496
sagemaker_session = sagemaker_session ,
@@ -480,7 +511,10 @@ def test_disable_compiler_config(
480
511
assert boto_call_names == ["resource" ]
481
512
482
513
expected_train_args = _create_train_job (
483
- pytorch_training_compiler_version , INSTANCE_TYPE , compiler_config
514
+ pytorch_training_compiler_version ,
515
+ INSTANCE_TYPE ,
516
+ compiler_config ,
517
+ pytorch_training_compiler_py_version ,
484
518
)
485
519
expected_train_args ["input_config" ][0 ]["DataSource" ]["S3DataSource" ]["S3Uri" ] = inputs
486
520
expected_train_args ["enable_sagemaker_metrics" ] = False
@@ -508,7 +542,10 @@ def test_attach(sagemaker_session, compiler_enabled, debug_enabled):
508
542
"py38-cu113-ubuntu20.04"
509
543
)
510
544
returned_job_description = {
511
- "AlgorithmSpecification" : {"TrainingInputMode" : "File" , "TrainingImage" : training_image },
545
+ "AlgorithmSpecification" : {
546
+ "TrainingInputMode" : "File" ,
547
+ "TrainingImage" : training_image ,
548
+ },
512
549
"HyperParameters" : {
513
550
"sagemaker_submit_directory" : '"s3://some/sourcedir.tar.gz"' ,
514
551
"sagemaker_program" : '"iris-dnn-classifier.py"' ,
@@ -530,7 +567,10 @@ def test_attach(sagemaker_session, compiler_enabled, debug_enabled):
530
567
"TrainingJobName" : "trcomp" ,
531
568
"TrainingJobStatus" : "Completed" ,
532
569
"TrainingJobArn" : "arn:aws:sagemaker:us-west-2:336:training-job/trcomp" ,
533
- "OutputDataConfig" : {"KmsKeyId" : "" , "S3OutputPath" : "s3://place/output/trcomp" },
570
+ "OutputDataConfig" : {
571
+ "KmsKeyId" : "" ,
572
+ "S3OutputPath" : "s3://place/output/trcomp" ,
573
+ },
534
574
"TrainingJobOutput" : {"S3TrainingJobOutput" : "s3://here/output.tar.gz" },
535
575
}
536
576
sagemaker_session .sagemaker_client .describe_training_job = Mock (
0 commit comments