Skip to content

Commit b731021

Browse files
BruceChang9688BruceZhang@eitug
and
BruceZhang@eitug
authored
feature: Adding support for SageMaker Training Compiler PyTorch 1.13 (#3629)
Co-authored-by: BruceZhang@eitug <[email protected]>
1 parent 5cf3e44 commit b731021

File tree

4 files changed

+105
-22
lines changed

4 files changed

+105
-22
lines changed

src/sagemaker/image_uri_config/pytorch-training-compiler.json

+31-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
"gpu"
55
],
66
"version_aliases": {
7-
"1.12": "1.12.0"
7+
"1.12": "1.12.0",
8+
"1.13": "1.13.1"
89
},
910
"versions": {
1011
"1.12.0": {
@@ -35,6 +36,35 @@
3536
"us-west-2": "763104351884"
3637
},
3738
"repository": "pytorch-trcomp-training"
39+
},
40+
"1.13.1": {
41+
"py_versions": [
42+
"py39"
43+
],
44+
"registries": {
45+
"af-south-1": "626614931356",
46+
"ap-east-1": "871362719292",
47+
"ap-northeast-1": "763104351884",
48+
"ap-northeast-2": "763104351884",
49+
"ap-northeast-3": "364406365360",
50+
"ap-south-1": "763104351884",
51+
"ap-southeast-1": "763104351884",
52+
"ap-southeast-2": "763104351884",
53+
"ca-central-1": "763104351884",
54+
"eu-central-1": "763104351884",
55+
"eu-north-1": "763104351884",
56+
"eu-west-1": "763104351884",
57+
"eu-west-2": "763104351884",
58+
"eu-west-3": "763104351884",
59+
"eu-south-1": "692866216735",
60+
"me-south-1": "217643126080",
61+
"sa-east-1": "763104351884",
62+
"us-east-1": "763104351884",
63+
"us-east-2": "763104351884",
64+
"us-west-1": "763104351884",
65+
"us-west-2": "763104351884"
66+
},
67+
"repository": "pytorch-trcomp-training"
3868
}
3969
}
4070
}

tests/conftest.py

+10
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,16 @@ def huggingface_pytorch_latest_training_py_version(
337337
)
338338

339339

