Skip to content

Commit 65086e5

Browse files
committed
feature: added warning recommending EFA instances with training compiler
1 parent e0580b5 commit 65086e5

File tree

1 file changed

+14
-0
lines changed
  • src/sagemaker/huggingface/training_compiler

1 file changed

+14
-0
lines changed

src/sagemaker/huggingface/training_compiler/config.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,13 @@ class TrainingCompilerConfig(BaseConfig):
2727
"""The SageMaker Training Compiler configuration class."""
2828

2929
SUPPORTED_INSTANCE_CLASS_PREFIXES = ["p3", "g4dn", "p4d", "g5"]
30+
SUPPORTED_INSTANCE_TYPES_WITH_EFA = [
31+
"ml.g4dn.8xlarge",
32+
"ml.g4dn.12xlarge",
33+
"ml.g5.48xlarge",
34+
"ml.p3dn.24xlarge",
35+
"ml.p4d.24xlarge",
36+
]
3037

3138
def __init__(
3239
self,
@@ -127,6 +134,13 @@ def validate(
127134
" Received pytorch_version={} which is unsupported."
128135
)
129136
raise ValueError(error_helper_string.format(estimator.pytorch_version))
137+
if estimator.instance_type not in cls.SUPPORTED_INSTANCE_TYPES_WITH_EFA:
138+
logger.warning(
139+
"Consider using instances with EFA support when "
140+
"training with PyTorch >= 1.11 and SageMaker Training Compiler "
141+
"enabled. SageMaker Training Compiler leverages EFA to provide better "
142+
"performance for distributed training."
143+
)
130144
if not pt_xla_present:
131145
if estimator.pytorch_version:
132146
if Version(estimator.pytorch_version) in SpecifierSet(">= 1.11"):

0 commit comments

Comments
 (0)