Skip to content

feature: Adding support in HuggingFace estimator for Training Compiler enhanced PyTorch 1.11 #3307

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 26 commits into from
Aug 30, 2022
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
336204d
feature: Adding support in HuggingFace estimator for Training Compile…
Lokiiiiii Aug 16, 2022
3b8738f
Merge remote-tracking branch 'aws/master' into trcomp-hf-pt-111
Lokiiiiii Aug 18, 2022
352c5f6
Update src/sagemaker/huggingface/training_compiler/config.py
Lokiiiiii Aug 18, 2022
5a21b26
Update src/sagemaker/huggingface/training_compiler/config.py
Lokiiiiii Aug 18, 2022
883cabb
Update src/sagemaker/huggingface/training_compiler/config.py
Lokiiiiii Aug 18, 2022
66a42f6
Update src/sagemaker/training_compiler/config.py
Lokiiiiii Aug 18, 2022
9098c89
fix: renaming distribution parameters pytorch_xla -> pytorchxla
Lokiiiiii Aug 19, 2022
a0aa2fc
Merge remote-tracking branch 'aws/master' into trcomp-hf-pt-111
Lokiiiiii Aug 19, 2022
1db9ed0
Update src/sagemaker/huggingface/training_compiler/config.py
Lokiiiiii Aug 22, 2022
2ea39c4
Update src/sagemaker/huggingface/estimator.py
Lokiiiiii Aug 22, 2022
de3077c
Update src/sagemaker/huggingface/training_compiler/config.py
Lokiiiiii Aug 22, 2022
192437c
Update src/sagemaker/huggingface/training_compiler/config.py
Lokiiiiii Aug 22, 2022
33b3174
Update src/sagemaker/huggingface/estimator.py
Lokiiiiii Aug 22, 2022
10b7c4e
Update src/sagemaker/huggingface/estimator.py
Lokiiiiii Aug 22, 2022
5c589d5
Merge remote-tracking branch 'aws/master' into trcomp-hf-pt-111
Lokiiiiii Aug 26, 2022
e9be4c1
Fix: syntax error in trcomp tests
Lokiiiiii Aug 26, 2022
8a7827d
fix: linting
Lokiiiiii Aug 26, 2022
0b2a16c
Merge remote-tracking branch 'aws/master' into trcomp-hf-pt-111
Lokiiiiii Aug 30, 2022
37223ec
fix: linting to break up long lines
Lokiiiiii Aug 30, 2022
8ce021c
fix: fixture scoping issue in integ test
Lokiiiiii Aug 30, 2022
d696653
fix: broken unit tests for trcomp
Lokiiiiii Aug 30, 2022
6688892
fix: broken skip logic in version fixtures
Lokiiiiii Aug 30, 2022
e0580b5
fix: update test and version compatibility
Lokiiiiii Aug 30, 2022
65086e5
feature: added warning recommending EFA instances with training compiler
Lokiiiiii Aug 30, 2022
26f47d4
Update src/sagemaker/huggingface/estimator.py
Lokiiiiii Aug 30, 2022
280f625
Update src/sagemaker/training_compiler/config.py
Lokiiiiii Aug 30, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ class EstimatorBase(with_metaclass(ABCMeta, object)): # pylint: disable=too-man
instance.
"""

LAUNCH_PT_XLA_ENV_NAME = "sagemaker_pytorch_xla_multi_worker_enabled"
LAUNCH_PS_ENV_NAME = "sagemaker_parameter_server_enabled"
LAUNCH_MPI_ENV_NAME = "sagemaker_mpi_enabled"
LAUNCH_SM_DDP_ENV_NAME = "sagemaker_distributed_dataparallel_enabled"
Expand Down Expand Up @@ -3316,6 +3317,10 @@ def _distribution_configuration(self, distribution):
"instance_groups"
]

if "pytorchxla" in distribution:
pt_xla_enabled = distribution.get("pytorchxla").get("enabled", False)
distribution_config[self.LAUNCH_PT_XLA_ENV_NAME] = pt_xla_enabled

if "parameter_server" in distribution:
ps_enabled = distribution.get("parameter_server").get("enabled", False)
distribution_config[self.LAUNCH_PS_ENV_NAME] = ps_enabled
Expand Down
29 changes: 29 additions & 0 deletions src/sagemaker/huggingface/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,28 @@ def __init__(
}
}
}

To enable distributed training with
`SageMaker Training Compiler <https://docs.aws.amazon.com/sagemaker/latest/dg/training-compiler.html>`_
for PyTorch:

.. code:: python

{
"pytorchxla": {
"enabled": True
}
}

To learn more, see `SageMaker Training Compiler
<https://docs.aws.amazon.com/sagemaker/latest/dg/training-compiler.html>`_
in the *Amazon SageMaker Developer Guide*.

.. note::

When you use this PyTorch XLA option for distributed training strategy,
you must add the ``compiler_config`` parameter and activate SageMaker
Training Compiler.
compiler_config (:class:`~sagemaker.huggingface.TrainingCompilerConfig`):
Configures SageMaker Training Compiler to accelerate training.

Expand Down Expand Up @@ -204,6 +226,13 @@ def __init__(
raise ValueError(error_string)
if compiler_config:
compiler_config.validate(self)
elif distribution is not None and "pytorchxla" in distribution:
raise ValueError(
"Distributed training through PyTorch XLA is currently only supported "
"when SageMaker Training Compiler is enabled. To learn more, "
"see Enable SageMaker Training Compiler at "
"https://docs.aws.amazon.com/sagemaker/latest/dg/training-compiler-enable.html."
)
self.compiler_config = compiler_config

def _validate_args(self, image_uri):
Expand Down
56 changes: 54 additions & 2 deletions src/sagemaker/huggingface/training_compiler/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from __future__ import absolute_import
import logging
from typing import Union
from packaging.specifiers import SpecifierSet
from packaging.version import Version

from sagemaker.training_compiler.config import TrainingCompilerConfig as BaseConfig
from sagemaker.workflow.entities import PipelineVariable
Expand All @@ -24,7 +26,14 @@
class TrainingCompilerConfig(BaseConfig):
"""The SageMaker Training Compiler configuration class."""

SUPPORTED_INSTANCE_CLASS_PREFIXES = ["p3", "g4dn", "p4"]
SUPPORTED_INSTANCE_CLASS_PREFIXES = ["p3", "g4dn", "p4d", "g5"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did you mean to remove p4? is it intentional?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

p4d is the only available flavor of p4 AFAIK. Just making the check more specific.

SUPPORTED_INSTANCE_TYPES_WITH_EFA = [
"ml.g4dn.8xlarge",
"ml.g4dn.12xlarge",
"ml.g5.48xlarge",
"ml.p3dn.24xlarge",
"ml.p4d.24xlarge",
]

def __init__(
self,
Expand Down Expand Up @@ -85,7 +94,7 @@ def validate(
"""Checks if SageMaker Training Compiler is configured correctly.

