-
Notifications
You must be signed in to change notification settings - Fork 1.2k
feature: Adding support in HuggingFace estimator for Training Compiler enhanced PyTorch 1.11 #3307
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 14 commits
Commits
Show all changes
26 commits
Select commit
Hold shift + click to select a range
336204d
feature: Adding support in HuggingFace estimator for Training Compile…
Lokiiiiii 3b8738f
Merge remote-tracking branch 'aws/master' into trcomp-hf-pt-111
Lokiiiiii 352c5f6
Update src/sagemaker/huggingface/training_compiler/config.py
Lokiiiiii 5a21b26
Update src/sagemaker/huggingface/training_compiler/config.py
Lokiiiiii 883cabb
Update src/sagemaker/huggingface/training_compiler/config.py
Lokiiiiii 66a42f6
Update src/sagemaker/training_compiler/config.py
Lokiiiiii 9098c89
fix: renaming distribution parameters pytorch_xla -> pytorchxla
Lokiiiiii a0aa2fc
Merge remote-tracking branch 'aws/master' into trcomp-hf-pt-111
Lokiiiiii 1db9ed0
Update src/sagemaker/huggingface/training_compiler/config.py
Lokiiiiii 2ea39c4
Update src/sagemaker/huggingface/estimator.py
Lokiiiiii de3077c
Update src/sagemaker/huggingface/training_compiler/config.py
Lokiiiiii 192437c
Update src/sagemaker/huggingface/training_compiler/config.py
Lokiiiiii 33b3174
Update src/sagemaker/huggingface/estimator.py
Lokiiiiii 10b7c4e
Update src/sagemaker/huggingface/estimator.py
Lokiiiiii 5c589d5
Merge remote-tracking branch 'aws/master' into trcomp-hf-pt-111
Lokiiiiii e9be4c1
Fix: syntax error in trcomp tests
Lokiiiiii 8a7827d
fix: linting
Lokiiiiii 0b2a16c
Merge remote-tracking branch 'aws/master' into trcomp-hf-pt-111
Lokiiiiii 37223ec
fix: linting to break up long lines
Lokiiiiii 8ce021c
fix: fixture scoping issue in integ test
Lokiiiiii d696653
fix: broken unit tests for trcomp
Lokiiiiii 6688892
fix: broken skip logic in version fixtures
Lokiiiiii e0580b5
fix: update test and version compatibility
Lokiiiiii 65086e5
feature: added warning recommending EFA instances with training compiler
Lokiiiiii 26f47d4
Update src/sagemaker/huggingface/estimator.py
Lokiiiiii 280f625
Update src/sagemaker/training_compiler/config.py
Lokiiiiii File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,6 +13,8 @@ | |
"""Configuration for the SageMaker Training Compiler.""" | ||
from __future__ import absolute_import | ||
import logging | ||
from packaging.specifiers import SpecifierSet | ||
from packaging.version import Version | ||
from typing import Union | ||
|
||
from sagemaker.training_compiler.config import TrainingCompilerConfig as BaseConfig | ||
|
@@ -24,7 +26,7 @@ | |
class TrainingCompilerConfig(BaseConfig): | ||
"""The SageMaker Training Compiler configuration class.""" | ||
|
||
SUPPORTED_INSTANCE_CLASS_PREFIXES = ["p3", "g4dn", "p4"] | ||
SUPPORTED_INSTANCE_CLASS_PREFIXES = ["p3", "g4dn", "p4d", "g5"] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. did you mean to remove p4? is it intentional? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
def __init__( | ||
self, | ||
|
@@ -85,7 +87,7 @@ def validate( | |
"""Checks if SageMaker Training Compiler is configured correctly. | ||
|
||
Args: | ||
estimator (str): A estimator object | ||
estimator (:class:`sagemaker.huggingface.HuggingFace`): An estimator object. | ||
If SageMaker Training Compiler is enabled, it will validate whether | ||
the estimator is configured to be compatible with Training Compiler. | ||
|
||
|
@@ -105,3 +107,39 @@ def validate( | |
"transformer_version, tensorflow_version or pytorch_version, and compiler_config." | ||
) | ||
raise ValueError(error_helper_string) | ||
|
||
if estimator.distribution: | ||
pt_xla_present = "pytorchxla" in estimator.distribution | ||
pt_xla_enabled = estimator.distribution.get("pytorchxla", {}).get("enabled", False) | ||
if pt_xla_enabled: | ||
if estimator.tensorflow_version: | ||
error_helper_string = ( | ||
"Distribution mechanism 'pytorchxla' is currently only supported for " | ||
"PyTorch >= 1.11 when SageMaker Training Compiler is enabled. Received " | ||
"tensorflow_version={} which is unsupported." | ||
) | ||
raise ValueError(error_helper_string.format(estimator.tensorflow_version)) | ||
elif estimator.pytorch_version: | ||
if Version(estimator.pytorch_version) in SpecifierSet("< 1.11"): | ||
error_helper_string = ( | ||
"Distribution mechanism 'pytorchxla' is currently only supported for " | ||
"PyTorch >= 1.11 when SageMaker Training Compiler is enabled. Received " | ||
"pytorch_version={} which is unsupported." | ||
) | ||
raise ValueError(error_helper_string.format(estimator.pytorch_version)) | ||
if not pt_xla_present: | ||
if estimator.pytorch_version: | ||
if Version(estimator.pytorch_version) in SpecifierSet(">= 1.11"): | ||
error_helper_string = ( | ||
"'pytorchxla' is the only distribution mechanism currently supported " | ||
"for PyTorch >= 1.11 when SageMaker Training Compiler is enabled. Received " | ||
"distribution={} which is unsupported." | ||
) | ||
raise ValueError(error_helper_string.format(estimator.distribution)) | ||
elif estimator.instance_count and estimator.instance_count > 1: | ||
if estimator.pytorch_version: | ||
if Version(estimator.pytorch_version) in SpecifierSet(">= 1.11"): | ||
logger.warning( | ||
"Consider setting 'distribution' to 'pytorchxla' for distributed " | ||
"training with PyTorch >= 1.11 and SageMaker Training Compiler enabled." | ||
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.