Skip to content

feature: Adding Training Compiler support for TensorFlow estimator starting TF 2.9 #3156

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Jun 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
b4adec2
change: Restructuring Training Compiler UI implementation
Lokiiiiii Jun 2, 2022
d818176
feature: Adding Training Compiler support for TensorFlow estimator
Lokiiiiii Jun 2, 2022
5f11033
change: Tests targetting Training Compiler in TensorFlow estimator
Lokiiiiii Jun 2, 2022
846bb82
fix: linting in training compiler files
Lokiiiiii Jun 2, 2022
5a9db84
fix: liniting in tensorflow estimator
Lokiiiiii Jun 2, 2022
2912b47
fix: syntax error in trcomp tests
Lokiiiiii Jun 3, 2022
6b8007e
fix: logic error in trcomp initialization in HF estimator
Lokiiiiii Jun 3, 2022
85af72d
fix: logic error in trcomp test for TF estimator
Lokiiiiii Jun 3, 2022
79efc4f
fix: logic error in version comparison
Lokiiiiii Jun 6, 2022
239d639
fix: syntax error in TF trcomp
Lokiiiiii Jun 7, 2022
2b8de05
change: documentation updates for trcomp
Lokiiiiii Jun 7, 2022
9ea141c
Apply documentation suggestions from code review for trcomp
Lokiiiiii Jun 7, 2022
26729be
update: documentation update for trcomp
Lokiiiiii Jun 7, 2022
02697a0
Adding tests for the TF trcomp BYOC path
Lokiiiiii Jun 9, 2022
84a6b00
linting trcomp config
Lokiiiiii Jun 15, 2022
9f4cf52
Adding logic to convert compiler_config to hyperparameters
Lokiiiiii Jun 15, 2022
a4d86c1
Fixing trcomp tensorflow tests
Lokiiiiii Jun 15, 2022
80a223a
Fixing logic error in training_compiler supported version
Lokiiiiii Jun 15, 2022
df4d6d9
Fixing logic error in training_compiler
Lokiiiiii Jun 15, 2022
56ca880
Fixing logic error in training_compiler
Lokiiiiii Jun 15, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions doc/frameworks/tensorflow/sagemaker.tensorflow.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,14 @@ TensorFlow Estimator
:undoc-members:
:show-inheritance:

TensorFlow Training Compiler Configuration
------------------------------------------

.. autoclass:: sagemaker.tensorflow.TrainingCompilerConfig
:members:
:undoc-members:
:show-inheritance:

TensorFlow Serving Model
------------------------

Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/huggingface/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@
from sagemaker.huggingface.model import HuggingFaceModel, HuggingFacePredictor # noqa: F401
from sagemaker.huggingface.processing import HuggingFaceProcessor # noqa:F401

from sagemaker.training_compiler.config import TrainingCompilerConfig # noqa: F401
from sagemaker.huggingface.training_compiler.config import TrainingCompilerConfig # noqa: F401
12 changes: 4 additions & 8 deletions src/sagemaker/huggingface/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from sagemaker.huggingface.model import HuggingFaceModel
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT

from sagemaker.training_compiler.config import TrainingCompilerConfig
from sagemaker.huggingface.training_compiler.config import TrainingCompilerConfig

logger = logging.getLogger("sagemaker")

Expand Down Expand Up @@ -190,6 +190,8 @@ def __init__(
entry_point, source_dir, hyperparameters, image_uri=image_uri, **kwargs
)

self.distribution = distribution or {}

