Skip to content

Commit cce46ad

Browse files
committed
fix: format fix
1 parent c56d79d commit cce46ad

File tree

2 files changed

+12
-9
lines changed

2 files changed

+12
-9
lines changed

src/sagemaker/fw_utils.py

+10-8
Original file line numberDiff line numberDiff line change
@@ -1077,8 +1077,10 @@ def validate_torch_distributed_distribution(
10771077
if not image_uri:
10781078
# ignore framework_version and py_version if image_uri is set
10791079
# in case image_uri is not set, then both are mandatory
1080-
if framework_version not in TORCH_DISTRIBUTED_SUPPORTED_FRAMEWORK_VERSIONS or \
1081-
framework_version not in TORCH_DISTRIBUTED_TRAINIUM_SUPPORTED_FRAMEWORK_VERSIONS:
1080+
if (
1081+
framework_version not in TORCH_DISTRIBUTED_SUPPORTED_FRAMEWORK_VERSIONS
1082+
or framework_version not in TORCH_DISTRIBUTED_TRAINIUM_SUPPORTED_FRAMEWORK_VERSIONS
1083+
):
10821084
err_msg += (
10831085
f"Provided framework_version {framework_version} is not supported by"
10841086
" torch_distributed.\n"
@@ -1093,16 +1095,16 @@ def validate_torch_distributed_distribution(
10931095

10941096
# Check instance compatibility
10951097
if not _is_gpu_instance(instance_type):
1096-
err_msg += (
1097-
"torch_distributed is supported only for GPU instances.\n"
1098-
)
1098+
err_msg += "torch_distributed is supported only for GPU instances.\n"
10991099

11001100
# Check version compatibility for GPU instance
11011101
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
11021102
if match:
11031103
# Non-Trainium GPU instance but version earlier than 1.13.1
1104-
if not match[1].startswith("trn") and \
1105-
framework_version not in TORCH_DISTRIBUTED_SUPPORTED_FRAMEWORK_VERSIONS:
1104+
if (
1105+
not match[1].startswith("trn")
1106+
and framework_version not in TORCH_DISTRIBUTED_SUPPORTED_FRAMEWORK_VERSIONS
1107+
):
11061108
err_msg += (
11071109
f"Provided framework_version {framework_version} is not supported by"
11081110
f" torch_distributed for instance {instance_type}.\n"
@@ -1133,7 +1135,7 @@ def _is_gpu_instance(instance_type):
11331135
if isinstance(instance_type, str):
11341136
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
11351137
if match:
1136-
if (match[1].startswith("trn") or match[1].startswith("p") or match[1].startswith("g")):
1138+
if match[1].startswith("trn") or match[1].startswith("p") or match[1].startswith("g"):
11371139
return True
11381140
if instance_type == "local_gpu":
11391141
return True

tests/unit/test_fw_utils.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1103,6 +1103,7 @@ def test_instance_type_supports_profiler():
11031103
assert fw_utils._instance_type_supports_profiler("ml.m4.xlarge") is False
11041104
assert fw_utils._instance_type_supports_profiler("local") is False
11051105

1106+
11061107
def test_is_gpu_instance():
11071108
gpu_instance_types = [
11081109
"ml.p3.2xlarge",
@@ -1114,7 +1115,7 @@ def test_is_gpu_instance():
11141115
"ml.g4dn.xlarge",
11151116
"ml.g5.xlarge",
11161117
"ml.g5.48xlarge",
1117-
"local_gpu"
1118+
"local_gpu",
11181119
]
11191120
non_gpu_instance_types = [
11201121
"ml.t3.xlarge",

0 commit comments

Comments
 (0)