-
Notifications
You must be signed in to change notification settings - Fork 1.2k
feature: Adding support in HuggingFace estimator for Training Compiler enhanced PyTorch 1.11 #3307
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 24 commits
336204d
3b8738f
352c5f6
5a21b26
883cabb
66a42f6
9098c89
a0aa2fc
1db9ed0
2ea39c4
de3077c
192437c
33b3174
10b7c4e
5c589d5
e9be4c1
8a7827d
0b2a16c
37223ec
8ce021c
d696653
6688892
e0580b5
65086e5
26f47d4
280f625
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,6 +14,8 @@ | |
from __future__ import absolute_import | ||
import logging | ||
from typing import Union | ||
from packaging.specifiers import SpecifierSet | ||
from packaging.version import Version | ||
|
||
from sagemaker.training_compiler.config import TrainingCompilerConfig as BaseConfig | ||
from sagemaker.workflow.entities import PipelineVariable | ||
|
@@ -24,7 +26,14 @@ | |
class TrainingCompilerConfig(BaseConfig): | ||
"""The SageMaker Training Compiler configuration class.""" | ||
|
||
SUPPORTED_INSTANCE_CLASS_PREFIXES = ["p3", "g4dn", "p4"] | ||
SUPPORTED_INSTANCE_CLASS_PREFIXES = ["p3", "g4dn", "p4d", "g5"] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. did you mean to remove p4? is it intentional? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
SUPPORTED_INSTANCE_TYPES_WITH_EFA = [ | ||
"ml.g4dn.8xlarge", | ||
Lokiiiiii marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"ml.g4dn.12xlarge", | ||
"ml.g5.48xlarge", | ||
"ml.p3dn.24xlarge", | ||
"ml.p4d.24xlarge", | ||
] | ||
|
||
def __init__( | ||
self, | ||
|
@@ -85,7 +94,7 @@ def validate( | |
"""Checks if SageMaker Training Compiler is configured correctly. | ||
|
||
Args: | ||
estimator (str): A estimator object | ||
estimator (:class:`sagemaker.huggingface.HuggingFace`): An estimator object. | ||
If SageMaker Training Compiler is enabled, it will validate whether | ||
the estimator is configured to be compatible with Training Compiler. | ||
|
||
|
@@ -105,3 +114,46 @@ def validate( | |
"transformer_version, tensorflow_version or pytorch_version, and compiler_config." | ||
) | ||
raise ValueError(error_helper_string) | ||
|
||
if estimator.distribution: | ||
pt_xla_present = "pytorchxla" in estimator.distribution | ||
pt_xla_enabled = estimator.distribution.get("pytorchxla", {}).get("enabled", False) | ||
if pt_xla_enabled: | ||
if estimator.tensorflow_version: | ||
error_helper_string = ( | ||
"Distribution mechanism 'pytorchxla' is currently only supported for " | ||
"PyTorch >= 1.11 when SageMaker Training Compiler is enabled. Received " | ||
"tensorflow_version={} which is unsupported." | ||
) | ||
raise ValueError(error_helper_string.format(estimator.tensorflow_version)) | ||
if estimator.pytorch_version: | ||
if Version(estimator.pytorch_version) in SpecifierSet("< 1.11"): | ||
error_helper_string = ( | ||
"Distribution mechanism 'pytorchxla' is currently only supported for " | ||
"PyTorch >= 1.11 when SageMaker Training Compiler is enabled." | ||
" Received pytorch_version={} which is unsupported." | ||
) | ||
raise ValueError(error_helper_string.format(estimator.pytorch_version)) | ||
if estimator.instance_type not in cls.SUPPORTED_INSTANCE_TYPES_WITH_EFA: | ||
logger.warning( | ||
"Consider using instances with EFA support when " | ||
"training with PyTorch >= 1.11 and SageMaker Training Compiler " | ||
"enabled. SageMaker Training Compiler leverages EFA to provide better " | ||
"performance for distributed training." | ||
) | ||
if not pt_xla_present: | ||
if estimator.pytorch_version: | ||
if Version(estimator.pytorch_version) in SpecifierSet(">= 1.11"): | ||
error_helper_string = ( | ||
"'pytorchxla' is the only distribution mechanism currently supported " | ||
"for PyTorch >= 1.11 when SageMaker Training Compiler is enabled." | ||
" Received distribution={} which is unsupported." | ||
) | ||
raise ValueError(error_helper_string.format(estimator.distribution)) | ||
elif estimator.instance_count and estimator.instance_count > 1: | ||
if estimator.pytorch_version: | ||
if Version(estimator.pytorch_version) in SpecifierSet(">= 1.11"): | ||
logger.warning( | ||
"Consider setting 'distribution' to 'pytorchxla' for distributed " | ||
"training with PyTorch >= 1.11 and SageMaker Training Compiler enabled." | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,7 +21,7 @@ class TrainingCompilerConfig(object): | |
"""The SageMaker Training Compiler configuration class.""" | ||
|
||
DEBUG_PATH = "/opt/ml/output/data/compiler/" | ||
SUPPORTED_INSTANCE_CLASS_PREFIXES = ["p3", "g4dn", "p4"] | ||
SUPPORTED_INSTANCE_CLASS_PREFIXES = ["p3", "g4dn", "p4d", "g5"] | ||
pinaraws marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
HP_ENABLE_COMPILER = "sagemaker_training_compiler_enabled" | ||
HP_ENABLE_DEBUG = "sagemaker_training_compiler_debug_mode" | ||
|
@@ -123,7 +123,7 @@ def validate( | |
"""Checks if SageMaker Training Compiler is configured correctly. | ||
|
||
Args: | ||
estimator (str): A estimator object | ||
estimator (:class:`sagemaker.estimator.Estimator`): An estimator object. | ||
When SageMaker Training Compiler is enabled, it validates if | ||
the estimator is configured to be compatible with Training Compiler. | ||
|
||
|
@@ -132,31 +132,34 @@ def validate( | |
ValueError: Raised if the requested configuration is not compatible | ||
with SageMaker Training Compiler. | ||
""" | ||
|
||
if "local" not in estimator.instance_type: | ||
requested_instance_class = estimator.instance_type.split(".")[ | ||
1 | ||
] # Expecting ml.class.size | ||
if not any( | ||
[ | ||
requested_instance_class.startswith(i) | ||
for i in cls.SUPPORTED_INSTANCE_CLASS_PREFIXES | ||
] | ||
): | ||
if estimator.instance_type: | ||
if "local" not in estimator.instance_type: | ||
requested_instance_class = estimator.instance_type.split(".")[ | ||
1 | ||
] # Expecting ml.class.size | ||
if not any( | ||
[ | ||
requested_instance_class.startswith(i) | ||
for i in cls.SUPPORTED_INSTANCE_CLASS_PREFIXES | ||
] | ||
): | ||
error_helper_string = ( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we add There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we just want to generalize to EFA enabled instances or what is special about these 2 instance types ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's a good idea, we eventually would be to recommend all EFA instances that improve performance. We should start with recommending EFA instances we benchmarked thoroughly, we can add other instances to the list once we benchmark them internally. |
||
"Unsupported Instance class {}." | ||
"SageMaker Training Compiler only supports {}" | ||
) | ||
error_helper_string = error_helper_string.format( | ||
requested_instance_class, cls.SUPPORTED_INSTANCE_CLASS_PREFIXES | ||
) | ||
raise ValueError(error_helper_string) | ||
elif estimator.instance_type == "local": | ||
error_helper_string = ( | ||
"Unsupported Instance class {}. SageMaker Training Compiler only supports {}" | ||
"The local mode is not supported by SageMaker Training Compiler." | ||
Lokiiiiii marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"It only supports the following GPU instances: {}" | ||
) | ||
error_helper_string = error_helper_string.format( | ||
requested_instance_class, cls.SUPPORTED_INSTANCE_CLASS_PREFIXES | ||
cls.SUPPORTED_INSTANCE_CLASS_PREFIXES | ||
) | ||
raise ValueError(error_helper_string) | ||
elif estimator.instance_type == "local": | ||
error_helper_string = ( | ||
"The local mode is not supported by SageMaker Training Compiler." | ||
"It only supports the following GPU instances: {}" | ||
) | ||
error_helper_string = error_helper_string.format(cls.SUPPORTED_INSTANCE_CLASS_PREFIXES) | ||
raise ValueError(error_helper_string) | ||
|
||
if estimator.distribution and "smdistributed" in estimator.distribution: | ||
raise ValueError( | ||
|
@@ -180,3 +183,12 @@ def validate( | |
estimator.debugger_hook_config, estimator.disable_profiler | ||
) | ||
logger.warning(helper_string) | ||
|
||
if estimator.instance_groups: | ||
raise ValueError( | ||
"SageMaker Training Compiler currently only supports homogeneous clusters of " | ||
"the following GPU instance families: {}. Please use the 'instance_type' " | ||
"and 'instance_count' parameters instead of 'instance_groups'".format( | ||
cls.SUPPORTED_INSTANCE_CLASS_PREFIXES | ||
) | ||
) |
Uh oh!
There was an error while loading. Please reload this page.