Skip to content

Commit 25c49d4

Browse files
Mike SchneiderShibo Xing
Mike Schneider
and
Shibo Xing
authored
feature: Add PyTorch 1.13.1 to SDK (#3587)
Co-authored-by: Shibo Xing <[email protected]>
1 parent 4d95b05 commit 25c49d4

File tree

6 files changed

+160
-26
lines changed

6 files changed

+160
-26
lines changed

src/sagemaker/fw_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@
131131
"1.12",
132132
"1.12.0",
133133
"1.12.1",
134+
"1.13.1",
134135
],
135136
}
136137

@@ -143,6 +144,7 @@
143144
"1.12",
144145
"1.12.0",
145146
"1.12.1",
147+
"1.13.1",
146148
]
147149

148150

src/sagemaker/image_uri_config/pytorch.json

Lines changed: 76 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,8 @@
7474
"1.9": "1.9.1",
7575
"1.10": "1.10.2",
7676
"1.11": "1.11.0",
77-
"1.12": "1.12.1"
77+
"1.12": "1.12.1",
78+
"1.13": "1.13.1"
7879
},
7980
"versions": {
8081
"0.4.0": {
@@ -783,6 +784,42 @@
783784
"us-west-2": "763104351884"
784785
},
785786
"repository": "pytorch-inference"
787+
},
788+
"1.13.1": {
789+
"py_versions": [
790+
"py39"
791+
],
792+
"registries": {
793+
"af-south-1": "626614931356",
794+
"ap-east-1": "871362719292",
795+
"ap-northeast-1": "763104351884",
796+
"ap-northeast-2": "763104351884",
797+
"ap-northeast-3": "364406365360",
798+
"ap-south-1": "763104351884",
799+
"ap-southeast-1": "763104351884",
800+
"ap-southeast-2": "763104351884",
801+
"ap-southeast-3": "907027046896",
802+
"ca-central-1": "763104351884",
803+
"cn-north-1": "727897471807",
804+
"cn-northwest-1": "727897471807",
805+
"eu-central-1": "763104351884",
806+
"eu-north-1": "763104351884",
807+
"eu-west-1": "763104351884",
808+
"eu-west-2": "763104351884",
809+
"eu-west-3": "763104351884",
810+
"eu-south-1": "692866216735",
811+
"me-south-1": "217643126080",
812+
"sa-east-1": "763104351884",
813+
"us-east-1": "763104351884",
814+
"us-east-2": "763104351884",
815+
"us-gov-east-1": "446045086412",
816+
"us-gov-west-1": "442386744353",
817+
"us-iso-east-1": "886529160074",
818+
"us-isob-east-1": "094389454867",
819+
"us-west-1": "763104351884",
820+
"us-west-2": "763104351884"
821+
},
822+
"repository": "pytorch-inference"
786823
}
787824
}
788825
},
@@ -855,7 +892,8 @@
855892
"1.9": "1.9.1",
856893
"1.10": "1.10.2",
857894
"1.11": "1.11.0",
858-
"1.12": "1.12.1"
895+
"1.12": "1.12.1",
896+
"1.13": "1.13.1"
859897
},
860898
"versions": {
861899
"0.4.0": {
@@ -1520,6 +1558,42 @@
15201558
"us-west-2": "763104351884"
15211559
},
15221560
"repository": "pytorch-training"
1561+
},
1562+
"1.13.1": {
1563+
"py_versions": [
1564+
"py39"
1565+
],
1566+
"registries": {
1567+
"af-south-1": "626614931356",
1568+
"ap-east-1": "871362719292",
1569+
"ap-northeast-1": "763104351884",
1570+
"ap-northeast-2": "763104351884",
1571+
"ap-northeast-3": "364406365360",
1572+
"ap-south-1": "763104351884",
1573+
"ap-southeast-1": "763104351884",
1574+
"ap-southeast-2": "763104351884",
1575+
"ap-southeast-3": "907027046896",
1576+
"ca-central-1": "763104351884",
1577+
"cn-north-1": "727897471807",
1578+
"cn-northwest-1": "727897471807",
1579+
"eu-central-1": "763104351884",
1580+
"eu-north-1": "763104351884",
1581+
"eu-west-1": "763104351884",
1582+
"eu-west-2": "763104351884",
1583+
"eu-west-3": "763104351884",
1584+
"eu-south-1": "692866216735",
1585+
"me-south-1": "217643126080",
1586+
"sa-east-1": "763104351884",
1587+
"us-east-1": "763104351884",
1588+
"us-east-2": "763104351884",
1589+
"us-gov-east-1": "446045086412",
1590+
"us-gov-west-1": "442386744353",
1591+
"us-iso-east-1": "886529160074",
1592+
"us-isob-east-1": "094389454867",
1593+
"us-west-1": "763104351884",
1594+
"us-west-2": "763104351884"
1595+
},
1596+
"repository": "pytorch-training"
15231597
}
15241598
}
15251599
}

