From 590f3d8599c6d4a40c773398fbeb69c64e42e378 Mon Sep 17 00:00:00 2001 From: Andrew Tian Date: Mon, 14 Oct 2024 10:50:54 -0700 Subject: [PATCH] changes for PT 2.4 currency upgrade --- src/sagemaker/fw_utils.py | 1 + .../image_uri_config/pytorch-smp.json | 28 ++++++++++++++++++- src/sagemaker/image_uris.py | 1 + .../unit/sagemaker/image_uris/test_smp_v2.py | 11 +++++--- 4 files changed, 36 insertions(+), 5 deletions(-) diff --git a/src/sagemaker/fw_utils.py b/src/sagemaker/fw_utils.py index 9a0e46d1a0..0ddb3cd255 100644 --- a/src/sagemaker/fw_utils.py +++ b/src/sagemaker/fw_utils.py @@ -153,6 +153,7 @@ "2.1.2", "2.2.0", "2.3.1", + "2.4.1", ] TRAINIUM_SUPPORTED_DISTRIBUTION_STRATEGIES = ["torch_distributed"] diff --git a/src/sagemaker/image_uri_config/pytorch-smp.json b/src/sagemaker/image_uri_config/pytorch-smp.json index 3a119c81d2..18eb0fb4c3 100644 --- a/src/sagemaker/image_uri_config/pytorch-smp.json +++ b/src/sagemaker/image_uri_config/pytorch-smp.json @@ -8,7 +8,8 @@ "2.1": "2.1.2", "2.2": "2.3.1", "2.2.0": "2.3.1", - "2.3.1": "2.5.0" + "2.3.1": "2.5.0", + "2.4.1": "2.6.0" }, "versions": { "2.0.1": { @@ -160,6 +161,31 @@ "us-west-2": "658645717510" }, "repository": "smdistributed-modelparallel" + }, + "2.6.0": { + "py_versions": [ + "py311" + ], + "registries": { + "ap-northeast-1": "658645717510", + "ap-northeast-2": "658645717510", + "ap-northeast-3": "658645717510", + "ap-south-1": "658645717510", + "ap-southeast-1": "658645717510", + "ap-southeast-2": "658645717510", + "ca-central-1": "658645717510", + "eu-central-1": "658645717510", + "eu-north-1": "658645717510", + "eu-west-1": "658645717510", + "eu-west-2": "658645717510", + "eu-west-3": "658645717510", + "sa-east-1": "658645717510", + "us-east-1": "658645717510", + "us-east-2": "658645717510", + "us-west-1": "658645717510", + "us-west-2": "658645717510" + }, + "repository": "smdistributed-modelparallel" } } } diff --git a/src/sagemaker/image_uris.py b/src/sagemaker/image_uris.py index 95080b8406..dd7012b2f2 100644 --- a/src/sagemaker/image_uris.py +++ b/src/sagemaker/image_uris.py @@ -699,6 +699,7 @@ def get_training_image_uri( or "2.1" in framework_version or "2.2" in framework_version or "2.3" in framework_version + or "2.4" in framework_version ): container_version = "cu121" else: diff --git a/tests/unit/sagemaker/image_uris/test_smp_v2.py b/tests/unit/sagemaker/image_uris/test_smp_v2.py index 4fd1cc6179..b1297822f7 100644 --- a/tests/unit/sagemaker/image_uris/test_smp_v2.py +++ b/tests/unit/sagemaker/image_uris/test_smp_v2.py @@ -36,14 +36,17 @@ def test_smp_v2(load_config): for region in ACCOUNTS.keys(): for instance_type in CONTAINER_VERSIONS.keys(): cuda_vers = CONTAINER_VERSIONS[instance_type] - if "2.1" in version or "2.2" in version or "2.3" in version: + if ( + "2.1" in version + or "2.2" in version + or "2.3" in version + or "2.4" in version + ): cuda_vers = "cu121" - if "2.3.1" == version: + if "2.3.1" == version or "2.4.1" == version: py_version = "py311" - print(version, py_version) - uri = image_uris.get_training_image_uri( region, framework="pytorch",