Skip to content

Commit 01fd66a

Browse files
Lokiiiiiimchoi8739
andauthored
feature: Adding support in HuggingFace estimator for Training Compiler enhanced PyTorch 1.11 (#3307)
* feature: Adding support in HuggingFace estimator for Training Compiler enhanced PyTorch 1.11 * Update src/sagemaker/huggingface/training_compiler/config.py Co-authored-by: Miyoung <[email protected]> * Update src/sagemaker/huggingface/training_compiler/config.py Co-authored-by: Miyoung <[email protected]> * Update src/sagemaker/huggingface/training_compiler/config.py Co-authored-by: Miyoung <[email protected]> * Update src/sagemaker/training_compiler/config.py Co-authored-by: Miyoung <[email protected]> * fix: renaming distribution parameters pytorch_xla -> pytorchxla * Update src/sagemaker/huggingface/training_compiler/config.py Co-authored-by: Miyoung <[email protected]> * Update src/sagemaker/huggingface/estimator.py Co-authored-by: Miyoung <[email protected]> * Update src/sagemaker/huggingface/training_compiler/config.py Co-authored-by: Miyoung <[email protected]> * Update src/sagemaker/huggingface/training_compiler/config.py Co-authored-by: Miyoung <[email protected]> * Update src/sagemaker/huggingface/estimator.py Co-authored-by: Miyoung <[email protected]> * Update src/sagemaker/huggingface/estimator.py Co-authored-by: Miyoung <[email protected]> * Fix: syntax error in trcomp tests * fix: linting * fix: linting to break up long lines * fix: fixture scoping issue in integ test * fix: broken unit tests for trcomp * fix: broken skip logic in version fixtures * fix: update test and version compatibility * feature: added warning recommending EFA instances with training compiler * Update src/sagemaker/huggingface/estimator.py Co-authored-by: Miyoung <[email protected]> * Update src/sagemaker/training_compiler/config.py Co-authored-by: Miyoung <[email protected]> Co-authored-by: Miyoung <[email protected]>
1 parent 29fc70e commit 01fd66a

File tree

11 files changed

+411
-70
lines changed

11 files changed

+411
-70
lines changed

src/sagemaker/estimator.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ class EstimatorBase(with_metaclass(ABCMeta, object)): # pylint: disable=too-man
100100
instance.
101101
"""
102102

103+
LAUNCH_PT_XLA_ENV_NAME = "sagemaker_pytorch_xla_multi_worker_enabled"
103104
LAUNCH_PS_ENV_NAME = "sagemaker_parameter_server_enabled"
104105
LAUNCH_MPI_ENV_NAME = "sagemaker_mpi_enabled"
105106
LAUNCH_SM_DDP_ENV_NAME = "sagemaker_distributed_dataparallel_enabled"
@@ -3316,6 +3317,10 @@ def _distribution_configuration(self, distribution):
33163317
"instance_groups"
33173318
]
33183319

3320+
if "pytorchxla" in distribution:
3321+
pt_xla_enabled = distribution.get("pytorchxla").get("enabled", False)
3322+
distribution_config[self.LAUNCH_PT_XLA_ENV_NAME] = pt_xla_enabled
3323+
33193324
if "parameter_server" in distribution:
33203325
ps_enabled = distribution.get("parameter_server").get("enabled", False)
33213326
distribution_config[self.LAUNCH_PS_ENV_NAME] = ps_enabled

src/sagemaker/huggingface/estimator.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,28 @@ def __init__(
141141
}
142142
}
143143
}
144+
145+
To enable distributed training with
146+
`SageMaker Training Compiler <https://docs.aws.amazon.com/sagemaker/latest/dg/training-compiler.html>`_
147+
for Hugging Face Transformers with PyTorch:
148+
149+
.. code:: python
150+
151+
{
152+
"pytorchxla": {
153+
"enabled": True
154+
}
155+
}
156+
157+
To learn more, see `SageMaker Training Compiler
158+
<https://docs.aws.amazon.com/sagemaker/latest/dg/training-compiler.html>`_
159+
in the *Amazon SageMaker Developer Guide*.
160+
161+
.. note::
162+
163+
When you use this PyTorch XLA option for distributed training strategy,
164+
you must add the ``compiler_config`` parameter and activate SageMaker
165+
Training Compiler.
144166
compiler_config (:class:`~sagemaker.huggingface.TrainingCompilerConfig`):
145167
Configures SageMaker Training Compiler to accelerate training.
146168
@@ -204,6 +226,13 @@ def __init__(
204226
raise ValueError(error_string)
205227
if compiler_config:
206228
compiler_config.validate(self)
229+
elif distribution is not None and "pytorchxla" in distribution:
230+
raise ValueError(
231+
"Distributed training through PyTorch XLA is currently only supported "
232+
"when SageMaker Training Compiler is enabled. To learn more, "
233+
"see Enable SageMaker Training Compiler at "
234+
"https://docs.aws.amazon.com/sagemaker/latest/dg/training-compiler-enable.html."
235+
)
207236
self.compiler_config = compiler_config
208237

209238
def _validate_args(self, image_uri):

src/sagemaker/huggingface/training_compiler/config.py

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
from __future__ import absolute_import
1515
import logging
1616
from typing import Union
17+
from packaging.specifiers import SpecifierSet
18+
from packaging.version import Version
1719

1820
from sagemaker.training_compiler.config import TrainingCompilerConfig as BaseConfig
1921
from sagemaker.workflow.entities import PipelineVariable
@@ -24,7 +26,14 @@
2426
class TrainingCompilerConfig(BaseConfig):
2527
"""The SageMaker Training Compiler configuration class."""
2628

27-
SUPPORTED_INSTANCE_CLASS_PREFIXES = ["p3", "g4dn", "p4"]
29+
SUPPORTED_INSTANCE_CLASS_PREFIXES = ["p3", "g4dn", "p4d", "g5"]
30+
SUPPORTED_INSTANCE_TYPES_WITH_EFA = [
31+
"ml.g4dn.8xlarge",
32+
"ml.g4dn.12xlarge",
33+
"ml.g5.48xlarge",
34+
"ml.p3dn.24xlarge",
35+
"ml.p4d.24xlarge",
36+
]
2837

2938
def __init__(
3039
self,
@@ -85,7 +94,7 @@ def validate(
8594
"""Checks if SageMaker Training Compiler is configured correctly.
8695
8796
Args:
88-
estimator (str): A estimator object
97+
estimator (:class:`sagemaker.huggingface.HuggingFace`): An estimator object.
8998
If SageMaker Training Compiler is enabled, it will validate whether
9099
the estimator is configured to be compatible with Training Compiler.
91100
@@ -105,3 +114,46 @@ def validate(
105114
"transformer_version, tensorflow_version or pytorch_version, and compiler_config."
106115
)
107116
raise ValueError(error_helper_string)
117+
118+
if estimator.distribution:
119+
pt_xla_present = "pytorchxla" in estimator.distribution
120+
pt_xla_enabled = estimator.distribution.get("pytorchxla", {}).get("enabled", False)
121+
if pt_xla_enabled:
122+
if estimator.tensorflow_version:
123+
error_helper_string = (
124+
"Distribution mechanism 'pytorchxla' is currently only supported for "
125+
"PyTorch >= 1.11 when SageMaker Training Compiler is enabled. Received "
126+
"tensorflow_version={} which is unsupported."
127+
)
128+
raise ValueError(error_helper_string.format(estimator.tensorflow_version))
129+
if estimator.pytorch_version:
130+
if Version(estimator.pytorch_version) in SpecifierSet("< 1.11"):
131+
error_helper_string = (
132+
"Distribution mechanism 'pytorchxla' is currently only supported for "
133+
"PyTorch >= 1.11 when SageMaker Training Compiler is enabled."
134+
" Received pytorch_version={} which is unsupported."
135+
)
136+
raise ValueError(error_helper_string.format(estimator.pytorch_version))
137+
if estimator.instance_type not in cls.SUPPORTED_INSTANCE_TYPES_WITH_EFA:
138+
logger.warning(
139+
"Consider using instances with EFA support when "
140+
"training with PyTorch >= 1.11 and SageMaker Training Compiler "
141+
"enabled. SageMaker Training Compiler leverages EFA to provide better "
142+
"performance for distributed training."
143+
)
144+
if not pt_xla_present:
145+
if estimator.pytorch_version:
146+
if Version(estimator.pytorch_version) in SpecifierSet(">= 1.11"):
147+
error_helper_string = (
148+
"'pytorchxla' is the only distribution mechanism currently supported "
149+
"for PyTorch >= 1.11 when SageMaker Training Compiler is enabled."
150+
" Received distribution={} which is unsupported."
151+
)
152+
raise ValueError(error_helper_string.format(estimator.distribution))
153+
elif estimator.instance_count and estimator.instance_count > 1:
154+
if estimator.pytorch_version:
155+
if Version(estimator.pytorch_version) in SpecifierSet(">= 1.11"):
156+
logger.warning(
157+
"Consider setting 'distribution' to 'pytorchxla' for distributed "
158+
"training with PyTorch >= 1.11 and SageMaker Training Compiler enabled."
159+
)

src/sagemaker/image_uri_config/huggingface-training-compiler.json

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
"processors": ["gpu"],
44
"version_aliases": {
55
"4.11": "4.11.0",
6-
"4.17": "4.17.0"
6+
"4.17": "4.17.0",
7+
"4.21": "4.21.1"
78
},
89
"versions": {
910
"4.11.0": {
@@ -97,6 +98,40 @@
9798
"repository": "huggingface-tensorflow-trcomp-training",
9899
"container_version": {"gpu":"cu112-ubuntu20.04"}
99100
}
101+
},
102+
"4.21.1": {
103+
"version_aliases": {
104+
"pytorch1.11": "pytorch1.11.0"
105+
},
106+
"pytorch1.11.0": {
107+
"py_versions": ["py38"],
108+
"registries": {
109+
"af-south-1": "626614931356",
110+
"ap-east-1": "871362719292",
111+
"ap-northeast-1": "763104351884",
112+
"ap-northeast-2": "763104351884",
113+
"ap-northeast-3": "364406365360",
114+
"ap-south-1": "763104351884",
115+
"ap-southeast-1": "763104351884",
116+
"ap-southeast-2": "763104351884",
117+
"ap-southeast-3": "907027046896",
118+
"ca-central-1": "763104351884",
119+
"eu-central-1": "763104351884",
120+
"eu-north-1": "763104351884",
121+
"eu-south-1": "692866216735",
122+
"eu-west-1": "763104351884",
123+
"eu-west-2": "763104351884",
124+
"eu-west-3": "763104351884",
125+
"me-south-1": "217643126080",
126+
"sa-east-1": "763104351884",
127+
"us-east-1": "763104351884",
128+
"us-east-2": "763104351884",
129+
"us-west-1": "763104351884",
130+
"us-west-2": "763104351884"
131+
},
132+
"repository": "huggingface-pytorch-trcomp-training",
133+
"container_version": {"gpu":"cu113-ubuntu20.04"}
134+
}
100135
}
101136
}
102137
}