if compiler_config is not None:
if not isinstance(compiler_config, TrainingCompilerConfig):
error_string = (
Expand All @@ -199,13 +201,7 @@ def __init__(
)
raise ValueError(error_string)
if compiler_config:
compiler_config.validate(
image_uri=image_uri,
instance_type=instance_type,
distribution=distribution,
)

self.distribution = distribution or {}
compiler_config.validate(self)
self.compiler_config = compiler_config

def _validate_args(self, image_uri):
Expand Down
Empty file.
105 changes: 105 additions & 0 deletions src/sagemaker/huggingface/training_compiler/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Configuration for the SageMaker Training Compiler."""
from __future__ import absolute_import
import logging

from sagemaker.training_compiler.config import TrainingCompilerConfig as BaseConfig

logger = logging.getLogger(__name__)


class TrainingCompilerConfig(BaseConfig):
"""The SageMaker Training Compiler configuration class."""

SUPPORTED_INSTANCE_CLASS_PREFIXES = ["p3", "g4dn", "p4"]

def __init__(
self,
enabled=True,
debug=False,
):
"""This class initializes a ``TrainingCompilerConfig`` instance.

`Amazon SageMaker Training Compiler
<https://docs.aws.amazon.com/sagemaker/latest/dg/training-compiler.html>`_
is a feature of SageMaker Training
and speeds up training jobs by optimizing model execution graphs.

You can compile Hugging Face models
by passing the object of this configuration class to the ``compiler_config``
parameter of the :class:`~sagemaker.huggingface.HuggingFace`
estimator.

Args:
enabled (bool): Optional. Switch to enable SageMaker Training Compiler.
The default is ``True``.
debug (bool): Optional. Whether to dump detailed logs for debugging.
This comes with a potential performance slowdown.
The default is ``False``.

**Example**: The following code shows the basic usage of the
:class:`sagemaker.huggingface.TrainingCompilerConfig()` class
to run a HuggingFace training job with the compiler.

.. code-block:: python

from sagemaker.huggingface import HuggingFace, TrainingCompilerConfig

huggingface_estimator=HuggingFace(
...
compiler_config=TrainingCompilerConfig()
)

.. seealso::

For more information about how to enable SageMaker Training Compiler
for various training settings such as using TensorFlow-based models,
PyTorch-based models, and distributed training,
see `Enable SageMaker Training Compiler
<https://docs.aws.amazon.com/sagemaker/latest/dg/training-compiler-enable.html>`_
in the `Amazon SageMaker Training Compiler developer guide
<https://docs.aws.amazon.com/sagemaker/latest/dg/training-compiler.html>`_.

"""

super(TrainingCompilerConfig, self).__init__(enabled=enabled, debug=debug)

@classmethod
def validate(
cls,
estimator,
):
"""Checks if SageMaker Training Compiler is configured correctly.

Args:
estimator (str): A estimator object
If SageMaker Training Compiler is enabled, it will validate whether
the estimator is configured to be compatible with Training Compiler.

Raises:
ValueError: Raised if the requested configuration is not compatible
with SageMaker Training Compiler.
"""

super(TrainingCompilerConfig, cls).validate(estimator)

if estimator.image_uri:
error_helper_string = (
"Overriding the image URI is currently not supported "
"for SageMaker Training Compiler."
"Specify the following parameters to run the Hugging Face training job "
"with SageMaker Training Compiler enabled: "
"transformer_version, tensorflow_version or pytorch_version, and compiler_config."
)
raise ValueError(error_helper_string)
15 changes: 6 additions & 9 deletions src/sagemaker/image_uris.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,21 +134,18 @@ def retrieve(
tolerate_vulnerable_model,
tolerate_deprecated_model,
)
if training_compiler_config is None:

if training_compiler_config and (framework == HUGGING_FACE_FRAMEWORK):
config = _config_for_framework_and_scope(
framework + "-training-compiler", image_scope, accelerator_type
)
else:
_framework = framework
if framework == HUGGING_FACE_FRAMEWORK:
inference_tool = _get_inference_tool(inference_tool, instance_type)
if inference_tool == "neuron":
_framework = f"{framework}-{inference_tool}"
config = _config_for_framework_and_scope(_framework, image_scope, accelerator_type)
elif framework == HUGGING_FACE_FRAMEWORK:
config = _config_for_framework_and_scope(
framework + "-training-compiler", image_scope, accelerator_type
)
else:
raise ValueError(
"Unsupported Configuration: Training Compiler is only supported with HuggingFace"
)

original_version = version
version = _validate_version_and_set_if_needed(version, config, framework)
Expand Down
2 changes: 2 additions & 0 deletions src/sagemaker/tensorflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,5 @@
from sagemaker.tensorflow.estimator import TensorFlow # noqa: F401 (imported but unused)
from sagemaker.tensorflow.model import TensorFlowModel, TensorFlowPredictor # noqa: F401
from sagemaker.tensorflow.processing import TensorFlowProcessor # noqa: F401

from sagemaker.tensorflow.training_compiler.config import TrainingCompilerConfig # noqa: F401
28 changes: 25 additions & 3 deletions src/sagemaker/tensorflow/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from sagemaker.transformer import Transformer
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
from sagemaker.workflow import is_pipeline_variable
from sagemaker.tensorflow.training_compiler.config import TrainingCompilerConfig

logger = logging.getLogger("sagemaker")

Expand All @@ -45,7 +46,8 @@ def __init__(
model_dir=None,
image_uri=None,
distribution=None,
**kwargs
compiler_config=None,
**kwargs,
):
"""Initialize a ``TensorFlow`` estimator.

Expand Down Expand Up @@ -157,6 +159,8 @@ def __init__(

To learn more, see `Training with parameter servers
<https://sagemaker.readthedocs.io/en/stable/frameworks/tensorflow/using_tf.html#training-with-parameter-servers>`_.
compiler_config (:class:`~sagemaker.tensorflow.TrainingCompilerConfig`):
Configures SageMaker Training Compiler to accelerate training.

**kwargs: Additional kwargs passed to the Framework constructor.

Expand Down Expand Up @@ -202,6 +206,17 @@ def __init__(
self.distribution = distribution or {}

self._validate_args(py_version=py_version)
if compiler_config is not None:
if not isinstance(compiler_config, TrainingCompilerConfig):
error_string = (
f"Expected instance of type {TrainingCompilerConfig}"
f"for argument compiler_config. "
f"Instead got {type(compiler_config)}"
)
raise ValueError(error_string)
if compiler_config:
compiler_config.validate(self)
self.compiler_config = compiler_config

def _validate_args(self, py_version):
"""Placeholder docstring"""
Expand Down Expand Up @@ -301,7 +316,7 @@ def create_model(
entry_point=None,
source_dir=None,
dependencies=None,
**kwargs
**kwargs,
):
"""Creates ``TensorFlowModel`` object to be used for creating SageMaker model entities.

Expand Down Expand Up @@ -352,7 +367,7 @@ def create_model(
entry_point=entry_point,
source_dir=source_dir,
dependencies=dependencies,
**kwargs
**kwargs,
)

def hyperparameters(self):
Expand All @@ -369,6 +384,13 @@ def hyperparameters(self):
hyperparameters.update(
EstimatorBase._json_encode_hyperparameters(additional_hyperparameters)
)

if self.compiler_config:
training_compiler_hyperparameters = self.compiler_config._to_hyperparameter_dict()
hyperparameters.update(
EstimatorBase._json_encode_hyperparameters(training_compiler_hyperparameters)
)

return hyperparameters

def _default_s3_path(self, directory, mpi=False):
Expand Down
Empty file.
111 changes: 111 additions & 0 deletions src/sagemaker/tensorflow/training_compiler/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Configuration for the SageMaker Training Compiler."""
from __future__ import absolute_import
import logging
from packaging.specifiers import SpecifierSet
from packaging.version import Version

from sagemaker.training_compiler.config import TrainingCompilerConfig as BaseConfig

logger = logging.getLogger(__name__)


class TrainingCompilerConfig(BaseConfig):
"""The SageMaker Training Compiler configuration class."""

SUPPORTED_INSTANCE_CLASS_PREFIXES = ["p3", "g4dn", "p4", "g5"]
MIN_SUPPORTED_VERSION = "2.9"

def __init__(
self,
enabled=True,
debug=False,
):
"""This class initializes a ``TrainingCompilerConfig`` instance.

`Amazon SageMaker Training Compiler
<https://docs.aws.amazon.com/sagemaker/latest/dg/training-compiler.html>`_
is a feature of SageMaker Training
and speeds up training jobs by optimizing model execution graphs.

You can compile TensorFlow models
by passing the object of this configuration class to the ``compiler_config``
parameter of the :class:`~sagemaker.tensorflow.TensorFlow`
estimator.

Args:
enabled (bool): Optional. Switch to enable SageMaker Training Compiler.
The default is ``True``.
debug (bool): Optional. Whether to dump detailed logs for debugging.
This comes with a potential performance slowdown.
The default is ``False``.

**Example**: The following code shows the basic usage of the
:class:`sagemaker.tensorflow.TrainingCompilerConfig()` class
to run a TensorFlow training job with the compiler.

.. code-block:: python

from sagemaker.tensorflow import TensorFlow, TrainingCompilerConfig

tensorflow_estimator=TensorFlow(
...
compiler_config=TrainingCompilerConfig()
)

.. seealso::

For more information about how to enable SageMaker Training Compiler
for various training settings such as using TensorFlow-based models,
PyTorch-based models, and distributed training,
see `Enable SageMaker Training Compiler
<https://docs.aws.amazon.com/sagemaker/latest/dg/training-compiler-enable.html>`_
in the `Amazon SageMaker Training Compiler developer guide
<https://docs.aws.amazon.com/sagemaker/latest/dg/training-compiler.html>`_.

"""

super(TrainingCompilerConfig, self).__init__(enabled=enabled, debug=debug)

@classmethod
def validate(
cls,
estimator,
):
"""Checks if SageMaker Training Compiler is configured correctly.

Args:
estimator (str): A estimator object
If SageMaker Training Compiler is enabled, it will validate whether
the estimator is configured to be compatible with Training Compiler.

Raises:
ValueError: Raised if the requested configuration is not compatible
with SageMaker Training Compiler.
"""

super(TrainingCompilerConfig, cls).validate(estimator)

if estimator.framework_version:
if Version(estimator.framework_version) in SpecifierSet(
f"< {cls.MIN_SUPPORTED_VERSION}"
):
error_helper_string = (
"SageMaker Training Compiler only supports TensorFlow version "
">= {} but received {}"
)
error_helper_string = error_helper_string.format(
cls.MIN_SUPPORTED_VERSION, estimator.framework_version
)
raise ValueError(error_helper_string)
Loading