Skip to content

feature: Adding support in HuggingFace estimator for Training Compiler enhanced PyTorch 1.11 #3307

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 26 commits into from
Aug 30, 2022
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
336204d
feature: Adding support in HuggingFace estimator for Training Compile…
Lokiiiiii Aug 16, 2022
3b8738f
Merge remote-tracking branch 'aws/master' into trcomp-hf-pt-111
Lokiiiiii Aug 18, 2022
352c5f6
Update src/sagemaker/huggingface/training_compiler/config.py
Lokiiiiii Aug 18, 2022
5a21b26
Update src/sagemaker/huggingface/training_compiler/config.py
Lokiiiiii Aug 18, 2022
883cabb
Update src/sagemaker/huggingface/training_compiler/config.py
Lokiiiiii Aug 18, 2022
66a42f6
Update src/sagemaker/training_compiler/config.py
Lokiiiiii Aug 18, 2022
9098c89
fix: renaming distribution parameters pytorch_xla -> pytorchxla
Lokiiiiii Aug 19, 2022
a0aa2fc
Merge remote-tracking branch 'aws/master' into trcomp-hf-pt-111
Lokiiiiii Aug 19, 2022
1db9ed0
Update src/sagemaker/huggingface/training_compiler/config.py
Lokiiiiii Aug 22, 2022
2ea39c4
Update src/sagemaker/huggingface/estimator.py
Lokiiiiii Aug 22, 2022
de3077c
Update src/sagemaker/huggingface/training_compiler/config.py
Lokiiiiii Aug 22, 2022
192437c
Update src/sagemaker/huggingface/training_compiler/config.py
Lokiiiiii Aug 22, 2022
33b3174
Update src/sagemaker/huggingface/estimator.py
Lokiiiiii Aug 22, 2022
10b7c4e
Update src/sagemaker/huggingface/estimator.py
Lokiiiiii Aug 22, 2022
5c589d5
Merge remote-tracking branch 'aws/master' into trcomp-hf-pt-111
Lokiiiiii Aug 26, 2022
e9be4c1
Fix: syntax error in trcomp tests
Lokiiiiii Aug 26, 2022
8a7827d
fix: linting
Lokiiiiii Aug 26, 2022
0b2a16c
Merge remote-tracking branch 'aws/master' into trcomp-hf-pt-111
Lokiiiiii Aug 30, 2022
37223ec
fix: linting to break up long lines
Lokiiiiii Aug 30, 2022
8ce021c
fix: fixture scoping issue in integ test
Lokiiiiii Aug 30, 2022
d696653
fix: broken unit tests for trcomp
Lokiiiiii Aug 30, 2022
6688892
fix: broken skip logic in version fixtures
Lokiiiiii Aug 30, 2022
e0580b5
fix: update test and version compatibility
Lokiiiiii Aug 30, 2022
65086e5
feature: added warning recommending EFA instances with training compiler
Lokiiiiii Aug 30, 2022
26f47d4
Update src/sagemaker/huggingface/estimator.py
Lokiiiiii Aug 30, 2022
280f625
Update src/sagemaker/training_compiler/config.py
Lokiiiiii Aug 30, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ class EstimatorBase(with_metaclass(ABCMeta, object)): # pylint: disable=too-man
instance.
"""

LAUNCH_PT_XLA_ENV_NAME = "sagemaker_pytorch_xla_multi_worker_enabled"
LAUNCH_PS_ENV_NAME = "sagemaker_parameter_server_enabled"
LAUNCH_MPI_ENV_NAME = "sagemaker_mpi_enabled"
LAUNCH_SM_DDP_ENV_NAME = "sagemaker_distributed_dataparallel_enabled"
Expand Down Expand Up @@ -3248,6 +3249,10 @@ def _distribution_configuration(self, distribution):
"instance_groups"
]

if "pytorchxla" in distribution:
pt_xla_enabled = distribution.get("pytorchxla").get("enabled", False)
distribution_config[self.LAUNCH_PT_XLA_ENV_NAME] = pt_xla_enabled

if "parameter_server" in distribution:
ps_enabled = distribution.get("parameter_server").get("enabled", False)
distribution_config[self.LAUNCH_PS_ENV_NAME] = ps_enabled
Expand Down
15 changes: 15 additions & 0 deletions src/sagemaker/huggingface/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,16 @@ def __init__(
}
}
}

To enable distributed training with Training Compiler for PyTorch:

.. code:: python

{
"pytorchxla": {
"enabled": True
}
}
compiler_config (:class:`~sagemaker.huggingface.TrainingCompilerConfig`):
Configures SageMaker Training Compiler to accelerate training.

Expand Down Expand Up @@ -204,6 +214,11 @@ def __init__(
raise ValueError(error_string)
if compiler_config:
compiler_config.validate(self)
elif distribution is not None and "pytorchxla" in distribution:
raise ValueError(
"Distributed training through PyTorch XLA is currently only supported "
"when Training Compiler is enabled."
)
self.compiler_config = compiler_config

def _validate_args(self, image_uri):
Expand Down
42 changes: 40 additions & 2 deletions src/sagemaker/huggingface/training_compiler/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
"""Configuration for the SageMaker Training Compiler."""
from __future__ import absolute_import
import logging
from packaging.specifiers import SpecifierSet
from packaging.version import Version
from typing import Union

from sagemaker.training_compiler.config import TrainingCompilerConfig as BaseConfig
Expand All @@ -24,7 +26,7 @@
class TrainingCompilerConfig(BaseConfig):
"""The SageMaker Training Compiler configuration class."""

SUPPORTED_INSTANCE_CLASS_PREFIXES = ["p3", "g4dn", "p4"]
SUPPORTED_INSTANCE_CLASS_PREFIXES = ["p3", "g4dn", "p4d", "g5"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did you mean to remove p4? is it intentional?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

p4d is the only available flavor of p4 AFAIK. Just making the check more specific.


def __init__(
self,
Expand Down Expand Up @@ -85,7 +87,7 @@ def validate(
"""Checks if SageMaker Training Compiler is configured correctly.