340+
@pytest.fixture(scope="module")
341+
def pytorch_training_compiler_py_version(
342+
pytorch_training_compiler_version,
343+
):
344+
return "py39" if Version(pytorch_training_compiler_version) > Version("1.12") else "py38"
345+
346+
347+
# TODO: Create a fixture to get the latest py version from TRCOMP image_uri.
348+
349+
340350
@pytest.fixture(scope="module")
341351
def huggingface_pytorch_latest_inference_py_version(
342352
huggingface_inference_pytorch_latest_version,

tests/integ/test_training_compiler.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def test_pytorch(
150150
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
151151

152152
hf = PyTorch(
153-
py_version="py38",
153+
py_version="py39",
154154
source_dir=os.path.join(DATA_DIR, "huggingface_byoc"),
155155
entry_point="run_glue.py",
156156
role="SageMakerRole",
@@ -216,7 +216,10 @@ def test_huggingface_tensorflow(
216216

217217
@pytest.mark.release
218218
def test_tensorflow(
219-
sagemaker_session, gpu_instance_type, tensorflow_training_latest_version, imagenet_val_set
219+
sagemaker_session,
220+
gpu_instance_type,
221+
tensorflow_training_latest_version,
222+
imagenet_val_set,
220223
):
221224
"""
222225
Test the TensorFlow estimator

tests/unit/sagemaker/training_compiler/test_pytorch_compiler.py

+59-19
Original file line numberDiff line numberDiff line change
@@ -83,22 +83,26 @@ def fixture_sagemaker_session():
8383
return session
8484

8585

86-
def _get_full_gpu_image_uri(version, instance_type, training_compiler_config):
86+
def _get_full_gpu_image_uri(version, instance_type, training_compiler_config, py_version):
8787
return image_uris.retrieve(
8888
"pytorch-training-compiler",
8989
REGION,
9090
version=version,
91-
py_version="py38",
91+
py_version=py_version,
9292
instance_type=instance_type,
9393
image_scope="training",
9494
container_version=None,
9595
training_compiler_config=training_compiler_config,
9696
)
9797

9898

99-
def _create_train_job(version, instance_type, training_compiler_config, instance_count=1):
99+
def _create_train_job(
100+
version, instance_type, training_compiler_config, py_version, instance_count=1
101+
):
100102
return {
101-
"image_uri": _get_full_gpu_image_uri(version, instance_type, training_compiler_config),
103+
"image_uri": _get_full_gpu_image_uri(
104+
version, instance_type, training_compiler_config, py_version
105+
),
102106
"input_mode": "File",
103107
"input_config": [
104108
{
@@ -303,15 +307,20 @@ def test_unsupported_distribution(
303307
@patch("time.time", return_value=TIME)
304308
@pytest.mark.parametrize("instance_class", SUPPORTED_GPU_INSTANCE_CLASSES)
305309
def test_pytorchxla_distribution(
306-
time, name_from_base, sagemaker_session, pytorch_training_compiler_version, instance_class
310+
time,
311+
name_from_base,
312+
sagemaker_session,
313+
pytorch_training_compiler_version,
314+
instance_class,
315+
pytorch_training_compiler_py_version,
307316
):
308317
if Version(pytorch_training_compiler_version) < Version("1.12"):
309318
pytest.skip("This test is intended for PyTorch 1.12 and above")
310319
compiler_config = TrainingCompilerConfig()
311320
instance_type = f"ml.{instance_class}.xlarge"
312321

313322
pt = PyTorch(
314-
py_version="py38",
323+
py_version=pytorch_training_compiler_py_version,
315324
entry_point=SCRIPT_PATH,
316325
role=ROLE,
317326
sagemaker_session=sagemaker_session,
@@ -333,7 +342,11 @@ def test_pytorchxla_distribution(
333342
assert boto_call_names == ["resource"]
334343

335344
expected_train_args = _create_train_job(
336-
pytorch_training_compiler_version, instance_type, compiler_config, instance_count=2
345+
pytorch_training_compiler_version,
346+
instance_type,
347+
compiler_config,
348+
pytorch_training_compiler_py_version,
349+
instance_count=2,
337350
)
338351
expected_train_args["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] = inputs
339352
expected_train_args["enable_sagemaker_metrics"] = False
@@ -357,13 +370,17 @@ def test_pytorchxla_distribution(
357370
@patch("time.time", return_value=TIME)
358371
@pytest.mark.parametrize("instance_class", SUPPORTED_GPU_INSTANCE_CLASSES)
359372
def test_default_compiler_config(
360-
time, name_from_base, sagemaker_session, pytorch_training_compiler_version, instance_class
373+
time,
374+
name_from_base,
375+
sagemaker_session,
376+
pytorch_training_compiler_version,
377+
instance_class,
378+
pytorch_training_compiler_py_version,
361379
):
362380
compiler_config = TrainingCompilerConfig()
363381
instance_type = f"ml.{instance_class}.xlarge"
364-
365382
pt = PyTorch(
366-
py_version="py38",
383+
py_version=pytorch_training_compiler_py_version,
367384
entry_point=SCRIPT_PATH,
368385
role=ROLE,
369386
sagemaker_session=sagemaker_session,
@@ -384,7 +401,10 @@ def test_default_compiler_config(
384401
assert boto_call_names == ["resource"]
385402

386403
expected_train_args = _create_train_job(
387-
pytorch_training_compiler_version, instance_type, compiler_config
404+
pytorch_training_compiler_version,
405+
instance_type,
406+
compiler_config,
407+
pytorch_training_compiler_py_version,
388408
)
389409
expected_train_args["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] = inputs
390410
expected_train_args["enable_sagemaker_metrics"] = False
@@ -406,12 +426,16 @@ def test_default_compiler_config(
406426
@patch("sagemaker.estimator.name_from_base", return_value=JOB_NAME)
407427
@patch("time.time", return_value=TIME)
408428
def test_debug_compiler_config(
409-
time, name_from_base, sagemaker_session, pytorch_training_compiler_version
429+
time,
430+
name_from_base,
431+
sagemaker_session,
432+
pytorch_training_compiler_version,
433+
pytorch_training_compiler_py_version,
410434
):
411435
compiler_config = TrainingCompilerConfig(debug=True)
412436

413437
pt = PyTorch(
414-
py_version="py38",
438+
py_version=pytorch_training_compiler_py_version,
415439
entry_point=SCRIPT_PATH,
416440
role=ROLE,
417441
sagemaker_session=sagemaker_session,
@@ -432,7 +456,10 @@ def test_debug_compiler_config(
432456
assert boto_call_names == ["resource"]
433457

434458
expected_train_args = _create_train_job(
435-
pytorch_training_compiler_version, INSTANCE_TYPE, compiler_config
459+
pytorch_training_compiler_version,
460+
INSTANCE_TYPE,
461+
compiler_config,
462+
pytorch_training_compiler_py_version,
436463
)
437464
expected_train_args["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] = inputs
438465
expected_train_args["enable_sagemaker_metrics"] = False
@@ -454,12 +481,16 @@ def test_debug_compiler_config(
454481
@patch("sagemaker.estimator.name_from_base", return_value=JOB_NAME)
455482
@patch("time.time", return_value=TIME)
456483
def test_disable_compiler_config(
457-
time, name_from_base, sagemaker_session, pytorch_training_compiler_version
484+
time,
485+
name_from_base,
486+
sagemaker_session,
487+
pytorch_training_compiler_version,
488+
pytorch_training_compiler_py_version,
458489
):
459490
compiler_config = TrainingCompilerConfig(enabled=False)
460491

461492
pt = PyTorch(
462-
py_version="py38",
493+
py_version=pytorch_training_compiler_py_version,
463494
entry_point=SCRIPT_PATH,
464495
role=ROLE,
465496
sagemaker_session=sagemaker_session,
@@ -480,7 +511,10 @@ def test_disable_compiler_config(
480511
assert boto_call_names == ["resource"]
481512

482513
expected_train_args = _create_train_job(
483-
pytorch_training_compiler_version, INSTANCE_TYPE, compiler_config
514+
pytorch_training_compiler_version,
515+
INSTANCE_TYPE,
516+
compiler_config,
517+
pytorch_training_compiler_py_version,
484518
)
485519
expected_train_args["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] = inputs
486520
expected_train_args["enable_sagemaker_metrics"] = False
@@ -508,7 +542,10 @@ def test_attach(sagemaker_session, compiler_enabled, debug_enabled):
508542
"py38-cu113-ubuntu20.04"
509543
)
510544
returned_job_description = {
511-
"AlgorithmSpecification": {"TrainingInputMode": "File", "TrainingImage": training_image},
545+
"AlgorithmSpecification": {
546+
"TrainingInputMode": "File",
547+
"TrainingImage": training_image,
548+
},
512549
"HyperParameters": {
513550
"sagemaker_submit_directory": '"s3://some/sourcedir.tar.gz"',
514551
"sagemaker_program": '"iris-dnn-classifier.py"',
@@ -530,7 +567,10 @@ def test_attach(sagemaker_session, compiler_enabled, debug_enabled):
530567
"TrainingJobName": "trcomp",
531568
"TrainingJobStatus": "Completed",
532569
"TrainingJobArn": "arn:aws:sagemaker:us-west-2:336:training-job/trcomp",
533-
"OutputDataConfig": {"KmsKeyId": "", "S3OutputPath": "s3://place/output/trcomp"},
570+
"OutputDataConfig": {
571+
"KmsKeyId": "",
572+
"S3OutputPath": "s3://place/output/trcomp",
573+
},
534574
"TrainingJobOutput": {"S3TrainingJobOutput": "s3://here/output.tar.gz"},
535575
}
536576
sagemaker_session.sagemaker_client.describe_training_job = Mock(

0 commit comments

Comments
 (0)