We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 87be0fc commit 1f8e744Copy full SHA for 1f8e744
src/sagemaker/fw_utils.py
@@ -804,14 +804,16 @@ def validate_pytorch_distribution(
804
`py_version` is not python3 or
805
`framework_version` is not in PYTORCHDDP_SUPPORTED_FRAMEWORK_VERSIONS
806
"""
807
- if framework_name != "pytorch":
+ if framework_name and framework_name != "pytorch":
808
# We need to validate only for PyTorch framework
809
return
810
+
811
+ pytorch_ddp_enabled = False
812
if "pytorchddp" in distribution:
813
pytorch_ddp_enabled = distribution.get("pytorchddp").get("enabled", False)
- if not pytorch_ddp_enabled:
- # Distribution strategy other than pytorchddp is selected
814
- return
+ if not pytorch_ddp_enabled:
815
+ # Distribution strategy other than pytorchddp is selected
816
+ return
817
818
err_msg = ""
819
if not image_uri:
0 commit comments