-
Notifications
You must be signed in to change notification settings - Fork 1.2k
feature: Adding support in HuggingFace estimator for Training Compiler enhanced PyTorch 1.11 #3307
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
336204d
3b8738f
352c5f6
5a21b26
883cabb
66a42f6
9098c89
a0aa2fc
1db9ed0
2ea39c4
de3077c
192437c
33b3174
10b7c4e
5c589d5
e9be4c1
8a7827d
0b2a16c
37223ec
8ce021c
d696653
6688892
e0580b5
65086e5
26f47d4
280f625
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,6 +13,8 @@ | |
"""Configuration for the SageMaker Training Compiler.""" | ||
from __future__ import absolute_import | ||
import logging | ||
from packaging.specifiers import SpecifierSet | ||
from packaging.version import Version | ||
from typing import Union | ||
|
||
from sagemaker.training_compiler.config import TrainingCompilerConfig as BaseConfig | ||
|
@@ -24,7 +26,7 @@ | |
class TrainingCompilerConfig(BaseConfig): | ||
"""The SageMaker Training Compiler configuration class.""" | ||
|
||
SUPPORTED_INSTANCE_CLASS_PREFIXES = ["p3", "g4dn", "p4"] | ||
SUPPORTED_INSTANCE_CLASS_PREFIXES = ["p3", "g4dn", "p4d", "g5"] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. did you mean to remove p4? is it intentional? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
def __init__( | ||
self, | ||
|
@@ -85,7 +87,7 @@ def validate( | |
"""Checks if SageMaker Training Compiler is configured correctly. | ||
|
||
Args: | ||
estimator (str): A estimator object | ||
estimator (sagemaker.huggingface.HuggingFace): A estimator object | ||
Lokiiiiii marked this conversation as resolved.
Show resolved
Hide resolved
|
||
If SageMaker Training Compiler is enabled, it will validate whether | ||
the estimator is configured to be compatible with Training Compiler. | ||
|
||
|
@@ -105,3 +107,40 @@ def validate( | |
"transformer_version, tensorflow_version or pytorch_version, and compiler_config." | ||
) | ||
raise ValueError(error_helper_string) | ||
|
||
if estimator.distribution: | ||
pt_xla_present = "pytorch_xla" in estimator.distribution | ||
pt_xla_enabled = estimator.distribution.get("pytorch_xla", {}).get("enabled", False) | ||
if pt_xla_enabled: | ||
if estimator.tensorflow_version: | ||
error_helper_string = ( | ||
"Distribution mechanism 'pytorch_xla' is currently only supported for " | ||
"PyTorch >= 1.11 when Training Compiler is enabled. Received " | ||
mchoi8739 marked this conversation as resolved.
Show resolved
Hide resolved
Lokiiiiii marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"tensorflow_version={} which is unsupported." | ||
) | ||
raise ValueError(error_helper_string.format(estimator.tensorflow_version)) | ||
elif estimator.pytorch_version: | ||
if Version(estimator.pytorch_version) in SpecifierSet("< 1.11"): | ||
error_helper_string = ( | ||
"Distribution mechanism 'pytorch_xla' is currently only supported for " | ||
"PyTorch >= 1.11 when Training Compiler is enabled. Received " | ||
Lokiiiiii marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"pytorch_version={} which is unsupported." | ||
) | ||
raise ValueError(error_helper_string.format(estimator.pytorch_version)) | ||
if not pt_xla_present: | ||
if estimator.pytorch_version: | ||
if Version(estimator.pytorch_version) in SpecifierSet(">= 1.11"): | ||
error_helper_string = ( | ||
"'pytorch_xla' is the only distribution mechanism currently supported " | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you can still run without specifying distribution, correct? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, this is just blocking customers from using other built-in distribution mechanisms like mpi or pytorch ddp. |
||
"for PyTorch >= 1.11 when Training Compiler is enabled. Received " | ||
Lokiiiiii marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"distribution={} which is unsupported." | ||
) | ||
raise ValueError(error_helper_string.format(estimator.distribution)) | ||
elif estimator.instance_count and estimator.instance_count > 1: | ||
if estimator.pytorch_version: | ||
if Version(estimator.pytorch_version) in SpecifierSet(">= 1.11"): | ||
logger.warning( | ||
"Consider setting 'distribution' to 'pytorch_xla' for distributed " | ||
"training with PyTorch >= 1.11 with Training Compiler enabled. This " | ||
Lokiiiiii marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"will become the default configuration in the future." | ||
Lokiiiiii marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) |
Uh oh!
There was an error while loading. Please reload this page.