Args:
estimator (str): A estimator object
estimator (:class:`sagemaker.huggingface.HuggingFace`): An estimator object.
If SageMaker Training Compiler is enabled, it will validate whether
the estimator is configured to be compatible with Training Compiler.

Expand All @@ -105,3 +114,46 @@ def validate(
"transformer_version, tensorflow_version or pytorch_version, and compiler_config."
)
raise ValueError(error_helper_string)

if estimator.distribution:
pt_xla_present = "pytorchxla" in estimator.distribution
pt_xla_enabled = estimator.distribution.get("pytorchxla", {}).get("enabled", False)
if pt_xla_enabled:
if estimator.tensorflow_version:
error_helper_string = (
"Distribution mechanism 'pytorchxla' is currently only supported for "
"PyTorch >= 1.11 when SageMaker Training Compiler is enabled. Received "
"tensorflow_version={} which is unsupported."
)
raise ValueError(error_helper_string.format(estimator.tensorflow_version))
if estimator.pytorch_version:
if Version(estimator.pytorch_version) in SpecifierSet("< 1.11"):
error_helper_string = (
"Distribution mechanism 'pytorchxla' is currently only supported for "
"PyTorch >= 1.11 when SageMaker Training Compiler is enabled."
" Received pytorch_version={} which is unsupported."
)
raise ValueError(error_helper_string.format(estimator.pytorch_version))
if estimator.instance_type not in cls.SUPPORTED_INSTANCE_TYPES_WITH_EFA:
logger.warning(
"Consider using instances with EFA support when "
"training with PyTorch >= 1.11 and SageMaker Training Compiler "
"enabled. SageMaker Training Compiler leverages EFA to provide better "
"performance for distributed training."
)
if not pt_xla_present:
if estimator.pytorch_version:
if Version(estimator.pytorch_version) in SpecifierSet(">= 1.11"):
error_helper_string = (
"'pytorchxla' is the only distribution mechanism currently supported "
"for PyTorch >= 1.11 when SageMaker Training Compiler is enabled."
" Received distribution={} which is unsupported."
)
raise ValueError(error_helper_string.format(estimator.distribution))
elif estimator.instance_count and estimator.instance_count > 1:
if estimator.pytorch_version:
if Version(estimator.pytorch_version) in SpecifierSet(">= 1.11"):
logger.warning(
"Consider setting 'distribution' to 'pytorchxla' for distributed "
"training with PyTorch >= 1.11 and SageMaker Training Compiler enabled."
)
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
"processors": ["gpu"],
"version_aliases": {
"4.11": "4.11.0",
"4.17": "4.17.0"
"4.17": "4.17.0",
"4.21": "4.21.1"
},
"versions": {
"4.11.0": {
Expand Down Expand Up @@ -97,6 +98,40 @@
"repository": "huggingface-tensorflow-trcomp-training",
"container_version": {"gpu":"cu112-ubuntu20.04"}
}
},
"4.21.1": {
"version_aliases": {
"pytorch1.11": "pytorch1.11.0"
},
"pytorch1.11.0": {
"py_versions": ["py38"],
"registries": {
"af-south-1": "626614931356",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ap-southeast-3": "907027046896",
"ca-central-1": "763104351884",
"eu-central-1": "763104351884",
"eu-north-1": "763104351884",
"eu-south-1": "692866216735",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
"us-west-1": "763104351884",
"us-west-2": "763104351884"
},
"repository": "huggingface-pytorch-trcomp-training",
"container_version": {"gpu":"cu113-ubuntu20.04"}
}
}
}
}
Expand Down
56 changes: 34 additions & 22 deletions src/sagemaker/training_compiler/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class TrainingCompilerConfig(object):
"""The SageMaker Training Compiler configuration class."""