tests/conftest.py

Lines changed: 52 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,8 @@
8686
"huggingface_training_compiler",
8787
)
8888

89+
PYTORCH_RENEWED_GPU = "ml.g4dn.xlarge"
90+
8991

9092
def pytest_addoption(parser):
9193
parser.addoption("--sagemaker-client-config", action="store", default=None)
@@ -221,22 +223,26 @@ def mxnet_eia_latest_py_version():
221223

222224
@pytest.fixture(scope="module", params=["py2", "py3"])
223225
def pytorch_training_py_version(pytorch_training_version, request):
224-
if Version(pytorch_training_version) < Version("1.5.0"):
225-
return request.param
226+
if Version(pytorch_training_version) >= Version("1.13"):
227+
return "py39"
226228
elif Version(pytorch_training_version) >= Version("1.9"):
227229
return "py38"
228-
else:
230+
elif Version(pytorch_training_version) >= Version("1.5.0"):
229231
return "py3"
232+
else:
233+
return request.param
230234

231235

232236
@pytest.fixture(scope="module", params=["py2", "py3"])
233237
def pytorch_inference_py_version(pytorch_inference_version, request):
234-
if Version(pytorch_inference_version) < Version("1.4.0"):
235-
return request.param
238+
if Version(pytorch_inference_version) >= Version("1.13"):
239+
return "py39"
236240
elif Version(pytorch_inference_version) >= Version("1.9"):
237241
return "py38"
238-
else:
242+
elif Version(pytorch_inference_version) >= Version("1.4.0"):
239243
return "py3"
244+
else:
245+
return request.param
240246

241247

