@@ -1077,8 +1077,10 @@ def validate_torch_distributed_distribution(
1077
1077
if not image_uri :
1078
1078
# ignore framework_version and py_version if image_uri is set
1079
1079
# in case image_uri is not set, then both are mandatory
1080
- if framework_version not in TORCH_DISTRIBUTED_SUPPORTED_FRAMEWORK_VERSIONS or \
1081
- framework_version not in TORCH_DISTRIBUTED_TRAINIUM_SUPPORTED_FRAMEWORK_VERSIONS :
1080
+ if (
1081
+ framework_version not in TORCH_DISTRIBUTED_SUPPORTED_FRAMEWORK_VERSIONS
1082
+ or framework_version not in TORCH_DISTRIBUTED_TRAINIUM_SUPPORTED_FRAMEWORK_VERSIONS
1083
+ ):
1082
1084
err_msg += (
1083
1085
f"Provided framework_version { framework_version } is not supported by"
1084
1086
" torch_distributed.\n "
@@ -1093,16 +1095,16 @@ def validate_torch_distributed_distribution(
1093
1095
1094
1096
# Check instance compatibility
1095
1097
if not _is_gpu_instance (instance_type ):
1096
- err_msg += (
1097
- "torch_distributed is supported only for GPU instances.\n "
1098
- )
1098
+ err_msg += "torch_distributed is supported only for GPU instances.\n "
1099
1099
1100
1100
# Check version compatibility for GPU instance
1101
1101
match = re .match (r"^ml[\._]([a-z\d]+)\.?\w*$" , instance_type )
1102
1102
if match :
1103
1103
# Non-Trainium GPU instance but version earlier than 1.13.1
1104
- if not match [1 ].startswith ("trn" ) and \
1105
- framework_version not in TORCH_DISTRIBUTED_SUPPORTED_FRAMEWORK_VERSIONS :
1104
+ if (
1105
+ not match [1 ].startswith ("trn" )
1106
+ and framework_version not in TORCH_DISTRIBUTED_SUPPORTED_FRAMEWORK_VERSIONS
1107
+ ):
1106
1108
err_msg += (
1107
1109
f"Provided framework_version { framework_version } is not supported by"
1108
1110
f" torch_distributed for instance { instance_type } .\n "
@@ -1133,7 +1135,7 @@ def _is_gpu_instance(instance_type):
1133
1135
if isinstance (instance_type , str ):
1134
1136
match = re .match (r"^ml[\._]([a-z\d]+)\.?\w*$" , instance_type )
1135
1137
if match :
1136
- if ( match [1 ].startswith ("trn" ) or match [1 ].startswith ("p" ) or match [1 ].startswith ("g" ) ):
1138
+ if match [1 ].startswith ("trn" ) or match [1 ].startswith ("p" ) or match [1 ].startswith ("g" ):
1137
1139
return True
1138
1140
if instance_type == "local_gpu" :
1139
1141
return True
0 commit comments