We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 42c58eb commit 9d202e6Copy full SHA for 9d202e6
src/sagemaker/image_uris.py
@@ -667,7 +667,7 @@ def get_training_image_uri(
667
if distribution["smdistributed"]["modelparallel"].get("enabled", True):
668
framework = "pytorch-smp"
669
if "p5" in instance_type:
670
- container_version = "cu12"
+ container_version = "cu121"
671
else:
672
container_version = "cu118"
673
tests/unit/sagemaker/image_uris/test_smp_v2.py
@@ -16,7 +16,7 @@
16
from sagemaker import image_uris
17
from tests.unit.sagemaker.image_uris import expected_uris
18
19
-CONTAINER_VERSIONS = {"ml.p4d.24xlarge": "cu118", "ml.p5d.24xlarge": "cu12"}
+CONTAINER_VERSIONS = {"ml.p4d.24xlarge": "cu118", "ml.p5d.24xlarge": "cu121"}
20
21
22
@pytest.mark.parametrize("load_config", ["pytorch-smp.json"], indirect=True)
0 commit comments