242248
@pytest.fixture(scope="module")
@@ -252,9 +258,13 @@ def huggingface_pytorch_training_py_version(huggingface_pytorch_training_version
252258

253259

254260
@pytest.fixture(scope="module")
255-
def huggingface_training_compiler_pytorch_version(huggingface_training_compiler_version):
261+
def huggingface_training_compiler_pytorch_version(
262+
huggingface_training_compiler_version,
263+
):
256264
versions = _huggingface_base_fm_version(
257-
huggingface_training_compiler_version, "pytorch", "huggingface_training_compiler"
265+
huggingface_training_compiler_version,
266+
"pytorch",
267+
"huggingface_training_compiler",
258268
)
259269
if not versions:
260270
pytest.skip(
@@ -265,9 +275,13 @@ def huggingface_training_compiler_pytorch_version(huggingface_training_compiler_
265275

266276

267277
@pytest.fixture(scope="module")
268-
def huggingface_training_compiler_tensorflow_version(huggingface_training_compiler_version):
278+
def huggingface_training_compiler_tensorflow_version(
279+
huggingface_training_compiler_version,
280+
):
269281
versions = _huggingface_base_fm_version(
270-
huggingface_training_compiler_version, "tensorflow", "huggingface_training_compiler"
282+
huggingface_training_compiler_version,
283+
"tensorflow",
284+
"huggingface_training_compiler",
271285
)
272286
if not versions:
273287
pytest.skip(
@@ -289,19 +303,25 @@ def huggingface_training_compiler_tensorflow_py_version(
289303

290304

291305
@pytest.fixture(scope="module")
292-
def huggingface_training_compiler_pytorch_py_version(huggingface_training_compiler_pytorch_version):
306+
def huggingface_training_compiler_pytorch_py_version(
307+
huggingface_training_compiler_pytorch_version,
308+
):
293309
return "py38"
294310

295311

296312
@pytest.fixture(scope="module")
297-
def huggingface_pytorch_latest_training_py_version(huggingface_training_pytorch_latest_version):
313+
def huggingface_pytorch_latest_training_py_version(
314+
huggingface_training_pytorch_latest_version,
315+
):
298316
return (
299317
"py38" if Version(huggingface_training_pytorch_latest_version) >= Version("1.9") else "py36"
300318
)
301319

302320

303321
@pytest.fixture(scope="module")
304-
def huggingface_pytorch_latest_inference_py_version(huggingface_inference_pytorch_latest_version):
322+
def huggingface_pytorch_latest_inference_py_version(
323+
huggingface_inference_pytorch_latest_version,
324+
):
305325
return (
306326
"py38"
307327
if Version(huggingface_inference_pytorch_latest_version) >= Version("1.9")
@@ -477,7 +497,8 @@ def pytorch_ddp_py_version():
477497

478498

479499
@pytest.fixture(
480-
scope="module", params=["1.10", "1.10.0", "1.10.2", "1.11", "1.11.0", "1.12", "1.12.0"]
500+
scope="module",
501+
params=["1.10", "1.10.0", "1.10.2", "1.11", "1.11.0", "1.12", "1.12.0"],
481502
)
482503
def pytorch_ddp_framework_version(request):
483504
return request.param
@@ -511,6 +532,23 @@ def gpu_instance_type(sagemaker_session, request):
511532
return "ml.p3.2xlarge"
512533

513534

535+
@pytest.fixture()
536+
def gpu_pytorch_instance_type(sagemaker_session, request):
537+
if "pytorch_inference_version" in request.fixturenames:
538+
fw_version = request.getfixturevalue("pytorch_inference_version")
539+
else:
540+
fw_version = request.param
541+
542+
region = sagemaker_session.boto_session.region_name
543+
if region in NO_P3_REGIONS:
544+
if Version(fw_version) >= Version("1.13"):
545+
return PYTORCH_RENEWED_GPU
546+
else:
547+
return "ml.p2.xlarge"
548+
else:
549+
return "ml.p3.2xlarge"
550+
551+
514552
@pytest.fixture(scope="session")
515553
def gpu_instance_type_list(sagemaker_session, request):
516554
region = sagemaker_session.boto_session.region_name

tests/unit/sagemaker/image_uris/test_dlc_frameworks.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from tests.unit.sagemaker.image_uris import expected_uris
1919

2020
INSTANCE_TYPES_AND_PROCESSORS = (("ml.c4.xlarge", "cpu"), ("ml.p2.xlarge", "gpu"))
21+
RENEWED_PYTORCH_INSTANCE_TYPES_AND_PROCESSORS = (("ml.c4.xlarge", "cpu"), ("ml.g4dn.xlarge", "gpu"))
2122
REGION = "us-west-2"
2223

2324
DLC_ACCOUNT = "763104351884"
@@ -70,7 +71,12 @@ def _test_image_uris(
7071
"image_scope": scope,
7172
}
7273

73-
for instance_type, processor in INSTANCE_TYPES_AND_PROCESSORS:
74+
TYPES_AND_PROCESSORS = INSTANCE_TYPES_AND_PROCESSORS
75+
if framework == "pytorch" and Version(fw_version) >= Version("1.13"):
76+
"""Handle P2 deprecation"""
77+
TYPES_AND_PROCESSORS = RENEWED_PYTORCH_INSTANCE_TYPES_AND_PROCESSORS
78+
79+
for instance_type, processor in TYPES_AND_PROCESSORS:
7480
uri = image_uris.retrieve(region=REGION, instance_type=instance_type, **base_args)
7581

7682
expected = expected_fn(processor=processor, **expected_fn_args)

tests/unit/test_fw_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -912,6 +912,7 @@ def test_validate_smdataparallel_args_not_raises():
912912
("ml.p3.16xlarge", "pytorch", "1.12.0", "py38", smdataparallel_enabled),
913913
("ml.p3.16xlarge", "pytorch", "1.12.1", "py38", smdataparallel_enabled),
914914
("ml.p3.16xlarge", "pytorch", "1.12", "py38", smdataparallel_enabled),
915+
("ml.p3.16xlarge", "pytorch", "1.13.1", "py39", smdataparallel_enabled),
915916
("ml.p3.16xlarge", "tensorflow", "2.4.1", "py3", smdataparallel_enabled_custom_mpi),
916917
("ml.p3.16xlarge", "tensorflow", "2.4.1", "py37", smdataparallel_enabled_custom_mpi),
917918
("ml.p3.16xlarge", "tensorflow", "2.4.3", "py3", smdataparallel_enabled_custom_mpi),
@@ -932,6 +933,7 @@ def test_validate_smdataparallel_args_not_raises():
932933
("ml.p3.16xlarge", "pytorch", "1.11.0", "py38", smdataparallel_enabled_custom_mpi),
933934
("ml.p3.16xlarge", "pytorch", "1.12.0", "py38", smdataparallel_enabled_custom_mpi),
934935
("ml.p3.16xlarge", "pytorch", "1.12.1", "py38", smdataparallel_enabled_custom_mpi),
936+
("ml.p3.16xlarge", "pytorch", "1.13.1", "py39", smdataparallel_enabled_custom_mpi),
935937
]
936938
for instance_type, framework_name, framework_version, py_version, distribution in good_args:
937939
fw_utils._validate_smdataparallel_args(

tests/unit/test_pytorch.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,12 @@ def test_create_model_with_custom_image(name_from_base, sagemaker_session):
302302
@patch("sagemaker.estimator.name_from_base", return_value=JOB_NAME)
303303
@patch("time.time", return_value=TIME)
304304
def test_pytorch(
305-
time, name_from_base, sagemaker_session, pytorch_inference_version, pytorch_inference_py_version
305+
time,
306+
name_from_base,
307+
sagemaker_session,
308+
pytorch_inference_version,
309+
pytorch_inference_py_version,
310+
gpu_pytorch_instance_type,
306311
):
307312
pytorch = PyTorch(
308313
entry_point=SCRIPT_PATH,
@@ -339,24 +344,29 @@ def test_pytorch(
339344
REGION,
340345
version=pytorch_inference_version,
341346
py_version=pytorch_inference_py_version,
342-
instance_type=GPU,
347+
instance_type=gpu_pytorch_instance_type,
343348
image_scope="inference",
344349
)
345350

346-
actual_environment = model.prepare_container_def(GPU)
351+
actual_environment = model.prepare_container_def(gpu_pytorch_instance_type)
347352
submit_directory = actual_environment["Environment"]["SAGEMAKER_SUBMIT_DIRECTORY"]
348353
model_url = actual_environment["ModelDataUrl"]
349354
expected_environment = _get_environment(submit_directory, model_url, expected_image_uri)
350355
assert actual_environment == expected_environment
351356

352357
assert "cpu" in model.prepare_container_def(CPU)["Image"]
353-
predictor = pytorch.deploy(1, GPU)
358+
predictor = pytorch.deploy(1, gpu_pytorch_instance_type)
354359
assert isinstance(predictor, PyTorchPredictor)
355360

356361

357362
@patch("sagemaker.utils.repack_model", MagicMock())
358363
@patch("sagemaker.utils.create_tar_file", MagicMock())
359-
def test_model(sagemaker_session, pytorch_inference_version, pytorch_inference_py_version):
364+
def test_model(
365+
sagemaker_session,
366+
pytorch_inference_version,
367+
pytorch_inference_py_version,
368+
gpu_pytorch_instance_type,
369+
):
360370
model = PyTorchModel(
361371
MODEL_DATA,
362372
role=ROLE,
@@ -365,21 +375,22 @@ def test_model(sagemaker_session, pytorch_inference_version, pytorch_inference_p
365375
py_version=pytorch_inference_py_version,
366376
sagemaker_session=sagemaker_session,
367377
)
368-
predictor = model.deploy(1, GPU)
378+
predictor = model.deploy(1, gpu_pytorch_instance_type)
369379
assert isinstance(predictor, PyTorchPredictor)
370380

371381

372382
@patch("sagemaker.utils.create_tar_file", MagicMock())
373383
@patch("sagemaker.utils.repack_model")
374-
def test_mms_model(repack_model, sagemaker_session):
384+
@pytest.mark.parametrize("gpu_pytorch_instance_type", ["1.2"], indirect=True)
385+
def test_mms_model(repack_model, sagemaker_session, gpu_pytorch_instance_type):
375386
PyTorchModel(
376387
MODEL_DATA,
377388
role=ROLE,
378389
entry_point=SCRIPT_PATH,
379390
sagemaker_session=sagemaker_session,
380391
framework_version="1.2",
381392
py_version="py3",
382-
).deploy(1, GPU)
393+
).deploy(1, gpu_pytorch_instance_type)
383394

384395
repack_model.assert_called_with(
385396
dependencies=[],
@@ -428,6 +439,7 @@ def test_model_custom_serialization(
428439
sagemaker_session,
429440
pytorch_inference_version,
430441
pytorch_inference_py_version,
442+
gpu_pytorch_instance_type,
431443
):
432444
model = PyTorchModel(
433445
MODEL_DATA,
@@ -441,7 +453,7 @@ def test_model_custom_serialization(
441453
custom_deserializer = Mock()
442454
predictor = model.deploy(
443455
1,
444-
GPU,
456+
gpu_pytorch_instance_type,
445457
serializer=custom_serializer,
446458
deserializer=custom_deserializer,
447459
)

0 commit comments

Comments
 (0)