Skip to content

Commit 1f8e744

Browse files
committed
fix unit tests
1 parent 87be0fc commit 1f8e744

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

src/sagemaker/fw_utils.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -804,14 +804,16 @@ def validate_pytorch_distribution(
804804
`py_version` is not python3 or
805805
`framework_version` is not in PYTORCHDDP_SUPPORTED_FRAMEWORK_VERSIONS
806806
"""
807-
if framework_name != "pytorch":
807+
if framework_name and framework_name != "pytorch":
808808
# We need to validate only for PyTorch framework
809809
return
810+
811+
pytorch_ddp_enabled = False
810812
if "pytorchddp" in distribution:
811813
pytorch_ddp_enabled = distribution.get("pytorchddp").get("enabled", False)
812-
if not pytorch_ddp_enabled:
813-
# Distribution strategy other than pytorchddp is selected
814-
return
814+
if not pytorch_ddp_enabled:
815+
# Distribution strategy other than pytorchddp is selected
816+
return
815817

816818
err_msg = ""
817819
if not image_uri:

0 commit comments

Comments
 (0)