Skip to content

Commit 3d98b68

Browse files
LokiiiiiimufiAmazon
authored andcommitted
feature: SageMaker Training Compiler does not support p4de instances (aws#3478)
1 parent b4d8331 commit 3d98b68

File tree

3 files changed

+15
-31
lines changed

3 files changed

+15
-31
lines changed

src/sagemaker/huggingface/training_compiler/config.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
class TrainingCompilerConfig(BaseConfig):
2727
"""The SageMaker Training Compiler configuration class."""
2828

29-
SUPPORTED_INSTANCE_CLASS_PREFIXES = ["p3", "g4dn", "p4d", "g5"]
29+
SUPPORTED_INSTANCE_CLASS_PREFIXES = ["p3", "p3dn", "g4dn", "p4d", "g5"]
3030
SUPPORTED_INSTANCE_TYPES_WITH_EFA = [
3131
"ml.g4dn.8xlarge",
3232
"ml.g4dn.12xlarge",
@@ -87,10 +87,7 @@ def __init__(
8787
super(TrainingCompilerConfig, self).__init__(enabled=enabled, debug=debug)
8888

8989
@classmethod
90-
def validate(
91-
cls,
92-
estimator,
93-
):
90+
def validate(cls, estimator):
9491
"""Checks if SageMaker Training Compiler is configured correctly.
9592
9693
Args:

src/sagemaker/tensorflow/training_compiler/config.py

+3-10
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,10 @@
2424
class TrainingCompilerConfig(BaseConfig):
2525
"""The SageMaker Training Compiler configuration class."""
2626

27-
SUPPORTED_INSTANCE_CLASS_PREFIXES = ["p3", "g4dn", "p4", "g5"]
27+
SUPPORTED_INSTANCE_CLASS_PREFIXES = ["p3", "p3dn", "g4dn", "p4d", "g5"]
2828
MIN_SUPPORTED_VERSION = "2.9"
2929

30-
def __init__(
31-
self,
32-
enabled=True,
33-
debug=False,
34-
):
30+
def __init__(self, enabled=True, debug=False):
3531
"""This class initializes a ``TrainingCompilerConfig`` instance.
3632
3733
`Amazon SageMaker Training Compiler
@@ -79,10 +75,7 @@ def __init__(
7975
super(TrainingCompilerConfig, self).__init__(enabled=enabled, debug=debug)
8076

8177
@classmethod
82-
def validate(
83-
cls,
84-
estimator,
85-
):
78+
def validate(cls, estimator):
8679
"""Checks if SageMaker Training Compiler is configured correctly.
8780
8881
Args:

src/sagemaker/training_compiler/config.py

+10-16
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,12 @@ class TrainingCompilerConfig(object):
2323
"""The SageMaker Training Compiler configuration class."""
2424

2525
DEBUG_PATH = "/opt/ml/output/data/compiler/"
26-
SUPPORTED_INSTANCE_CLASS_PREFIXES = ["p3", "g4dn", "p4d", "g5"]
26+
SUPPORTED_INSTANCE_CLASS_PREFIXES = ["p3", "p3dn", "g4dn", "p4d", "g5"]
2727

2828
HP_ENABLE_COMPILER = "sagemaker_training_compiler_enabled"
2929
HP_ENABLE_DEBUG = "sagemaker_training_compiler_debug_mode"
3030

31-
def __init__(
32-
self,
33-
enabled=True,
34-
debug=False,
35-
):
31+
def __init__(self, enabled=True, debug=False):
3632
"""This class initializes a ``TrainingCompilerConfig`` instance.
3733
3834
`Amazon SageMaker Training Compiler
@@ -118,10 +114,7 @@ def _to_hyperparameter_dict(self):
118114
return compiler_config_hyperparameters
119115

120116
@classmethod
121-
def validate(
122-
cls,
123-
estimator,
124-
):
117+
def validate(cls, estimator):
125118
"""Checks if SageMaker Training Compiler is configured correctly.
126119
127120
Args:
@@ -138,19 +131,20 @@ def validate(
138131
warn_msg = (
139132
"Estimator instance_type is a PipelineVariable (%s), "
140133
"which has to be interpreted as one of the "
141-
"[p3, g4dn, p4d, g5] classes in execution time."
134+
"%s classes in execution time."
135+
)
136+
logger.warning(
137+
warn_msg,
138+
type(estimator.instance_type),
139+
str(cls.SUPPORTED_INSTANCE_CLASS_PREFIXES).replace(",", ""),
142140
)
143-
logger.warning(warn_msg, type(estimator.instance_type))
144141
elif estimator.instance_type:
145142
if "local" not in estimator.instance_type:
146143
requested_instance_class = estimator.instance_type.split(".")[
147144
1
148145
] # Expecting ml.class.size
149146
if not any(
150-
[
151-
requested_instance_class.startswith(i)
152-
for i in cls.SUPPORTED_INSTANCE_CLASS_PREFIXES
153-
]
147+
[requested_instance_class == i for i in cls.SUPPORTED_INSTANCE_CLASS_PREFIXES]
154148
):
155149
error_helper_string = (
156150
"Unsupported Instance class {}."

0 commit comments

Comments
 (0)