148
148
]
149
149
150
150
151
- TORCH_DISTRIBUTED_SUPPORTED_FRAMEWORK_VERSIONS = ["1.11" , "1.11.0" ]
151
+ TORCH_DISTRIBUTED_TRAINIUM_SUPPORTED_FRAMEWORK_VERSIONS = [
152
+ "1.11" ,
153
+ "1.11.0" ,
154
+ "1.12" ,
155
+ "1.12.0" ,
156
+ "1.12.1" ,
157
+ "1.13.1" ,
158
+ ]
159
+ TORCH_DISTRIBUTED_SUPPORTED_FRAMEWORK_VERSIONS = ["1.13.1" ]
152
160
153
161
154
162
TRAINIUM_SUPPORTED_DISTRIBUTION_STRATEGIES = ["torch_distributed" ]
@@ -1069,12 +1077,13 @@ def validate_torch_distributed_distribution(
1069
1077
if not image_uri :
1070
1078
# ignore framework_version and py_version if image_uri is set
1071
1079
# in case image_uri is not set, then both are mandatory
1072
- if framework_version not in TORCH_DISTRIBUTED_SUPPORTED_FRAMEWORK_VERSIONS :
1080
+ if framework_version not in TORCH_DISTRIBUTED_SUPPORTED_FRAMEWORK_VERSIONS or \
1081
+ framework_version not in TORCH_DISTRIBUTED_TRAINIUM_SUPPORTED_FRAMEWORK_VERSIONS :
1073
1082
err_msg += (
1074
1083
f"Provided framework_version { framework_version } is not supported by"
1075
1084
" torch_distributed.\n "
1076
1085
"Please specify one of the supported framework versions:"
1077
- f" { TORCH_DISTRIBUTED_SUPPORTED_FRAMEWORK_VERSIONS } \n "
1086
+ f"{ TORCH_DISTRIBUTED_TRAINIUM_SUPPORTED_FRAMEWORK_VERSIONS } \n "
1078
1087
)
1079
1088
if "py3" not in py_version :
1080
1089
err_msg += (
@@ -1083,13 +1092,22 @@ def validate_torch_distributed_distribution(
1083
1092
)
1084
1093
1085
1094
# Check instance compatibility
1095
+ if not _is_gpu_instance (instance_type ):
1096
+ err_msg += (
1097
+ "torch_distributed is supported only for GPU instances.\n "
1098
+ )
1099
+
1100
+ # Check version compatibility for GPU instance
1086
1101
match = re .match (r"^ml[\._]([a-z\d]+)\.?\w*$" , instance_type )
1087
1102
if match :
1088
- if not match [1 ].startswith ("trn" ):
1103
+ # Non-Trainium GPU instance but version earlier than 1.13.1
1104
+ if not match [1 ].startswith ("trn" ) and \
1105
+ framework_version not in TORCH_DISTRIBUTED_SUPPORTED_FRAMEWORK_VERSIONS :
1089
1106
err_msg += (
1090
- "torch_distributed is currently supported only for trainium instances.\n "
1091
- " Please refer https://sagemaker.readthedocs.io/en/stable/frameworks/pytorch/using_pytorch.html#distributed-pytorch-training \n " # noqa E501 # pylint: disable=c0301
1092
- "for information regarding distributed training on non-trainium instances"
1107
+ f"Provided framework_version { framework_version } is not supported by"
1108
+ f" torch_distributed for instance { instance_type } .\n "
1109
+ "Please specify one of the supported framework versions:"
1110
+ f"{ TORCH_DISTRIBUTED_SUPPORTED_FRAMEWORK_VERSIONS } \n "
1093
1111
)
1094
1112
1095
1113
# Check entry point type
@@ -1103,6 +1121,25 @@ def validate_torch_distributed_distribution(
1103
1121
raise ValueError (err_msg )
1104
1122
1105
1123
1124
+ def _is_gpu_instance (instance_type ):
1125
+ """Returns bool indicating whether instance_type supports GPU
1126
+
1127
+ Args:
1128
+ instance_type (str): Name of the instance_type to check against.
1129
+
1130
+ Returns:
1131
+ bool: Whether or not the instance_type supports GPU
1132
+ """
1133
+ if isinstance (instance_type , str ):
1134
+ match = re .match (r"^ml[\._]([a-z\d]+)\.?\w*$" , instance_type )
1135
+ if match :
1136
+ if (match [1 ].startswith ("trn" ) or match [1 ].startswith ("p" ) or match [1 ].startswith ("g" )):
1137
+ return True
1138
+ if instance_type == "local_gpu" :
1139
+ return True
1140
+ return False
1141
+
1142
+
1106
1143
def python_deprecation_warning (framework , latest_supported_version ):
1107
1144
"""Placeholder docstring"""
1108
1145
return PYTHON_2_DEPRECATION_WARNING .format (
0 commit comments