src/sagemaker/training_compiler/config.py

Lines changed: 34 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ class TrainingCompilerConfig(object):
2121
"""The SageMaker Training Compiler configuration class."""
2222

2323
DEBUG_PATH = "/opt/ml/output/data/compiler/"
24-
SUPPORTED_INSTANCE_CLASS_PREFIXES = ["p3", "g4dn", "p4"]
24+
SUPPORTED_INSTANCE_CLASS_PREFIXES = ["p3", "g4dn", "p4d", "g5"]
2525

2626
HP_ENABLE_COMPILER = "sagemaker_training_compiler_enabled"
2727
HP_ENABLE_DEBUG = "sagemaker_training_compiler_debug_mode"
@@ -123,7 +123,7 @@ def validate(
123123
"""Checks if SageMaker Training Compiler is configured correctly.
124124
125125
Args:
126-
estimator (str): A estimator object
126+
estimator (:class:`sagemaker.estimator.Estimator`): An estimator object.
127127
When SageMaker Training Compiler is enabled, it validates if
128128
the estimator is configured to be compatible with Training Compiler.
129129
@@ -132,31 +132,34 @@ def validate(
132132
ValueError: Raised if the requested configuration is not compatible
133133
with SageMaker Training Compiler.
134134
"""
135-
136-
if "local" not in estimator.instance_type:
137-
requested_instance_class = estimator.instance_type.split(".")[
138-
1
139-
] # Expecting ml.class.size
140-
if not any(
141-
[
142-
requested_instance_class.startswith(i)
143-
for i in cls.SUPPORTED_INSTANCE_CLASS_PREFIXES
144-
]
145-
):
135+
if estimator.instance_type:
136+
if "local" not in estimator.instance_type:
137+
requested_instance_class = estimator.instance_type.split(".")[
138+
1
139+
] # Expecting ml.class.size
140+
if not any(
141+
[
142+
requested_instance_class.startswith(i)
143+
for i in cls.SUPPORTED_INSTANCE_CLASS_PREFIXES
144+
]
145+
):
146+
error_helper_string = (
147+
"Unsupported Instance class {}."
148+
"SageMaker Training Compiler only supports {}"
149+
)
150+
error_helper_string = error_helper_string.format(
151+
requested_instance_class, cls.SUPPORTED_INSTANCE_CLASS_PREFIXES
152+
)
153+
raise ValueError(error_helper_string)
154+
elif estimator.instance_type == "local":
146155
error_helper_string = (
147-
"Unsupported Instance class {}. SageMaker Training Compiler only supports {}"
156+
"SageMaker Training Compiler doesn't support local mode."
157+
"It only supports the following GPU instances: {}"
148158
)
149159
error_helper_string = error_helper_string.format(
150-
requested_instance_class, cls.SUPPORTED_INSTANCE_CLASS_PREFIXES
160+
cls.SUPPORTED_INSTANCE_CLASS_PREFIXES
151161
)
152162
raise ValueError(error_helper_string)
153-
elif estimator.instance_type == "local":
154-
error_helper_string = (
155-
"The local mode is not supported by SageMaker Training Compiler."
156-
"It only supports the following GPU instances: {}"
157-
)
158-
error_helper_string = error_helper_string.format(cls.SUPPORTED_INSTANCE_CLASS_PREFIXES)
159-
raise ValueError(error_helper_string)
160163

161164
if estimator.distribution and "smdistributed" in estimator.distribution:
162165
raise ValueError(
@@ -180,3 +183,12 @@ def validate(
180183
estimator.debugger_hook_config, estimator.disable_profiler
181184
)
182185
logger.warning(helper_string)
186+
187+
if estimator.instance_groups:
188+
raise ValueError(
189+
"SageMaker Training Compiler currently only supports homogeneous clusters of "
190+
"the following GPU instance families: {}. Please use the 'instance_type' "
191+
"and 'instance_count' parameters instead of 'instance_groups'".format(
192+
cls.SUPPORTED_INSTANCE_CLASS_PREFIXES
193+
)
194+
)

tests/conftest.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -252,27 +252,46 @@ def huggingface_pytorch_training_py_version(huggingface_pytorch_training_version
252252

253253
@pytest.fixture(scope="module")
254254
def huggingface_training_compiler_pytorch_version(huggingface_training_compiler_version):
255-
return _huggingface_base_fm_version(
255+
versions = _huggingface_base_fm_version(
256256
huggingface_training_compiler_version, "pytorch", "huggingface_training_compiler"
257-
)[0]
257+
)
258+
if not versions:
259+
pytest.skip(
260+
f"Hugging Face Training Compiler version {huggingface_training_compiler_version} does "
261+
f"not have a PyTorch release."
262+
)
263+
return versions[0]
258264

259265

260266
@pytest.fixture(scope="module")
261267
def huggingface_training_compiler_tensorflow_version(huggingface_training_compiler_version):
262-
return _huggingface_base_fm_version(
268+
versions = _huggingface_base_fm_version(
263269
huggingface_training_compiler_version, "tensorflow", "huggingface_training_compiler"
264-
)[0]
270+
)
271+
if not versions:
272+
pytest.skip(
273+
f"Hugging Face Training Compiler version {huggingface_training_compiler_version} "
274+
f"does not have a TensorFlow release."
275+
)
276+
return versions[0]
265277

266278

267279
@pytest.fixture(scope="module")
268-
def huggingface_training_compiler_py_version(huggingface_training_compiler_tensorflow_version):
280+
def huggingface_training_compiler_tensorflow_py_version(
281+
huggingface_training_compiler_tensorflow_version,
282+
):
269283
return (
270284
"py37"
271285
if Version(huggingface_training_compiler_tensorflow_version) < Version("2.6")
272286
else "py38"
273287
)
274288

275289

290+
@pytest.fixture(scope="module")
291+
def huggingface_training_compiler_pytorch_py_version(huggingface_training_compiler_pytorch_version):
292+
return "py38"
293+
294+
276295
@pytest.fixture(scope="module")
277296
def huggingface_pytorch_latest_training_py_version(huggingface_training_pytorch_latest_version):
278297
return (

0 commit comments

Comments
 (0)