Args:
estimator (str): A estimator object
estimator (:class:`sagemaker.huggingface.HuggingFace`): An estimator object.
If SageMaker Training Compiler is enabled, it will validate whether
the estimator is configured to be compatible with Training Compiler.

Expand All @@ -105,3 +107,39 @@ def validate(
"transformer_version, tensorflow_version or pytorch_version, and compiler_config."
)
raise ValueError(error_helper_string)

if estimator.distribution:
pt_xla_present = "pytorchxla" in estimator.distribution
pt_xla_enabled = estimator.distribution.get("pytorchxla", {}).get("enabled", False)
if pt_xla_enabled:
if estimator.tensorflow_version:
error_helper_string = (
"Distribution mechanism 'pytorchxla' is currently only supported for "
"PyTorch >= 1.11 when Training Compiler is enabled. Received "
"tensorflow_version={} which is unsupported."
)
raise ValueError(error_helper_string.format(estimator.tensorflow_version))
elif estimator.pytorch_version:
if Version(estimator.pytorch_version) in SpecifierSet("< 1.11"):
error_helper_string = (
"Distribution mechanism 'pytorchxla' is currently only supported for "
"PyTorch >= 1.11 when Training Compiler is enabled. Received "
"pytorch_version={} which is unsupported."
)
raise ValueError(error_helper_string.format(estimator.pytorch_version))
if not pt_xla_present:
if estimator.pytorch_version:
if Version(estimator.pytorch_version) in SpecifierSet(">= 1.11"):
error_helper_string = (
"'pytorchxla' is the only distribution mechanism currently supported "
"for PyTorch >= 1.11 when Training Compiler is enabled. Received "
"distribution={} which is unsupported."
)
raise ValueError(error_helper_string.format(estimator.distribution))
elif estimator.instance_count and estimator.instance_count > 1:
if estimator.pytorch_version:
if Version(estimator.pytorch_version) in SpecifierSet(">= 1.11"):
logger.warning(
"Consider setting 'distribution' to 'pytorchxla' for distributed "
"training with PyTorch >= 1.11 and SageMaker Training Compiler enabled."
)
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
"processors": ["gpu"],
"version_aliases": {
"4.11": "4.11.0",
"4.17": "4.17.0"
"4.17": "4.17.0",
"4.21": "4.21.1"
},
"versions": {
"4.11.0": {
Expand Down Expand Up @@ -97,6 +98,40 @@
"repository": "huggingface-tensorflow-trcomp-training",
"container_version": {"gpu":"cu112-ubuntu20.04"}
}
},
"4.21.1": {
"version_aliases": {
"pytorch1.11": "pytorch1.11.0"
},
"pytorch1.11.0": {
"py_versions": ["py38"],
"registries": {
"af-south-1": "626614931356",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ap-southeast-3": "907027046896",
"ca-central-1": "763104351884",
"eu-central-1": "763104351884",
"eu-north-1": "763104351884",
"eu-south-1": "692866216735",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
"us-west-1": "763104351884",
"us-west-2": "763104351884"
},
"repository": "huggingface-pytorch-trcomp-training",
"container_version": {"gpu":"cu113-ubuntu20.04"}
}
}
}
}
Expand Down
53 changes: 31 additions & 22 deletions src/sagemaker/training_compiler/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class TrainingCompilerConfig(object):
"""The SageMaker Training Compiler configuration class."""

