Skip to content

Commit 65d4103

Browse files
LokiiiiiiUbuntu
authored andcommitted
feature: Adding support for SageMaker Training Compiler in PyTorch estimator starting 1.12 (aws#3500)
Co-authored-by: Ubuntu <[email protected]>
1 parent 77fc6a7 commit 65d4103

File tree

14 files changed

+1467
-31
lines changed

14 files changed

+1467
-31
lines changed

src/sagemaker/fw_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -493,7 +493,7 @@ def framework_name_from_image(image_uri):
493493
# We must support both the legacy and current image name format.
494494
name_pattern = re.compile(
495495
r"""^(?:sagemaker(?:-rl)?-)?
496-
(tensorflow|mxnet|chainer|pytorch|scikit-learn|xgboost
496+
(tensorflow|mxnet|chainer|pytorch|pytorch-trcomp|scikit-learn|xgboost
497497
|huggingface-tensorflow|huggingface-pytorch
498498
|huggingface-tensorflow-trcomp|huggingface-pytorch-trcomp)(?:-)?
499499
(scriptmode|training)?
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
{
2+
"training": {
3+
"processors": [
4+
"gpu"
5+
],
6+
"version_aliases": {
7+
"1.12": "1.12.0"
8+
},
9+
"versions": {
10+
"1.12.0": {
11+
"py_versions": [
12+
"py38"
13+
],
14+
"registries": {
15+
"af-south-1": "626614931356",
16+
"ap-east-1": "871362719292",
17+
"ap-northeast-1": "763104351884",
18+
"ap-northeast-2": "763104351884",
19+
"ap-northeast-3": "364406365360",
20+
"ap-south-1": "763104351884",
21+
"ap-southeast-1": "763104351884",
22+
"ap-southeast-2": "763104351884",
23+
"ca-central-1": "763104351884",
24+
"eu-central-1": "763104351884",
25+
"eu-north-1": "763104351884",
26+
"eu-west-1": "763104351884",
27+
"eu-west-2": "763104351884",
28+
"eu-west-3": "763104351884",
29+
"eu-south-1": "692866216735",
30+
"me-south-1": "217643126080",
31+
"sa-east-1": "763104351884",
32+
"us-east-1": "763104351884",
33+
"us-east-2": "763104351884",
34+
"us-west-1": "763104351884",
35+
"us-west-2": "763104351884"
36+
},
37+
"repository": "pytorch-training"
38+
}
39+
}
40+
}
41+
}

src/sagemaker/image_uris.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def retrieve(
146146
tolerate_deprecated_model,
147147
)
148148

149-
if training_compiler_config and (framework == HUGGING_FACE_FRAMEWORK):
149+
if training_compiler_config and (framework in [HUGGING_FACE_FRAMEWORK, "pytorch"]):
150150
final_image_scope = image_scope
151151
config = _config_for_framework_and_scope(
152152
framework + "-training-compiler", final_image_scope, accelerator_type

src/sagemaker/pytorch/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,5 @@
1616
from sagemaker.pytorch.estimator import PyTorch # noqa: F401
1717
from sagemaker.pytorch.model import PyTorchModel, PyTorchPredictor # noqa: F401
1818
from sagemaker.pytorch.processing import PyTorchProcessor # noqa: F401
19+
20+
from sagemaker.pytorch.training_compiler.config import TrainingCompilerConfig # noqa: F401

src/sagemaker/pytorch/estimator.py

+57-3
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
)
2929
from sagemaker.pytorch import defaults
3030
from sagemaker.pytorch.model import PyTorchModel
31+
from sagemaker.pytorch.training_compiler.config import TrainingCompilerConfig
3132
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
3233
from sagemaker.workflow.entities import PipelineVariable
3334

@@ -51,7 +52,8 @@ def __init__(
5152
hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
5253
image_uri: Optional[Union[str, PipelineVariable]] = None,
5354
distribution: Optional[Dict] = None,
54-
**kwargs
55+
compiler_config: Optional[TrainingCompilerConfig] = None,
56+
**kwargs,
5557
):
5658
"""This ``Estimator`` executes a PyTorch script in a managed PyTorch execution environment.
5759
@@ -208,6 +210,31 @@ def __init__(
208210
To learn more, see `Training with parameter servers
209211
<https://sagemaker.readthedocs.io/en/stable/frameworks/tensorflow/using_tf.html#training-with-parameter-servers>`_.
210212
213+
**To enable distributed training with
214+
`SageMaker Training Compiler <https://docs.aws.amazon.com/sagemaker/latest/dg/training-compiler.html>`_
215+
for PyTorch:**
216+
217+
.. code:: python
218+
219+
{
220+
"pytorchxla": {
221+
"enabled": True
222+
}
223+
}
224+
225+
To learn more, see `SageMaker Training Compiler
226+
<https://docs.aws.amazon.com/sagemaker/latest/dg/training-compiler.html>`_
227+
in the *Amazon SageMaker Developer Guide*.
228+
229+
.. note::
230+
231+
When you use this PyTorch XLA option for distributed training strategy,
232+
you must add the ``compiler_config`` parameter and activate SageMaker
233+
Training Compiler.
234+
235+
compiler_config (:class:`~sagemaker.pytorch.TrainingCompilerConfig`):
236+
Configures SageMaker Training Compiler to accelerate training.
237+
211238
**kwargs: Additional kwargs passed to the :class:`~sagemaker.estimator.Framework`
212239
constructor.
213240
@@ -250,6 +277,25 @@ def __init__(
250277

251278
self.distribution = distribution or {}
252279

280+
if compiler_config is not None:
281+
if not isinstance(compiler_config, TrainingCompilerConfig):
282+
error_string = (
283+
f"Expected instance of type {TrainingCompilerConfig}"
284+
f"for argument compiler_config. "
285+
f"Instead got {type(compiler_config)}"
286+
)
287+
raise ValueError(error_string)
288+
if compiler_config:
289+
compiler_config.validate(self)
290+
elif distribution is not None and "pytorchxla" in distribution:
291+
raise ValueError(
292+
"Distributed training through PyTorch XLA is currently only supported "
293+
"when SageMaker Training Compiler is enabled. To learn more, "
294+
"see Enable SageMaker Training Compiler at "
295+
"https://docs.aws.amazon.com/sagemaker/latest/dg/training-compiler-enable.html."
296+
)
297+
self.compiler_config = compiler_config
298+
253299
def _pytorch_distribution_configuration(self, distribution):
254300
"""Returns a dict of distribution config for PyTorch training
255301
@@ -289,6 +335,12 @@ def hyperparameters(self):
289335
hyperparameters.update(
290336
EstimatorBase._json_encode_hyperparameters(additional_hyperparameters)
291337
)
338+
if self.compiler_config:
339+
training_compiler_hyperparameters = self.compiler_config._to_hyperparameter_dict()
340+
hyperparameters.update(
341+
EstimatorBase._json_encode_hyperparameters(training_compiler_hyperparameters)
342+
)
343+
292344
return hyperparameters
293345

294346
def create_model(
@@ -299,7 +351,7 @@ def create_model(
299351
entry_point=None,
300352
source_dir=None,
301353
dependencies=None,
302-
**kwargs
354+
**kwargs,
303355
):
304356
"""Create a SageMaker ``PyTorchModel`` object that can be deployed to an ``Endpoint``.
305357
@@ -350,7 +402,7 @@ def create_model(
350402
sagemaker_session=self.sagemaker_session,
351403
vpc_config=self.get_vpc_config(vpc_config_override),
352404
dependencies=(dependencies or self.dependencies),
353-
**kwargs
405+
**kwargs,
354406
)
355407

356408
@classmethod
@@ -371,6 +423,8 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
371423
)
372424
image_uri = init_params.pop("image_uri")
373425
framework, py_version, tag, _ = framework_name_from_image(image_uri)
426+
if framework:
427+
framework = framework.split("-")[0]
374428

375429
if tag is None:
376430
framework_version = None

src/sagemaker/pytorch/training_compiler/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Configuration for the SageMaker Training Compiler."""
14+
from __future__ import absolute_import
15+
import logging
16+
from typing import Union
17+
from packaging.specifiers import SpecifierSet
18+
from packaging.version import Version
19+
20+
from sagemaker.training_compiler.config import TrainingCompilerConfig as BaseConfig
21+
from sagemaker.workflow.entities import PipelineVariable
22+
23+
logger = logging.getLogger(__name__)
24+
25+
26+
class TrainingCompilerConfig(BaseConfig):
27+
"""The SageMaker Training Compiler configuration class."""
28+
29+
SUPPORTED_INSTANCE_CLASS_PREFIXES = ["p3", "p3dn", "g4dn", "p4d", "g5"]
30+
SUPPORTED_INSTANCE_TYPES_WITH_EFA = [
31+
"ml.g4dn.8xlarge",
32+
"ml.g4dn.12xlarge",
33+
"ml.g5.48xlarge",
34+
"ml.p3dn.24xlarge",
35+
"ml.p4d.24xlarge",
36+
]
37+
38+
def __init__(
39+
self,
40+
enabled: Union[bool, PipelineVariable] = True,
41+
debug: Union[bool, PipelineVariable] = False,
42+
):
43+
"""This class initializes a ``TrainingCompilerConfig`` instance.
44+
45+
`Amazon SageMaker Training Compiler
46+
<https://docs.aws.amazon.com/sagemaker/latest/dg/training-compiler.html>`_
47+
is a feature of SageMaker Training
48+
and speeds up training jobs by optimizing model execution graphs.
49+
50+
You can compile PyTorch models
51+
by passing the object of this configuration class to the ``compiler_config``
52+
parameter of the :class:`~sagemaker.pytorch.PyTorch`
53+
estimator.
54+
55+
Args:
56+
enabled (bool or PipelineVariable): Optional. Switch to enable SageMaker
57+
Training Compiler. The default is ``True``.
58+
debug (bool or PipelineVariable): Optional. Whether to dump detailed logs
59+
for debugging. This comes with a potential performance slowdown.
60+
The default is ``False``.
61+
62+
**Example**: The following code shows the basic usage of the
63+
:class:`sagemaker.pytorch.TrainingCompilerConfig()` class
64+
to run a PyTorch training job with the compiler.
65+
66+
.. code-block:: python
67+
68+
from sagemaker.pytorch import PyTorch, TrainingCompilerConfig
69+
70+
pytorch_estimator=PyTorch(
71+
...
72+
compiler_config=TrainingCompilerConfig()
73+
)
74+
75+
.. seealso::
76+
77+
For more information about how to enable SageMaker Training Compiler
78+
for various training settings such as distributed training,
79+
see `Enable SageMaker Training Compiler
80+
<https://docs.aws.amazon.com/sagemaker/latest/dg/training-compiler-enable.html>`_
81+
in the `Amazon SageMaker Training Compiler developer guide
82+
<https://docs.aws.amazon.com/sagemaker/latest/dg/training-compiler.html>`_.
83+
84+
"""
85+
86+
super(TrainingCompilerConfig, self).__init__(enabled=enabled, debug=debug)
87+
88+
@classmethod
89+
def validate(
90+
cls,
91+
estimator,
92+
):
93+
"""Checks if SageMaker Training Compiler is configured correctly.
94+
95+
Args:
96+
estimator (:class:`sagemaker.pytorch.PyTorch`): An estimator object.
97+
If SageMaker Training Compiler is enabled, it will validate whether
98+
the estimator is configured to be compatible with Training Compiler.
99+
100+
Raises:
101+
ValueError: Raised if the requested configuration is not compatible
102+
with SageMaker Training Compiler.
103+
"""
104+
105+
super(TrainingCompilerConfig, cls).validate(estimator)
106+
107+
if estimator.image_uri:
108+
error_helper_string = (
109+
"Overriding the image URI is currently not supported "
110+
"for SageMaker Training Compiler."
111+
"Specify the following parameters to run the PyTorch training job "
112+
"with SageMaker Training Compiler enabled: "
113+
"framework_version, and compiler_config."
114+
)
115+
raise ValueError(error_helper_string)
116+
117+
if estimator.distribution:
118+
pt_xla_present = "pytorchxla" in estimator.distribution
119+
pt_xla_enabled = estimator.distribution.get("pytorchxla", {}).get("enabled", False)
120+
if pt_xla_enabled:
121+
if estimator.framework_version:
122+
if Version(estimator.framework_version) in SpecifierSet("< 1.12"):
123+
error_helper_string = (
124+
"Distribution mechanism 'pytorchxla' is currently only supported for "
125+
"PyTorch >= 1.12 when SageMaker Training Compiler is enabled."
126+
" Received framework_version={} which is unsupported."
127+
)
128+
raise ValueError(error_helper_string.format(estimator.framework_version))
129+
if estimator.instance_type not in cls.SUPPORTED_INSTANCE_TYPES_WITH_EFA:
130+
logger.warning(
131+
"Consider using instances with EFA support when "
132+
"training with PyTorch >= 1.12 and SageMaker Training Compiler "
133+
"enabled. SageMaker Training Compiler leverages EFA to provide better "
134+
"performance for distributed training."
135+
)
136+
if not pt_xla_present:
137+
if estimator.framework_version:
138+
if Version(estimator.framework_version) in SpecifierSet(">= 1.12"):
139+
error_helper_string = (
140+
"'pytorchxla' is the only distribution mechanism currently supported "
141+
"for PyTorch >= 1.12 when SageMaker Training Compiler is enabled."
142+
" Received distribution={} which is unsupported."
143+
)
144+
raise ValueError(error_helper_string.format(estimator.distribution))
145+
elif estimator.instance_count and estimator.instance_count > 1:
146+
if estimator.framework_version:
147+
if Version(estimator.framework_version) in SpecifierSet(">= 1.12"):
148+
logger.warning(
149+
"Consider setting 'distribution' to 'pytorchxla' for distributed "
150+
"training with PyTorch >= 1.12 and SageMaker Training Compiler enabled."
151+
)

tests/conftest.py

+1
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373
"neo_pytorch",
7474
"neo_tensorflow",
7575
"pytorch",
76+
"pytorch_training_compiler",
7677
"ray_pytorch",
7778
"ray_tensorflow",
7879
"sklearn",
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
transformers
2+
datasets

0 commit comments

Comments
 (0)