@@ -23,16 +23,12 @@ class TrainingCompilerConfig(object):
23
23
"""The SageMaker Training Compiler configuration class."""
24
24
25
25
DEBUG_PATH = "/opt/ml/output/data/compiler/"
26
- SUPPORTED_INSTANCE_CLASS_PREFIXES = ["p3" , "g4dn" , "p4d" , "g5" ]
26
+ SUPPORTED_INSTANCE_CLASS_PREFIXES = ["p3" , "p3dn" , " g4dn" , "p4d" , "g5" ]
27
27
28
28
HP_ENABLE_COMPILER = "sagemaker_training_compiler_enabled"
29
29
HP_ENABLE_DEBUG = "sagemaker_training_compiler_debug_mode"
30
30
31
- def __init__ (
32
- self ,
33
- enabled = True ,
34
- debug = False ,
35
- ):
31
+ def __init__ (self , enabled = True , debug = False ):
36
32
"""This class initializes a ``TrainingCompilerConfig`` instance.
37
33
38
34
`Amazon SageMaker Training Compiler
@@ -118,10 +114,7 @@ def _to_hyperparameter_dict(self):
118
114
return compiler_config_hyperparameters
119
115
120
116
@classmethod
121
- def validate (
122
- cls ,
123
- estimator ,
124
- ):
117
+ def validate (cls , estimator ):
125
118
"""Checks if SageMaker Training Compiler is configured correctly.
126
119
127
120
Args:
@@ -138,19 +131,20 @@ def validate(
138
131
warn_msg = (
139
132
"Estimator instance_type is a PipelineVariable (%s), "
140
133
"which has to be interpreted as one of the "
141
- "[p3, g4dn, p4d, g5] classes in execution time."
134
+ "%s classes in execution time."
135
+ )
136
+ logger .warning (
137
+ warn_msg ,
138
+ type (estimator .instance_type ),
139
+ str (cls .SUPPORTED_INSTANCE_CLASS_PREFIXES ).replace ("," , "" ),
142
140
)
143
- logger .warning (warn_msg , type (estimator .instance_type ))
144
141
elif estimator .instance_type :
145
142
if "local" not in estimator .instance_type :
146
143
requested_instance_class = estimator .instance_type .split ("." )[
147
144
1
148
145
] # Expecting ml.class.size
149
146
if not any (
150
- [
151
- requested_instance_class .startswith (i )
152
- for i in cls .SUPPORTED_INSTANCE_CLASS_PREFIXES
153
- ]
147
+ [requested_instance_class == i for i in cls .SUPPORTED_INSTANCE_CLASS_PREFIXES ]
154
148
):
155
149
error_helper_string = (
156
150
"Unsupported Instance class {}."
0 commit comments