DEBUG_PATH = "/opt/ml/output/data/compiler/"
SUPPORTED_INSTANCE_CLASS_PREFIXES = ["p3", "g4dn", "p4"]
SUPPORTED_INSTANCE_CLASS_PREFIXES = ["p3", "g4dn", "p4d", "g5"]

HP_ENABLE_COMPILER = "sagemaker_training_compiler_enabled"
HP_ENABLE_DEBUG = "sagemaker_training_compiler_debug_mode"
Expand Down Expand Up @@ -123,7 +123,7 @@ def validate(
"""Checks if SageMaker Training Compiler is configured correctly.

Args:
estimator (str): A estimator object
estimator (:class:`sagemaker.estimator.Estimator`): An estimator object.
When SageMaker Training Compiler is enabled, it validates if
the estimator is configured to be compatible with Training Compiler.

Expand All @@ -132,31 +132,31 @@ def validate(
ValueError: Raised if the requested configuration is not compatible
with SageMaker Training Compiler.
"""

if "local" not in estimator.instance_type:
requested_instance_class = estimator.instance_type.split(".")[
1
] # Expecting ml.class.size
if not any(
[
requested_instance_class.startswith(i)
for i in cls.SUPPORTED_INSTANCE_CLASS_PREFIXES
]
):
if estimator.instance_type:
if "local" not in estimator.instance_type:
requested_instance_class = estimator.instance_type.split(".")[
1
] # Expecting ml.class.size
if not any(
[
requested_instance_class.startswith(i)
for i in cls.SUPPORTED_INSTANCE_CLASS_PREFIXES
]
):
error_helper_string = "Unsupported Instance class {}. SageMaker Training Compiler only supports {}"
error_helper_string = error_helper_string.format(
requested_instance_class, cls.SUPPORTED_INSTANCE_CLASS_PREFIXES
)
raise ValueError(error_helper_string)
elif estimator.instance_type == "local":
error_helper_string = (
"Unsupported Instance class {}. SageMaker Training Compiler only supports {}"
"The local mode is not supported by SageMaker Training Compiler."
"It only supports the following GPU instances: {}"
)
error_helper_string = error_helper_string.format(
requested_instance_class, cls.SUPPORTED_INSTANCE_CLASS_PREFIXES
cls.SUPPORTED_INSTANCE_CLASS_PREFIXES
)
raise ValueError(error_helper_string)
elif estimator.instance_type == "local":
error_helper_string = (
"The local mode is not supported by SageMaker Training Compiler."
"It only supports the following GPU instances: {}"
)
error_helper_string = error_helper_string.format(cls.SUPPORTED_INSTANCE_CLASS_PREFIXES)
raise ValueError(error_helper_string)

if estimator.distribution and "smdistributed" in estimator.distribution:
raise ValueError(
Expand All @@ -180,3 +180,12 @@ def validate(
estimator.debugger_hook_config, estimator.disable_profiler
)
logger.warning(helper_string)

if estimator.instance_groups:
raise ValueError(
"SageMaker Training Compiler currently only supports homogeneous clusters of "
"the following GPU instance families: {}. Please use the 'instance_type' "
"and 'instance_count' parameters instead of 'instance_groups'".format(
cls.SUPPORTED_INSTANCE_CLASS_PREFIXES
)
)
26 changes: 20 additions & 6 deletions tests/integ/test_training_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,14 @@
from tests.integ.timeout import timeout


@pytest.fixture(scope="module")
def gpu_instance_type(request):
return "ml.p3.2xlarge"


def instance_count(request):
return 1


@pytest.fixture(scope="module")
def imagenet_val_set(request, sagemaker_session, tmpdir_factory):
"""
Expand Down Expand Up @@ -63,20 +66,31 @@ def huggingface_dummy_dataset(request, sagemaker_session):


@pytest.fixture(scope="module", autouse=True)
def skip_if_incompatible(request):
def skip_if_incompatible(gpu_instance_type, request):
"""
These tests are for training compiler enabled images/estimators only.
"""
if integ.test_region() not in integ.TRAINING_COMPILER_SUPPORTED_REGIONS:
region = integ.test_region()
if region not in integ.TRAINING_COMPILER_SUPPORTED_REGIONS:
pytest.skip("SageMaker Training Compiler is not supported in this region")
if integ.test_region() in integ.TRAINING_NO_P3_REGIONS:
if gpu_instance_type == "ml.p3.16xlarge" and region not in integ.DATA_PARALLEL_TESTING_REGIONS:
pytest.skip("Data parallel testing is not allowed in this region")
if gpu_instance_type == "ml.p3.2xlarge" and region in integ.TRAINING_NO_P3_REGIONS:
pytest.skip("no ml.p3 instances in this region")


@pytest.mark.release
@pytest.mark.parametrize(
"gpu_instance_type instance_count",
[
("ml.p3.2xlarge", 1),
("ml.p3.16xlarge", 2),
],
)
def test_huggingface_pytorch(
sagemaker_session,
gpu_instance_type,
instance_count,
huggingface_training_compiler_latest_version,
huggingface_training_compiler_pytorch_latest_version,
huggingface_dummy_dataset,
Expand All @@ -93,7 +107,7 @@ def test_huggingface_pytorch(
role="SageMakerRole",
transformers_version=huggingface_training_compiler_latest_version,
pytorch_version=huggingface_training_compiler_pytorch_latest_version,
instance_count=1,
instance_count=instance_count,
instance_type=gpu_instance_type,
hyperparameters={
"model_name_or_path": "distilbert-base-cased",
Expand All @@ -105,10 +119,10 @@ def test_huggingface_pytorch(
"per_device_train_batch_size": 128,
"output_dir": "/opt/ml/model",
},
environment={"GPU_NUM_DEVICES": "1"},
sagemaker_session=sagemaker_session,
disable_profiler=True,
compiler_config=HFTrainingCompilerConfig(),
distribution={"pytorchxla": {"enabled": True}} if instance_count > 1 else None,
)

hf.fit(huggingface_dummy_dataset)
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/sagemaker/training_compiler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@
# language governing permissions and limitations under the License.
from __future__ import absolute_import

EC2_GPU_INSTANCE_CLASSES = {"p2", "g4dn", "g4ad", "p3", "p3dn", "p4dn"}
EC2_GPU_INSTANCE_CLASSES = {"p2", "g4dn", "g4ad", "p3", "p3dn", "p4d", "g5"}
Loading