DEBUG_PATH = "/opt/ml/output/data/compiler/"
SUPPORTED_INSTANCE_CLASS_PREFIXES = ["p3", "g4dn", "p4"]
SUPPORTED_INSTANCE_CLASS_PREFIXES = ["p3", "g4dn", "p4d", "g5"]

HP_ENABLE_COMPILER = "sagemaker_training_compiler_enabled"
HP_ENABLE_DEBUG = "sagemaker_training_compiler_debug_mode"
Expand Down Expand Up @@ -123,7 +123,7 @@ def validate(
"""Checks if SageMaker Training Compiler is configured correctly.

Args:
estimator (str): A estimator object
estimator (:class:`sagemaker.estimator.Estimator`): An estimator object.
When SageMaker Training Compiler is enabled, it validates if
the estimator is configured to be compatible with Training Compiler.

Expand All @@ -132,31 +132,34 @@ def validate(
ValueError: Raised if the requested configuration is not compatible
with SageMaker Training Compiler.
"""

if "local" not in estimator.instance_type:
requested_instance_class = estimator.instance_type.split(".")[
1
] # Expecting ml.class.size
if not any(
[
requested_instance_class.startswith(i)
for i in cls.SUPPORTED_INSTANCE_CLASS_PREFIXES
]
):
if estimator.instance_type:
if "local" not in estimator.instance_type:
requested_instance_class = estimator.instance_type.split(".")[
1
] # Expecting ml.class.size
if not any(
[
requested_instance_class.startswith(i)
for i in cls.SUPPORTED_INSTANCE_CLASS_PREFIXES
]
):
error_helper_string = (
Copy link

@harryzorus harryzorus Aug 30, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add info logging to inform customers that p3dn.24xlarge & p4d.24xlarge instances offer the best performance for distributed training?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we just want to generalize to EFA enabled instances or what is special about these 2 instance types ?

Copy link

@harryzorus harryzorus Aug 30, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good idea, we eventually would be to recommend all EFA instances that improve performance. We should start with recommending EFA instances we benchmarked thoroughly, we can add other instances to the list once we benchmark them internally.

"Unsupported Instance class {}."
"SageMaker Training Compiler only supports {}"
)
error_helper_string = error_helper_string.format(
requested_instance_class, cls.SUPPORTED_INSTANCE_CLASS_PREFIXES
)
raise ValueError(error_helper_string)
elif estimator.instance_type == "local":
error_helper_string = (
"Unsupported Instance class {}. SageMaker Training Compiler only supports {}"
"The local mode is not supported by SageMaker Training Compiler."
"It only supports the following GPU instances: {}"
)
error_helper_string = error_helper_string.format(
requested_instance_class, cls.SUPPORTED_INSTANCE_CLASS_PREFIXES
cls.SUPPORTED_INSTANCE_CLASS_PREFIXES
)
raise ValueError(error_helper_string)
elif estimator.instance_type == "local":
error_helper_string = (
"The local mode is not supported by SageMaker Training Compiler."
"It only supports the following GPU instances: {}"
)
error_helper_string = error_helper_string.format(cls.SUPPORTED_INSTANCE_CLASS_PREFIXES)
raise ValueError(error_helper_string)

if estimator.distribution and "smdistributed" in estimator.distribution:
raise ValueError(
Expand All @@ -180,3 +183,12 @@ def validate(
estimator.debugger_hook_config, estimator.disable_profiler
)
logger.warning(helper_string)

if estimator.instance_groups:
raise ValueError(
"SageMaker Training Compiler currently only supports homogeneous clusters of "
"the following GPU instance families: {}. Please use the 'instance_type' "
"and 'instance_count' parameters instead of 'instance_groups'".format(
cls.SUPPORTED_INSTANCE_CLASS_PREFIXES
)
)
29 changes: 24 additions & 5 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,27 +252,46 @@ def huggingface_pytorch_training_py_version(huggingface_pytorch_training_version

@pytest.fixture(scope="module")
def huggingface_training_compiler_pytorch_version(huggingface_training_compiler_version):
return _huggingface_base_fm_version(
versions = _huggingface_base_fm_version(
huggingface_training_compiler_version, "pytorch", "huggingface_training_compiler"
)[0]
)
if not versions:
pytest.skip(
f"Hugging Face Training Compiler version {huggingface_training_compiler_version} does "
f"not have a PyTorch release."
)
return versions[0]


@pytest.fixture(scope="module")
def huggingface_training_compiler_tensorflow_version(huggingface_training_compiler_version):
return _huggingface_base_fm_version(
versions = _huggingface_base_fm_version(
huggingface_training_compiler_version, "tensorflow", "huggingface_training_compiler"
)[0]
)
if not versions:
pytest.skip(
f"Hugging Face Training Compiler version {huggingface_training_compiler_version} "
f"does not have a TensorFlow release."
)
return versions[0]


@pytest.fixture(scope="module")
def huggingface_training_compiler_py_version(huggingface_training_compiler_tensorflow_version):
def huggingface_training_compiler_tensorflow_py_version(
huggingface_training_compiler_tensorflow_version,
):
return (
"py37"
if Version(huggingface_training_compiler_tensorflow_version) < Version("2.6")
else "py38"
)


@pytest.fixture(scope="module")
def huggingface_training_compiler_pytorch_py_version(huggingface_training_compiler_pytorch_version):
return "py38"


@pytest.fixture(scope="module")
def huggingface_pytorch_latest_training_py_version(huggingface_training_pytorch_latest_version):
return (
Expand Down
Loading