148
148
]
149
149
150
150
151
- TORCH_DISTRIBUTED_SUPPORTED_FRAMEWORK_VERSIONS = ["1.11" , "1.11.0" ]
152
-
151
+ TORCH_DISTRIBUTED_GPU_SUPPORTED_FRAMEWORK_VERSIONS = ["1.13.1" ]
153
152
154
153
TRAINIUM_SUPPORTED_DISTRIBUTION_STRATEGIES = ["torch_distributed" ]
155
-
154
+ TRAINIUM_SUPPORTED_TORCH_DISTRIBUTED_FRAMEWORK_VERSIONS = [
155
+ "1.11" ,
156
+ "1.11.0" ,
157
+ "1.12" ,
158
+ "1.12.0" ,
159
+ "1.12.1" ,
160
+ "1.13.1" ,
161
+ ]
156
162
157
163
SMDISTRIBUTED_SUPPORTED_STRATEGIES = ["dataparallel" , "modelparallel" ]
158
164
@@ -1055,9 +1061,8 @@ def validate_torch_distributed_distribution(
1055
1061
Raises:
1056
1062
ValueError: if
1057
1063
`py_version` is not python3 or
1058
- `framework_version` is not in TORCH_DISTRIBUTED_SUPPORTED_FRAMEWORK_VERSIONS
1064
+ `framework_version` is not compatible with instance types
1059
1065
"""
1060
-
1061
1066
torch_distributed_enabled = False
1062
1067
if "torch_distributed" in distribution :
1063
1068
torch_distributed_enabled = distribution .get ("torch_distributed" ).get ("enabled" , False )
@@ -1066,30 +1071,36 @@ def validate_torch_distributed_distribution(
1066
1071
return
1067
1072
1068
1073
err_msg = ""
1074
+
1069
1075
if not image_uri :
1070
1076
# ignore framework_version and py_version if image_uri is set
1071
1077
# in case image_uri is not set, then both are mandatory
1072
- if framework_version not in TORCH_DISTRIBUTED_SUPPORTED_FRAMEWORK_VERSIONS :
1073
- err_msg += (
1074
- f"Provided framework_version { framework_version } is not supported by"
1075
- " torch_distributed.\n "
1076
- "Please specify one of the supported framework versions:"
1077
- f" { TORCH_DISTRIBUTED_SUPPORTED_FRAMEWORK_VERSIONS } \n "
1078
- )
1079
1078
if "py3" not in py_version :
1080
1079
err_msg += (
1081
1080
f"Provided py_version { py_version } is not supported by torch_distributed.\n "
1082
- "Please specify py_version>=py3"
1081
+ "Please specify py_version>=py3\n "
1083
1082
)
1084
1083
1085
- # Check instance compatibility
1086
- match = re .match (r"^ml[\._]([a-z\d]+)\.?\w*$" , instance_type )
1087
- if match :
1088
- if not match [1 ].startswith ("trn" ):
1084
+ # Check instance and framework_version compatibility
1085
+ if _is_gpu_instance (instance_type ):
1086
+ if framework_version not in TORCH_DISTRIBUTED_GPU_SUPPORTED_FRAMEWORK_VERSIONS :
1087
+ err_msg += (
1088
+ f"Provided framework_version { framework_version } is not supported by"
1089
+ f" torch_distributed for instance { instance_type } .\n "
1090
+ "Please specify one of the supported framework versions:"
1091
+ f"{ TORCH_DISTRIBUTED_GPU_SUPPORTED_FRAMEWORK_VERSIONS } \n "
1092
+ )
1093
+ elif _is_trainium_instance (instance_type ):
1094
+ if framework_version not in TRAINIUM_SUPPORTED_TORCH_DISTRIBUTED_FRAMEWORK_VERSIONS :
1095
+ err_msg += (
1096
+ f"Provided framework_version { framework_version } is not supported by"
1097
+ f" torch_distributed for instance { instance_type } .\n "
1098
+ "Please specify one of the supported framework versions:"
1099
+ f"{ TRAINIUM_SUPPORTED_TORCH_DISTRIBUTED_FRAMEWORK_VERSIONS } \n "
1100
+ )
1101
+ else :
1089
1102
err_msg += (
1090
- "torch_distributed is currently supported only for trainium instances.\n "
1091
- " Please refer https://sagemaker.readthedocs.io/en/stable/frameworks/pytorch/using_pytorch.html#distributed-pytorch-training \n " # noqa E501 # pylint: disable=c0301
1092
- "for information regarding distributed training on non-trainium instances"
1103
+ "Currently torch_distributed is supported only for GPU and Trainium instances.\n "
1093
1104
)
1094
1105
1095
1106
# Check entry point type
@@ -1103,6 +1114,41 @@ def validate_torch_distributed_distribution(
1103
1114
raise ValueError (err_msg )
1104
1115
1105
1116
1117
+ def _is_gpu_instance (instance_type ):
1118
+ """Returns bool indicating whether instance_type supports GPU
1119
+
1120
+ Args:
1121
+ instance_type (str): Name of the instance_type to check against.
1122
+
1123
+ Returns:
1124
+ bool: Whether or not the instance_type supports GPU
1125
+ """
1126
+ if isinstance (instance_type , str ):
1127
+ match = re .match (r"^ml[\._]([a-z\d]+)\.?\w*$" , instance_type )
1128
+ if match :
1129
+ if match [1 ].startswith ("p" ) or match [1 ].startswith ("g" ):
1130
+ return True
1131
+ if instance_type == "local_gpu" :
1132
+ return True
1133
+ return False
1134
+
1135
+
1136
+ def _is_trainium_instance (instance_type ):
1137
+ """Returns bool indicating whether instance_type is a Trainium instance
1138
+
1139
+ Args:
1140
+ instance_type (str): Name of the instance_type to check against.
1141
+
1142
+ Returns:
1143
+ bool: Whether or not the instance_type is a Trainium instance
1144
+ """
1145
+ if isinstance (instance_type , str ):
1146
+ match = re .match (r"^ml[\._]([a-z\d]+)\.?\w*$" , instance_type )
1147
+ if match and match [1 ].startswith ("trn" ):
1148
+ return True
1149
+ return False
1150
+
1151
+
1106
1152
def python_deprecation_warning (framework , latest_supported_version ):
1107
1153
"""Placeholder docstring"""
1108
1154
return PYTHON_2_DEPRECATION_WARNING .format (
0 commit comments