Skip to content

feature: Adding support for SageMaker Training Compiler in PyTorch estimator starting 1.12 #3500

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Dec 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/sagemaker/fw_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,7 @@ def framework_name_from_image(image_uri):
# We must support both the legacy and current image name format.
name_pattern = re.compile(
r"""^(?:sagemaker(?:-rl)?-)?
(tensorflow|mxnet|chainer|pytorch|scikit-learn|xgboost
(tensorflow|mxnet|chainer|pytorch|pytorch-trcomp|scikit-learn|xgboost
|huggingface-tensorflow|huggingface-pytorch
|huggingface-tensorflow-trcomp|huggingface-pytorch-trcomp)(?:-)?
(scriptmode|training)?
Expand Down
41 changes: 41 additions & 0 deletions src/sagemaker/image_uri_config/pytorch-training-compiler.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
{
"training": {
"processors": [
"gpu"
],
"version_aliases": {
"1.12": "1.12.0"
},
"versions": {
"1.12.0": {
"py_versions": [
"py38"
],
"registries": {
"af-south-1": "626614931356",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ca-central-1": "763104351884",
"eu-central-1": "763104351884",
"eu-north-1": "763104351884",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"eu-south-1": "692866216735",
"me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
"us-west-1": "763104351884",
"us-west-2": "763104351884"
},
"repository": "pytorch-training"
}
}
}
}
2 changes: 1 addition & 1 deletion src/sagemaker/image_uris.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def retrieve(
tolerate_deprecated_model,
)

if training_compiler_config and (framework == HUGGING_FACE_FRAMEWORK):
if training_compiler_config and (framework in [HUGGING_FACE_FRAMEWORK, "pytorch"]):
final_image_scope = image_scope
config = _config_for_framework_and_scope(
framework + "-training-compiler", final_image_scope, accelerator_type
Expand Down
2 changes: 2 additions & 0 deletions src/sagemaker/pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,5 @@
from sagemaker.pytorch.estimator import PyTorch # noqa: F401
from sagemaker.pytorch.model import PyTorchModel, PyTorchPredictor # noqa: F401
from sagemaker.pytorch.processing import PyTorchProcessor # noqa: F401

from sagemaker.pytorch.training_compiler.config import TrainingCompilerConfig # noqa: F401
60 changes: 57 additions & 3 deletions src/sagemaker/pytorch/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
)
from sagemaker.pytorch import defaults
from sagemaker.pytorch.model import PyTorchModel
from sagemaker.pytorch.training_compiler.config import TrainingCompilerConfig
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
from sagemaker.workflow.entities import PipelineVariable

Expand All @@ -51,7 +52,8 @@ def __init__(
hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
image_uri: Optional[Union[str, PipelineVariable]] = None,
distribution: Optional[Dict] = None,
**kwargs
compiler_config: Optional[TrainingCompilerConfig] = None,
**kwargs,
):
"""This ``Estimator`` executes a PyTorch script in a managed PyTorch execution environment.

Expand Down Expand Up @@ -208,6 +210,31 @@ def __init__(
To learn more, see `Training with parameter servers
<https://sagemaker.readthedocs.io/en/stable/frameworks/tensorflow/using_tf.html#training-with-parameter-servers>`_.

**To enable distributed training with
`SageMaker Training Compiler <https://docs.aws.amazon.com/sagemaker/latest/dg/training-compiler.html>`_
for PyTorch:**

.. code:: python

{
"pytorchxla": {
"enabled": True
}
}

To learn more, see `SageMaker Training Compiler
<https://docs.aws.amazon.com/sagemaker/latest/dg/training-compiler.html>`_
in the *Amazon SageMaker Developer Guide*.

.. note::

When you use this PyTorch XLA option for distributed training strategy,
you must add the ``compiler_config`` parameter and activate SageMaker
Training Compiler.

compiler_config (:class:`~sagemaker.pytorch.TrainingCompilerConfig`):
Configures SageMaker Training Compiler to accelerate training.

**kwargs: Additional kwargs passed to the :class:`~sagemaker.estimator.Framework`
constructor.

Expand Down Expand Up @@ -250,6 +277,25 @@ def __init__(

self.distribution = distribution or {}

if compiler_config is not None:
if not isinstance(compiler_config, TrainingCompilerConfig):
error_string = (
f"Expected instance of type {TrainingCompilerConfig}"
f"for argument compiler_config. "
f"Instead got {type(compiler_config)}"
)
raise ValueError(error_string)
if compiler_config:
compiler_config.validate(self)
elif distribution is not None and "pytorchxla" in distribution:
raise ValueError(
"Distributed training through PyTorch XLA is currently only supported "
"when SageMaker Training Compiler is enabled. To learn more, "
"see Enable SageMaker Training Compiler at "
"https://docs.aws.amazon.com/sagemaker/latest/dg/training-compiler-enable.html."
)
self.compiler_config = compiler_config

def _pytorch_distribution_configuration(self, distribution):
"""Returns a dict of distribution config for PyTorch training

Expand Down Expand Up @@ -289,6 +335,12 @@ def hyperparameters(self):
hyperparameters.update(
EstimatorBase._json_encode_hyperparameters(additional_hyperparameters)
)
if self.compiler_config:
training_compiler_hyperparameters = self.compiler_config._to_hyperparameter_dict()
hyperparameters.update(
EstimatorBase._json_encode_hyperparameters(training_compiler_hyperparameters)
)

return hyperparameters

def create_model(
Expand All @@ -299,7 +351,7 @@ def create_model(
entry_point=None,
source_dir=None,
dependencies=None,
**kwargs
**kwargs,
):
"""Create a SageMaker ``PyTorchModel`` object that can be deployed to an ``Endpoint``.

Expand Down Expand Up @@ -350,7 +402,7 @@ def create_model(
sagemaker_session=self.sagemaker_session,
vpc_config=self.get_vpc_config(vpc_config_override),
dependencies=(dependencies or self.dependencies),
**kwargs
**kwargs,
)

@classmethod
Expand All @@ -371,6 +423,8 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
)
image_uri = init_params.pop("image_uri")
framework, py_version, tag, _ = framework_name_from_image(image_uri)
if framework:
framework = framework.split("-")[0]

if tag is None:
framework_version = None
Expand Down
Empty file.
151 changes: 151 additions & 0 deletions src/sagemaker/pytorch/training_compiler/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Configuration for the SageMaker Training Compiler."""
from __future__ import absolute_import
import logging
from typing import Union
from packaging.specifiers import SpecifierSet
from packaging.version import Version

from sagemaker.training_compiler.config import TrainingCompilerConfig as BaseConfig
from sagemaker.workflow.entities import PipelineVariable

logger = logging.getLogger(__name__)


class TrainingCompilerConfig(BaseConfig):
"""The SageMaker Training Compiler configuration class."""

SUPPORTED_INSTANCE_CLASS_PREFIXES = ["p3", "p3dn", "g4dn", "p4d", "g5"]
SUPPORTED_INSTANCE_TYPES_WITH_EFA = [
"ml.g4dn.8xlarge",
"ml.g4dn.12xlarge",
"ml.g5.48xlarge",
"ml.p3dn.24xlarge",
"ml.p4d.24xlarge",
]

def __init__(
self,
enabled: Union[bool, PipelineVariable] = True,
debug: Union[bool, PipelineVariable] = False,
):
"""This class initializes a ``TrainingCompilerConfig`` instance.

`Amazon SageMaker Training Compiler
<https://docs.aws.amazon.com/sagemaker/latest/dg/training-compiler.html>`_
is a feature of SageMaker Training
and speeds up training jobs by optimizing model execution graphs.

You can compile PyTorch models
by passing the object of this configuration class to the ``compiler_config``
parameter of the :class:`~sagemaker.pytorch.PyTorch`
estimator.

Args:
enabled (bool or PipelineVariable): Optional. Switch to enable SageMaker
Training Compiler. The default is ``True``.
debug (bool or PipelineVariable): Optional. Whether to dump detailed logs
for debugging. This comes with a potential performance slowdown.
The default is ``False``.

**Example**: The following code shows the basic usage of the
:class:`sagemaker.pytorch.TrainingCompilerConfig()` class
to run a PyTorch training job with the compiler.

.. code-block:: python

from sagemaker.pytorch import PyTorch, TrainingCompilerConfig

pytorch_estimator=PyTorch(
...
compiler_config=TrainingCompilerConfig()
)

.. seealso::

For more information about how to enable SageMaker Training Compiler
for various training settings such as distributed training,
see `Enable SageMaker Training Compiler
<https://docs.aws.amazon.com/sagemaker/latest/dg/training-compiler-enable.html>`_
in the `Amazon SageMaker Training Compiler developer guide
<https://docs.aws.amazon.com/sagemaker/latest/dg/training-compiler.html>`_.

"""

super(TrainingCompilerConfig, self).__init__(enabled=enabled, debug=debug)

@classmethod
def validate(
cls,
estimator,
):
"""Checks if SageMaker Training Compiler is configured correctly.

Args:
estimator (:class:`sagemaker.pytorch.PyTorch`): An estimator object.
If SageMaker Training Compiler is enabled, it will validate whether
the estimator is configured to be compatible with Training Compiler.

Raises:
ValueError: Raised if the requested configuration is not compatible
with SageMaker Training Compiler.
"""

super(TrainingCompilerConfig, cls).validate(estimator)

if estimator.image_uri:
error_helper_string = (
"Overriding the image URI is currently not supported "
"for SageMaker Training Compiler."
"Specify the following parameters to run the PyTorch training job "
"with SageMaker Training Compiler enabled: "
"framework_version, and compiler_config."
)
raise ValueError(error_helper_string)

if estimator.distribution:
pt_xla_present = "pytorchxla" in estimator.distribution
pt_xla_enabled = estimator.distribution.get("pytorchxla", {}).get("enabled", False)
if pt_xla_enabled:
if estimator.framework_version:
if Version(estimator.framework_version) in SpecifierSet("< 1.12"):
error_helper_string = (
"Distribution mechanism 'pytorchxla' is currently only supported for "
"PyTorch >= 1.12 when SageMaker Training Compiler is enabled."
" Received framework_version={} which is unsupported."
)
raise ValueError(error_helper_string.format(estimator.framework_version))
if estimator.instance_type not in cls.SUPPORTED_INSTANCE_TYPES_WITH_EFA:
logger.warning(
"Consider using instances with EFA support when "
"training with PyTorch >= 1.12 and SageMaker Training Compiler "
"enabled. SageMaker Training Compiler leverages EFA to provide better "
"performance for distributed training."
)
if not pt_xla_present:
if estimator.framework_version:
if Version(estimator.framework_version) in SpecifierSet(">= 1.12"):
error_helper_string = (
"'pytorchxla' is the only distribution mechanism currently supported "
"for PyTorch >= 1.12 when SageMaker Training Compiler is enabled."
" Received distribution={} which is unsupported."
)
raise ValueError(error_helper_string.format(estimator.distribution))
elif estimator.instance_count and estimator.instance_count > 1:
if estimator.framework_version:
if Version(estimator.framework_version) in SpecifierSet(">= 1.12"):
logger.warning(
"Consider setting 'distribution' to 'pytorchxla' for distributed "
"training with PyTorch >= 1.12 and SageMaker Training Compiler enabled."
)
1 change: 1 addition & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
"neo_pytorch",
"neo_tensorflow",
"pytorch",
"pytorch_training_compiler",
"ray_pytorch",
"ray_tensorflow",
"sklearn",
Expand Down
2 changes: 2 additions & 0 deletions tests/data/huggingface_byoc/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
transformers
datasets
Loading