Skip to content

Commit d9463d3

Browse files
Lokiiiiiimchoi8739
andauthored
feature: Adding Training Compiler support for TensorFlow estimator starting TF 2.9 (#3156)
Co-authored-by: Miyoung <[email protected]>
1 parent 325214f commit d9463d3

File tree

13 files changed

+879
-74
lines changed

13 files changed

+879
-74
lines changed

doc/frameworks/tensorflow/sagemaker.tensorflow.rst

+8
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,14 @@ TensorFlow Estimator
1010
:undoc-members:
1111
:show-inheritance:
1212

13+
TensorFlow Training Compiler Configuration
14+
------------------------------------------
15+
16+
.. autoclass:: sagemaker.tensorflow.TrainingCompilerConfig
17+
:members:
18+
:undoc-members:
19+
:show-inheritance:
20+
1321
TensorFlow Serving Model
1422
------------------------
1523

src/sagemaker/huggingface/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,4 @@
1717
from sagemaker.huggingface.model import HuggingFaceModel, HuggingFacePredictor # noqa: F401
1818
from sagemaker.huggingface.processing import HuggingFaceProcessor # noqa:F401
1919

20-
from sagemaker.training_compiler.config import TrainingCompilerConfig # noqa: F401
20+
from sagemaker.huggingface.training_compiler.config import TrainingCompilerConfig # noqa: F401

src/sagemaker/huggingface/estimator.py

+4-8
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from sagemaker.huggingface.model import HuggingFaceModel
2727
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
2828

29-
from sagemaker.training_compiler.config import TrainingCompilerConfig
29+
from sagemaker.huggingface.training_compiler.config import TrainingCompilerConfig
3030

3131
logger = logging.getLogger("sagemaker")
3232

@@ -190,6 +190,8 @@ def __init__(
190190
entry_point, source_dir, hyperparameters, image_uri=image_uri, **kwargs
191191
)
192192

193+
self.distribution = distribution or {}
194+
193195
if compiler_config is not None:
194196
if not isinstance(compiler_config, TrainingCompilerConfig):
195197
error_string = (
@@ -199,13 +201,7 @@ def __init__(
199201
)
200202
raise ValueError(error_string)
201203
if compiler_config:
202-
compiler_config.validate(
203-
image_uri=image_uri,
204-
instance_type=instance_type,
205-
distribution=distribution,
206-
)
207-
208-
self.distribution = distribution or {}
204+
compiler_config.validate(self)
209205
self.compiler_config = compiler_config
210206

211207
def _validate_args(self, image_uri):

src/sagemaker/huggingface/training_compiler/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Configuration for the SageMaker Training Compiler."""
14+
from __future__ import absolute_import
15+
import logging
16+
17+
from sagemaker.training_compiler.config import TrainingCompilerConfig as BaseConfig
18+
19+
logger = logging.getLogger(__name__)
20+
21+
22+
class TrainingCompilerConfig(BaseConfig):
23+
"""The SageMaker Training Compiler configuration class."""
24+
25+
SUPPORTED_INSTANCE_CLASS_PREFIXES = ["p3", "g4dn", "p4"]
26+
27+
def __init__(
28+
self,
29+
enabled=True,
30+
debug=False,
31+
):
32+
"""This class initializes a ``TrainingCompilerConfig`` instance.
33+
34+
`Amazon SageMaker Training Compiler
35+
<https://docs.aws.amazon.com/sagemaker/latest/dg/training-compiler.html>`_
36+
is a feature of SageMaker Training
37+
and speeds up training jobs by optimizing model execution graphs.
38+
39+
You can compile Hugging Face models
40+
by passing the object of this configuration class to the ``compiler_config``
41+
parameter of the :class:`~sagemaker.huggingface.HuggingFace`
42+
estimator.
43+
44+
Args:
45+
enabled (bool): Optional. Switch to enable SageMaker Training Compiler.
46+
The default is ``True``.
47+
debug (bool): Optional. Whether to dump detailed logs for debugging.
48+
This comes with a potential performance slowdown.
49+
The default is ``False``.
50+
51+
**Example**: The following code shows the basic usage of the
52+
:class:`sagemaker.huggingface.TrainingCompilerConfig()` class
53+
to run a HuggingFace training job with the compiler.
54+
55+
.. code-block:: python
56+
57+
from sagemaker.huggingface import HuggingFace, TrainingCompilerConfig
58+
59+
huggingface_estimator=HuggingFace(
60+
...
61+
compiler_config=TrainingCompilerConfig()
62+
)
63+
64+
.. seealso::
65+
66+
For more information about how to enable SageMaker Training Compiler
67+
for various training settings such as using TensorFlow-based models,
68+
PyTorch-based models, and distributed training,
69+
see `Enable SageMaker Training Compiler
70+
<https://docs.aws.amazon.com/sagemaker/latest/dg/training-compiler-enable.html>`_
71+
in the `Amazon SageMaker Training Compiler developer guide
72+
<https://docs.aws.amazon.com/sagemaker/latest/dg/training-compiler.html>`_.
73+
74+
"""
75+
76+
super(TrainingCompilerConfig, self).__init__(enabled=enabled, debug=debug)
77+
78+
@classmethod
79+
def validate(
80+
cls,
81+
estimator,
82+
):
83+
"""Checks if SageMaker Training Compiler is configured correctly.
84+
85+
Args:
86+
estimator (str): A estimator object
87+
If SageMaker Training Compiler is enabled, it will validate whether
88+
the estimator is configured to be compatible with Training Compiler.
89+
90+
Raises:
91+
ValueError: Raised if the requested configuration is not compatible
92+
with SageMaker Training Compiler.
93+
"""
94+
95+
super(TrainingCompilerConfig, cls).validate(estimator)
96+
97+
if estimator.image_uri:
98+
error_helper_string = (
99+
"Overriding the image URI is currently not supported "
100+
"for SageMaker Training Compiler."
101+
"Specify the following parameters to run the Hugging Face training job "
102+
"with SageMaker Training Compiler enabled: "
103+
"transformer_version, tensorflow_version or pytorch_version, and compiler_config."
104+
)
105+
raise ValueError(error_helper_string)

src/sagemaker/image_uris.py

+6-9
Original file line numberDiff line numberDiff line change
@@ -134,21 +134,18 @@ def retrieve(
134134
tolerate_vulnerable_model,
135135
tolerate_deprecated_model,
136136
)
137-
if training_compiler_config is None:
137+
138+
if training_compiler_config and (framework == HUGGING_FACE_FRAMEWORK):
139+
config = _config_for_framework_and_scope(
140+
framework + "-training-compiler", image_scope, accelerator_type
141+
)
142+
else:
138143
_framework = framework
139144
if framework == HUGGING_FACE_FRAMEWORK:
140145
inference_tool = _get_inference_tool(inference_tool, instance_type)
141146
if inference_tool == "neuron":
142147
_framework = f"{framework}-{inference_tool}"
143148
config = _config_for_framework_and_scope(_framework, image_scope, accelerator_type)
144-
elif framework == HUGGING_FACE_FRAMEWORK:
145-
config = _config_for_framework_and_scope(
146-
framework + "-training-compiler", image_scope, accelerator_type
147-
)
148-
else:
149-
raise ValueError(
150-
"Unsupported Configuration: Training Compiler is only supported with HuggingFace"
151-
)
152149

153150
original_version = version
154151
version = _validate_version_and_set_if_needed(version, config, framework)

src/sagemaker/tensorflow/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,5 @@
1616
from sagemaker.tensorflow.estimator import TensorFlow # noqa: F401 (imported but unused)
1717
from sagemaker.tensorflow.model import TensorFlowModel, TensorFlowPredictor # noqa: F401
1818
from sagemaker.tensorflow.processing import TensorFlowProcessor # noqa: F401
19+
20+
from sagemaker.tensorflow.training_compiler.config import TrainingCompilerConfig # noqa: F401

src/sagemaker/tensorflow/estimator.py

+25-3
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from sagemaker.transformer import Transformer
2727
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
2828
from sagemaker.workflow import is_pipeline_variable
29+
from sagemaker.tensorflow.training_compiler.config import TrainingCompilerConfig
2930

3031
logger = logging.getLogger("sagemaker")
3132

@@ -45,7 +46,8 @@ def __init__(
4546
model_dir=None,
4647
image_uri=None,
4748
distribution=None,
48-
**kwargs
49+
compiler_config=None,
50+
**kwargs,
4951
):
5052
"""Initialize a ``TensorFlow`` estimator.
5153
@@ -157,6 +159,8 @@ def __init__(
157159
158160
To learn more, see `Training with parameter servers
159161
<https://sagemaker.readthedocs.io/en/stable/frameworks/tensorflow/using_tf.html#training-with-parameter-servers>`_.
162+
compiler_config (:class:`~sagemaker.tensorflow.TrainingCompilerConfig`):
163+
Configures SageMaker Training Compiler to accelerate training.
160164
161165
**kwargs: Additional kwargs passed to the Framework constructor.
162166
@@ -202,6 +206,17 @@ def __init__(
202206
self.distribution = distribution or {}
203207

204208
self._validate_args(py_version=py_version)
209+
if compiler_config is not None:
210+
if not isinstance(compiler_config, TrainingCompilerConfig):
211+
error_string = (
212+
f"Expected instance of type {TrainingCompilerConfig}"
213+
f"for argument compiler_config. "
214+
f"Instead got {type(compiler_config)}"
215+
)
216+
raise ValueError(error_string)
217+
if compiler_config:
218+
compiler_config.validate(self)
219+
self.compiler_config = compiler_config
205220

206221
def _validate_args(self, py_version):
207222
"""Placeholder docstring"""
@@ -301,7 +316,7 @@ def create_model(
301316
entry_point=None,
302317
source_dir=None,
303318
dependencies=None,
304-
**kwargs
319+
**kwargs,
305320
):
306321
"""Creates ``TensorFlowModel`` object to be used for creating SageMaker model entities.
307322
@@ -352,7 +367,7 @@ def create_model(
352367
entry_point=entry_point,
353368
source_dir=source_dir,
354369
dependencies=dependencies,
355-
**kwargs
370+
**kwargs,
356371
)
357372

358373
def hyperparameters(self):
@@ -369,6 +384,13 @@ def hyperparameters(self):
369384
hyperparameters.update(
370385
EstimatorBase._json_encode_hyperparameters(additional_hyperparameters)
371386
)
387+
388+
if self.compiler_config:
389+
training_compiler_hyperparameters = self.compiler_config._to_hyperparameter_dict()
390+
hyperparameters.update(
391+
EstimatorBase._json_encode_hyperparameters(training_compiler_hyperparameters)
392+
)
393+
372394
return hyperparameters
373395

374396
def _default_s3_path(self, directory, mpi=False):

src/sagemaker/tensorflow/training_compiler/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Configuration for the SageMaker Training Compiler."""
14+
from __future__ import absolute_import
15+
import logging
16+
from packaging.specifiers import SpecifierSet
17+
from packaging.version import Version
18+
19+
from sagemaker.training_compiler.config import TrainingCompilerConfig as BaseConfig
20+
21+
logger = logging.getLogger(__name__)
22+
23+
24+
class TrainingCompilerConfig(BaseConfig):
25+
"""The SageMaker Training Compiler configuration class."""
26+
27+
SUPPORTED_INSTANCE_CLASS_PREFIXES = ["p3", "g4dn", "p4", "g5"]
28+
MIN_SUPPORTED_VERSION = "2.9"
29+
30+
def __init__(
31+
self,
32+
enabled=True,
33+
debug=False,
34+
):
35+
"""This class initializes a ``TrainingCompilerConfig`` instance.
36+
37+
`Amazon SageMaker Training Compiler
38+
<https://docs.aws.amazon.com/sagemaker/latest/dg/training-compiler.html>`_
39+
is a feature of SageMaker Training
40+
and speeds up training jobs by optimizing model execution graphs.
41+
42+
You can compile TensorFlow models
43+
by passing the object of this configuration class to the ``compiler_config``
44+
parameter of the :class:`~sagemaker.tensorflow.TensorFlow`
45+
estimator.
46+
47+
Args:
48+
enabled (bool): Optional. Switch to enable SageMaker Training Compiler.
49+
The default is ``True``.
50+
debug (bool): Optional. Whether to dump detailed logs for debugging.
51+
This comes with a potential performance slowdown.
52+
The default is ``False``.
53+
54+
**Example**: The following code shows the basic usage of the
55+
:class:`sagemaker.tensorflow.TrainingCompilerConfig()` class
56+
to run a TensorFlow training job with the compiler.
57+
58+
.. code-block:: python
59+
60+
from sagemaker.tensorflow import TensorFlow, TrainingCompilerConfig
61+
62+
tensorflow_estimator=TensorFlow(
63+
...
64+
compiler_config=TrainingCompilerConfig()
65+
)
66+
67+
.. seealso::
68+
69+
For more information about how to enable SageMaker Training Compiler
70+
for various training settings such as using TensorFlow-based models,
71+
PyTorch-based models, and distributed training,
72+
see `Enable SageMaker Training Compiler
73+
<https://docs.aws.amazon.com/sagemaker/latest/dg/training-compiler-enable.html>`_
74+
in the `Amazon SageMaker Training Compiler developer guide
75+
<https://docs.aws.amazon.com/sagemaker/latest/dg/training-compiler.html>`_.
76+
77+
"""
78+
79+
super(TrainingCompilerConfig, self).__init__(enabled=enabled, debug=debug)
80+
81+
@classmethod
82+
def validate(
83+
cls,
84+
estimator,
85+
):
86+
"""Checks if SageMaker Training Compiler is configured correctly.
87+
88+
Args:
89+
estimator (str): A estimator object
90+
If SageMaker Training Compiler is enabled, it will validate whether
91+
the estimator is configured to be compatible with Training Compiler.
92+
93+
Raises:
94+
ValueError: Raised if the requested configuration is not compatible
95+
with SageMaker Training Compiler.
96+
"""
97+
98+
super(TrainingCompilerConfig, cls).validate(estimator)
99+
100+
if estimator.framework_version:
101+
if Version(estimator.framework_version) in SpecifierSet(
102+
f"< {cls.MIN_SUPPORTED_VERSION}"
103+
):
104+
error_helper_string = (
105+
"SageMaker Training Compiler only supports TensorFlow version "
106+
">= {} but received {}"
107+
)
108+
error_helper_string = error_helper_string.format(
109+
cls.MIN_SUPPORTED_VERSION, estimator.framework_version
110+
)
111+
raise ValueError(error_helper_string)

0 commit comments

Comments
 (0)