45
45
REGION = "us-east-1"
46
46
GPU = "ml.p3.2xlarge"
47
47
SUPPORTED_GPU_INSTANCE_CLASSES = {"p3" , "p3dn" , "g4dn" , "p4d" , "g5" }
48
- UNSUPPORTED_GPU_INSTANCE_CLASSES = EC2_GPU_INSTANCE_CLASSES - SUPPORTED_GPU_INSTANCE_CLASSES
48
+ UNSUPPORTED_GPU_INSTANCE_CLASSES = (
49
+ EC2_GPU_INSTANCE_CLASSES - SUPPORTED_GPU_INSTANCE_CLASSES
50
+ )
49
51
50
52
LIST_TAGS_RESULT = {"Tags" : [{"Key" : "TagtestKey" , "Value" : "TagtestValue" }]}
51
53
@@ -96,9 +98,13 @@ def _get_full_gpu_image_uri(version, instance_type, training_compiler_config):
96
98
)
97
99
98
100
99
- def _create_train_job (version , instance_type , training_compiler_config , instance_count = 1 ):
101
+ def _create_train_job (
102
+ version , instance_type , training_compiler_config , instance_count = 1
103
+ ):
100
104
return {
101
- "image_uri" : _get_full_gpu_image_uri (version , instance_type , training_compiler_config ),
105
+ "image_uri" : _get_full_gpu_image_uri (
106
+ version , instance_type , training_compiler_config
107
+ ),
102
108
"input_mode" : "File" ,
103
109
"input_config" : [
104
110
{
@@ -183,7 +189,9 @@ def test_unsupported_cpu_instance(cpu_instance_type, pytorch_training_compiler_v
183
189
).fit ()
184
190
185
191
186
- @pytest .mark .parametrize ("unsupported_gpu_instance_class" , UNSUPPORTED_GPU_INSTANCE_CLASSES )
192
+ @pytest .mark .parametrize (
193
+ "unsupported_gpu_instance_class" , UNSUPPORTED_GPU_INSTANCE_CLASSES
194
+ )
187
195
def test_unsupported_gpu_instance (
188
196
unsupported_gpu_instance_class , pytorch_training_compiler_version
189
197
):
@@ -303,7 +311,12 @@ def test_unsupported_distribution(
303
311
@patch ("time.time" , return_value = TIME )
304
312
@pytest .mark .parametrize ("instance_class" , SUPPORTED_GPU_INSTANCE_CLASSES )
305
313
def test_pytorchxla_distribution (
306
- time , name_from_base , sagemaker_session , pytorch_training_compiler_version , instance_class , pytorch_training_py_version
314
+ time ,
315
+ name_from_base ,
316
+ sagemaker_session ,
317
+ pytorch_training_compiler_version ,
318
+ instance_class ,
319
+ pytorch_training_py_version ,
307
320
):
308
321
if Version (pytorch_training_compiler_version ) < Version ("1.12" ):
309
322
pytest .skip ("This test is intended for PyTorch 1.12 and above" )
@@ -333,17 +346,24 @@ def test_pytorchxla_distribution(
333
346
assert boto_call_names == ["resource" ]
334
347
335
348
expected_train_args = _create_train_job (
336
- pytorch_training_compiler_version , instance_type , compiler_config , instance_count = 2
349
+ pytorch_training_compiler_version ,
350
+ instance_type ,
351
+ compiler_config ,
352
+ instance_count = 2 ,
337
353
)
338
- expected_train_args ["input_config" ][0 ]["DataSource" ]["S3DataSource" ]["S3Uri" ] = inputs
354
+ expected_train_args ["input_config" ][0 ]["DataSource" ]["S3DataSource" ][
355
+ "S3Uri"
356
+ ] = inputs
339
357
expected_train_args ["enable_sagemaker_metrics" ] = False
340
- expected_train_args ["hyperparameters" ][TrainingCompilerConfig .HP_ENABLE_COMPILER ] = json .dumps (
358
+ expected_train_args ["hyperparameters" ][
359
+ TrainingCompilerConfig .HP_ENABLE_COMPILER
360
+ ] = json .dumps (True )
361
+ expected_train_args ["hyperparameters" ][PyTorch .LAUNCH_PT_XLA_ENV_NAME ] = json .dumps (
341
362
True
342
363
)
343
- expected_train_args ["hyperparameters" ][PyTorch .LAUNCH_PT_XLA_ENV_NAME ] = json .dumps (True )
344
- expected_train_args ["hyperparameters" ][TrainingCompilerConfig .HP_ENABLE_DEBUG ] = json .dumps (
345
- False
346
- )
364
+ expected_train_args ["hyperparameters" ][
365
+ TrainingCompilerConfig .HP_ENABLE_DEBUG
366
+ ] = json .dumps (False )
347
367
348
368
actual_train_args = sagemaker_session .method_calls [0 ][2 ]
349
369
assert (
@@ -357,7 +377,12 @@ def test_pytorchxla_distribution(
357
377
@patch ("time.time" , return_value = TIME )
358
378
@pytest .mark .parametrize ("instance_class" , SUPPORTED_GPU_INSTANCE_CLASSES )
359
379
def test_default_compiler_config (
360
- time , name_from_base , sagemaker_session , pytorch_training_compiler_version , instance_class , pytorch_training_py_version
380
+ time ,
381
+ name_from_base ,
382
+ sagemaker_session ,
383
+ pytorch_training_compiler_version ,
384
+ instance_class ,
385
+ pytorch_training_py_version ,
361
386
):
362
387
compiler_config = TrainingCompilerConfig ()
363
388
instance_type = f"ml.{ instance_class } .xlarge"
@@ -386,14 +411,16 @@ def test_default_compiler_config(
386
411
expected_train_args = _create_train_job (
387
412
pytorch_training_compiler_version , instance_type , compiler_config
388
413
)
389
- expected_train_args ["input_config" ][0 ]["DataSource" ]["S3DataSource" ]["S3Uri" ] = inputs
414
+ expected_train_args ["input_config" ][0 ]["DataSource" ]["S3DataSource" ][
415
+ "S3Uri"
416
+ ] = inputs
390
417
expected_train_args ["enable_sagemaker_metrics" ] = False
391
- expected_train_args ["hyperparameters" ][TrainingCompilerConfig . HP_ENABLE_COMPILER ] = json . dumps (
392
- True
393
- )
394
- expected_train_args ["hyperparameters" ][TrainingCompilerConfig . HP_ENABLE_DEBUG ] = json . dumps (
395
- False
396
- )
418
+ expected_train_args ["hyperparameters" ][
419
+ TrainingCompilerConfig . HP_ENABLE_COMPILER
420
+ ] = json . dumps ( True )
421
+ expected_train_args ["hyperparameters" ][
422
+ TrainingCompilerConfig . HP_ENABLE_DEBUG
423
+ ] = json . dumps ( False )
397
424
398
425
actual_train_args = sagemaker_session .method_calls [0 ][2 ]
399
426
assert (
@@ -406,7 +433,11 @@ def test_default_compiler_config(
406
433
@patch ("sagemaker.estimator.name_from_base" , return_value = JOB_NAME )
407
434
@patch ("time.time" , return_value = TIME )
408
435
def test_debug_compiler_config (
409
- time , name_from_base , sagemaker_session , pytorch_training_compiler_version , pytorch_training_py_version
436
+ time ,
437
+ name_from_base ,
438
+ sagemaker_session ,
439
+ pytorch_training_compiler_version ,
440
+ pytorch_training_py_version ,
410
441
):
411
442
compiler_config = TrainingCompilerConfig (debug = True )
412
443
@@ -434,14 +465,16 @@ def test_debug_compiler_config(
434
465
expected_train_args = _create_train_job (
435
466
pytorch_training_compiler_version , INSTANCE_TYPE , compiler_config
436
467
)
437
- expected_train_args ["input_config" ][0 ]["DataSource" ]["S3DataSource" ]["S3Uri" ] = inputs
468
+ expected_train_args ["input_config" ][0 ]["DataSource" ]["S3DataSource" ][
469
+ "S3Uri"
470
+ ] = inputs
438
471
expected_train_args ["enable_sagemaker_metrics" ] = False
439
- expected_train_args ["hyperparameters" ][TrainingCompilerConfig . HP_ENABLE_COMPILER ] = json . dumps (
440
- True
441
- )
442
- expected_train_args ["hyperparameters" ][TrainingCompilerConfig . HP_ENABLE_DEBUG ] = json . dumps (
443
- True
444
- )
472
+ expected_train_args ["hyperparameters" ][
473
+ TrainingCompilerConfig . HP_ENABLE_COMPILER
474
+ ] = json . dumps ( True )
475
+ expected_train_args ["hyperparameters" ][
476
+ TrainingCompilerConfig . HP_ENABLE_DEBUG
477
+ ] = json . dumps ( True )
445
478
446
479
actual_train_args = sagemaker_session .method_calls [0 ][2 ]
447
480
assert (
@@ -454,7 +487,11 @@ def test_debug_compiler_config(
454
487
@patch ("sagemaker.estimator.name_from_base" , return_value = JOB_NAME )
455
488
@patch ("time.time" , return_value = TIME )
456
489
def test_disable_compiler_config (
457
- time , name_from_base , sagemaker_session , pytorch_training_compiler_version , pytorch_training_py_version
490
+ time ,
491
+ name_from_base ,
492
+ sagemaker_session ,
493
+ pytorch_training_compiler_version ,
494
+ pytorch_training_py_version ,
458
495
):
459
496
compiler_config = TrainingCompilerConfig (enabled = False )
460
497
@@ -482,14 +519,16 @@ def test_disable_compiler_config(
482
519
expected_train_args = _create_train_job (
483
520
pytorch_training_compiler_version , INSTANCE_TYPE , compiler_config
484
521
)
485
- expected_train_args ["input_config" ][0 ]["DataSource" ]["S3DataSource" ]["S3Uri" ] = inputs
522
+ expected_train_args ["input_config" ][0 ]["DataSource" ]["S3DataSource" ][
523
+ "S3Uri"
524
+ ] = inputs
486
525
expected_train_args ["enable_sagemaker_metrics" ] = False
487
- expected_train_args ["hyperparameters" ][TrainingCompilerConfig . HP_ENABLE_COMPILER ] = json . dumps (
488
- False
489
- )
490
- expected_train_args ["hyperparameters" ][TrainingCompilerConfig . HP_ENABLE_DEBUG ] = json . dumps (
491
- False
492
- )
526
+ expected_train_args ["hyperparameters" ][
527
+ TrainingCompilerConfig . HP_ENABLE_COMPILER
528
+ ] = json . dumps ( False )
529
+ expected_train_args ["hyperparameters" ][
530
+ TrainingCompilerConfig . HP_ENABLE_DEBUG
531
+ ] = json . dumps ( False )
493
532
494
533
actual_train_args = sagemaker_session .method_calls [0 ][2 ]
495
534
assert (
@@ -508,7 +547,10 @@ def test_attach(sagemaker_session, compiler_enabled, debug_enabled):
508
547
"py38-cu113-ubuntu20.04"
509
548
)
510
549
returned_job_description = {
511
- "AlgorithmSpecification" : {"TrainingInputMode" : "File" , "TrainingImage" : training_image },
550
+ "AlgorithmSpecification" : {
551
+ "TrainingInputMode" : "File" ,
552
+ "TrainingImage" : training_image ,
553
+ },
512
554
"HyperParameters" : {
513
555
"sagemaker_submit_directory" : '"s3://some/sourcedir.tar.gz"' ,
514
556
"sagemaker_program" : '"iris-dnn-classifier.py"' ,
@@ -530,14 +572,19 @@ def test_attach(sagemaker_session, compiler_enabled, debug_enabled):
530
572
"TrainingJobName" : "trcomp" ,
531
573
"TrainingJobStatus" : "Completed" ,
532
574
"TrainingJobArn" : "arn:aws:sagemaker:us-west-2:336:training-job/trcomp" ,
533
- "OutputDataConfig" : {"KmsKeyId" : "" , "S3OutputPath" : "s3://place/output/trcomp" },
575
+ "OutputDataConfig" : {
576
+ "KmsKeyId" : "" ,
577
+ "S3OutputPath" : "s3://place/output/trcomp" ,
578
+ },
534
579
"TrainingJobOutput" : {"S3TrainingJobOutput" : "s3://here/output.tar.gz" },
535
580
}
536
581
sagemaker_session .sagemaker_client .describe_training_job = Mock (
537
582
name = "describe_training_job" , return_value = returned_job_description
538
583
)
539
584
540
- estimator = PyTorch .attach (training_job_name = "trcomp" , sagemaker_session = sagemaker_session )
585
+ estimator = PyTorch .attach (
586
+ training_job_name = "trcomp" , sagemaker_session = sagemaker_session
587
+ )
541
588
assert estimator .latest_training_job .job_name == "trcomp"
542
589
assert estimator .py_version == "py38"
543
590
assert estimator .framework_version == "1.12.0"
@@ -549,12 +596,12 @@ def test_attach(sagemaker_session, compiler_enabled, debug_enabled):
549
596
assert estimator .output_path == "s3://place/output/trcomp"
550
597
assert estimator .output_kms_key == ""
551
598
assert estimator .hyperparameters ()["training_steps" ] == "100"
552
- assert estimator .hyperparameters ()[TrainingCompilerConfig . HP_ENABLE_COMPILER ] == json . dumps (
553
- compiler_enabled
554
- )
555
- assert estimator .hyperparameters ()[TrainingCompilerConfig . HP_ENABLE_DEBUG ] == json . dumps (
556
- debug_enabled
557
- )
599
+ assert estimator .hyperparameters ()[
600
+ TrainingCompilerConfig . HP_ENABLE_COMPILER
601
+ ] == json . dumps ( compiler_enabled )
602
+ assert estimator .hyperparameters ()[
603
+ TrainingCompilerConfig . HP_ENABLE_DEBUG
604
+ ] == json . dumps ( debug_enabled )
558
605
assert estimator .source_dir == "s3://some/sourcedir.tar.gz"
559
606
assert estimator .entry_point == "iris-dnn-classifier.py"
560
607
0 commit comments