diff --git a/src/sagemaker/fw_utils.py b/src/sagemaker/fw_utils.py index bb5190cef2..33f38fccd5 100644 --- a/src/sagemaker/fw_utils.py +++ b/src/sagemaker/fw_utils.py @@ -93,6 +93,9 @@ "1.9.1", "1.10", "1.10.0", + "1.10.2", + "1.11", + "1.11.0", ], } SMDISTRIBUTED_SUPPORTED_STRATEGIES = ["dataparallel", "modelparallel"] diff --git a/src/sagemaker/image_uri_config/pytorch.json b/src/sagemaker/image_uri_config/pytorch.json index 9c96858efe..b239d9e00c 100644 --- a/src/sagemaker/image_uri_config/pytorch.json +++ b/src/sagemaker/image_uri_config/pytorch.json @@ -534,6 +534,72 @@ "us-west-2": "763104351884" }, "repository": "pytorch-inference" + }, + "1.10.2": { + "py_versions": [ + "py38" + ], + "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "repository": "pytorch-inference" + }, + "1.11.0": { + "py_versions": [ + "py38" + ], + "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "repository": "pytorch-inference" } } }, @@ -1025,6 +1091,72 @@ "us-west-2": "763104351884" }, "repository": "pytorch-training" + }, + "1.10.2": { + "py_versions": [ + "py38" + ], + "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "repository": "pytorch-training" + }, + "1.11.0": { + "py_versions": [ + "py38" + ], + "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "repository": "pytorch-training" } } } diff --git a/tests/unit/test_fw_utils.py b/tests/unit/test_fw_utils.py index f1ffabef5e..14338b76f9 100644 --- a/tests/unit/test_fw_utils.py +++ b/tests/unit/test_fw_utils.py @@ -700,7 +700,11 @@ def test_validate_smdataparallel_args_not_raises(): ("ml.p3.16xlarge", "pytorch", "1.8", "py3", smdataparallel_enabled), ("ml.p3.16xlarge", "pytorch", "1.9.1", "py38", smdataparallel_enabled), ("ml.p3.16xlarge", "pytorch", "1.9", "py38", smdataparallel_enabled), + ("ml.p3.16xlarge", "pytorch", "1.10.0", "py38", smdataparallel_enabled), + ("ml.p3.16xlarge", "pytorch", "1.10.2", "py38", smdataparallel_enabled), ("ml.p3.16xlarge", "pytorch", "1.10", "py38", smdataparallel_enabled), + ("ml.p3.16xlarge", "pytorch", "1.11.0", "py38", smdataparallel_enabled), + ("ml.p3.16xlarge", "pytorch", "1.11", "py38", smdataparallel_enabled), ("ml.p3.16xlarge", "tensorflow", "2.4.1", "py3", smdataparallel_enabled_custom_mpi), ("ml.p3.16xlarge", "tensorflow", "2.4.1", "py37", smdataparallel_enabled_custom_mpi), ("ml.p3.16xlarge", "tensorflow", "2.4.3", "py3", smdataparallel_enabled_custom_mpi), @@ -713,6 +717,8 @@ def test_validate_smdataparallel_args_not_raises(): ("ml.p3.16xlarge", "tensorflow", "2.8.0", "py39", smdataparallel_enabled_custom_mpi), ("ml.p3.16xlarge", "pytorch", "1.8.0", "py3", smdataparallel_enabled_custom_mpi), ("ml.p3.16xlarge", "pytorch", "1.9.1", "py38", smdataparallel_enabled_custom_mpi), + ("ml.p3.16xlarge", "pytorch", "1.10.2", "py38", smdataparallel_enabled_custom_mpi), + ("ml.p3.16xlarge", "pytorch", "1.11.0", "py38", smdataparallel_enabled_custom_mpi), ] for instance_type, framework_name, framework_version, py_version, distribution in good_args: fw_utils._validate_smdataparallel_args(