diff --git a/src/sagemaker/fw_utils.py b/src/sagemaker/fw_utils.py index 1fb8c8eaaa..449beeb55d 100644 --- a/src/sagemaker/fw_utils.py +++ b/src/sagemaker/fw_utils.py @@ -141,6 +141,7 @@ "2.0.1", "2.1.0", "2.1.2", + "2.2.0", ], } @@ -160,7 +161,14 @@ ] -TORCH_DISTRIBUTED_GPU_SUPPORTED_FRAMEWORK_VERSIONS = ["1.13.1", "2.0.0", "2.0.1", "2.1.0", "2.1.2"] +TORCH_DISTRIBUTED_GPU_SUPPORTED_FRAMEWORK_VERSIONS = [ + "1.13.1", + "2.0.0", + "2.0.1", + "2.1.0", + "2.1.2", + "2.2.0", +] TRAINIUM_SUPPORTED_DISTRIBUTION_STRATEGIES = ["torch_distributed"] TRAINIUM_SUPPORTED_TORCH_DISTRIBUTED_FRAMEWORK_VERSIONS = [ diff --git a/src/sagemaker/image_uri_config/pytorch-smp.json b/src/sagemaker/image_uri_config/pytorch-smp.json index d71c2df6ec..933c0fa437 100644 --- a/src/sagemaker/image_uri_config/pytorch-smp.json +++ b/src/sagemaker/image_uri_config/pytorch-smp.json @@ -5,7 +5,8 @@ ], "version_aliases": { "2.0": "2.0.1", - "2.1": "2.1.2" + "2.1": "2.1.2", + "2.2": "2.2.0" }, "versions": { "2.0.1": { @@ -57,6 +58,31 @@ "us-west-2": "658645717510" }, "repository": "smdistributed-modelparallel" + }, + "2.2.0": { + "py_versions": [ + "py310" + ], + "registries": { + "ap-northeast-1": "658645717510", + "ap-northeast-2": "658645717510", + "ap-northeast-3": "658645717510", + "ap-south-1": "658645717510", + "ap-southeast-1": "658645717510", + "ap-southeast-2": "658645717510", + "ca-central-1": "658645717510", + "eu-central-1": "658645717510", + "eu-north-1": "658645717510", + "eu-west-1": "658645717510", + "eu-west-2": "658645717510", + "eu-west-3": "658645717510", + "sa-east-1": "658645717510", + "us-east-1": "658645717510", + "us-east-2": "658645717510", + "us-west-1": "658645717510", + "us-west-2": "658645717510" + }, + "repository": "smdistributed-modelparallel" } } } diff --git a/src/sagemaker/image_uris.py b/src/sagemaker/image_uris.py index 8498027079..143ecc9bdb 100644 --- a/src/sagemaker/image_uris.py +++ b/src/sagemaker/image_uris.py @@ -678,7 +678,11 @@ def get_training_image_uri( if "modelparallel" in distribution["smdistributed"]: if distribution["smdistributed"]["modelparallel"].get("enabled", True): framework = "pytorch-smp" - if "p5" in instance_type or "2.1" in framework_version: + if ( + "p5" in instance_type + or "2.1" in framework_version + or "2.2" in framework_version + ): container_version = "cu121" else: container_version = "cu118" diff --git a/tests/unit/sagemaker/image_uris/test_smp_v2.py b/tests/unit/sagemaker/image_uris/test_smp_v2.py index 36accdebbb..b53a45133e 100644 --- a/tests/unit/sagemaker/image_uris/test_smp_v2.py +++ b/tests/unit/sagemaker/image_uris/test_smp_v2.py @@ -35,7 +35,7 @@ def test_smp_v2(load_config): for region in ACCOUNTS.keys(): for instance_type in CONTAINER_VERSIONS.keys(): cuda_vers = CONTAINER_VERSIONS[instance_type] - if "2.1" in version: + if "2.1" in version or "2.2" in version: cuda_vers = "cu121" uri = image_uris.get_training_image_uri(