Skip to content

feature: support torch_distributed distribution for Trainium instances #3424

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 101 additions & 1 deletion doc/frameworks/pytorch/using_pytorch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,8 @@ during the PyTorch DDP initialization.

.. note::

The SageMaker PyTorch estimator doesn’t use ``torchrun`` for distributed training.
The SageMaker PyTorch estimator operates ``mpirun`` in the backend.
It doesn’t use ``torchrun`` for distributed training.

For more information about setting up PyTorch DDP in your training script,
see `Getting Started with Distributed Data Parallel
Expand Down Expand Up @@ -292,7 +293,106 @@ using two ``ml.p4d.24xlarge`` instances:

pt_estimator.fit("s3://bucket/path/to/training/data")

.. _distributed-pytorch-training-on-trainium:

Distributed PyTorch Training on Trainium
========================================

SageMaker Training on Trainium instances now supports the ``xla``
package through ``torchrun``. With this, you do not need to manually pass RANK,
WORLD_SIZE, MASTER_ADDR, and MASTER_PORT. You can launch the training job using the
:class:`sagemaker.pytorch.estimator.PyTorch` estimator class
with the ``torch_distributed`` option as the distribution strategy.

.. note::

This ``torch_distributed`` support is available
in the SageMaker Trainium (trn1) PyTorch Deep Learning Containers starting v1.11.0.
To find a complete list of supported versions of PyTorch Neuron, see `Neuron Containers <https://github.com/aws/deep-learning-containers/blob/master/available_images.md#neuron-containers>`_ in the *AWS Deep Learning Containers GitHub repository*.

SageMaker Debugger and Profiler are currently not supported with Trainium instances.

Adapt Your Training Script to Initialize with the XLA backend
-------------------------------------------------------------

To initialize distributed training in your script, call
`torch.distributed.init_process_group
<https://pytorch.org/docs/master/distributed.html#torch.distributed.init_process_group>`_
with the ``xla`` backend as shown below.

.. code:: python

import torch.distributed as dist

dist.init_process_group('xla')

SageMaker takes care of ``'MASTER_ADDR'`` and ``'MASTER_PORT'`` for you via ``torchrun``

For detailed documentation about modifying your training script for Trainium, see `Multi-worker data-parallel MLP training using torchrun <https://awsdocs-neuron.readthedocs-hosted.com/en/latest/frameworks/torch/torch-neuronx/tutorials/training/mlp.html?highlight=torchrun#multi-worker-data-parallel-mlp-training-using-torchrun>`_ in the *AWS Neuron Documentation*.

**Currently Supported backends:**

- ``xla`` for Trainium (Trn1) instances

For up-to-date information on supported backends for Trainium instances, see `AWS Neuron Documentation <https://awsdocs-neuron.readthedocs-hosted.com/en/latest/index.html>`_.

Launching a Distributed Training Job on Trainium
------------------------------------------------

You can run multi-node distributed PyTorch training jobs on Trainium instances using the
:class:`sagemaker.pytorch.estimator.PyTorch` estimator class.
With ``instance_count=1``, the estimator submits a
single-node training job to SageMaker; with ``instance_count`` greater
than one, a multi-node training job is launched.

With the ``torch_distributed`` option, the SageMaker PyTorch estimator runs a SageMaker
training container for PyTorch Neuron, sets up the environment, and launches
the training job using the ``torchrun`` command on each worker with the given information.

**Examples**

The following examples show how to run a PyTorch training using ``torch_distributed`` in SageMaker
on one ``ml.trn1.2xlarge`` instance and two ``ml.trn1.32xlarge`` instances:

.. code:: python

from sagemaker.pytorch import PyTorch

pt_estimator = PyTorch(
entry_point="train_ptddp.py",
role="SageMakerRole",
framework_version="1.11.0",
py_version="py38",
instance_count=1,
instance_type="ml.trn1.2xlarge",
distribution={
"torch_distributed": {
"enabled": True
}
}
)

pt_estimator.fit("s3://bucket/path/to/training/data")

.. code:: python

from sagemaker.pytorch import PyTorch

pt_estimator = PyTorch(
entry_point="train_ptddp.py",
role="SageMakerRole",
framework_version="1.11.0",
py_version="py38",
instance_count=2,
instance_type="ml.trn1.32xlarge",
distribution={
"torch_distributed": {
"enabled": True
}
}
)

pt_estimator.fit("s3://bucket/path/to/training/data")

*********************
Deploy PyTorch Models
Expand Down
182 changes: 167 additions & 15 deletions src/sagemaker/fw_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,10 @@
"1.12.0",
]

TORCH_DISTRIBUTED_SUPPORTED_FRAMEWORK_VERSIONS = ["1.11", "1.11.0"]

TRAINIUM_SUPPORTED_DISTRIBUTION_STRATEGIES = ["torch_distributed"]

SMDISTRIBUTED_SUPPORTED_STRATEGIES = ["dataparallel", "modelparallel"]


Expand Down Expand Up @@ -701,7 +705,13 @@ def _validate_smdataparallel_args(


def validate_distribution(
distribution, instance_groups, framework_name, framework_version, py_version, image_uri, kwargs
distribution,
instance_groups,
framework_name,
framework_version,
py_version,
image_uri,
kwargs,
):
"""Check if distribution strategy is correctly invoked by the user.

Expand Down Expand Up @@ -767,21 +777,35 @@ def validate_distribution(
f"Invalid training instance group {train_instance_group.instance_group_name} !"
)
instance_type = train_instance_group.instance_type
validate_smdistributed(
validate_distribution_for_instance_type(
instance_type=instance_type,
framework_name=framework_name,
framework_version=framework_version,
py_version=py_version,
distribution=distribution,
image_uri=image_uri,
)
validate_pytorch_distribution(
distribution=distribution,
validate_smdistributed(
instance_type=instance_type,
framework_name=framework_name,
framework_version=framework_version,
py_version=py_version,
distribution=distribution,
image_uri=image_uri,
)
if framework_name and framework_name == "pytorch":
# We need to validate only for PyTorch framework
validate_pytorch_distribution(
distribution=distribution,
framework_name=framework_name,
framework_version=framework_version,
py_version=py_version,
image_uri=image_uri,
)
validate_torch_distributed_distribution(
instance_type=instance_type,
distribution=distribution,
framework_version=framework_version,
py_version=py_version,
image_uri=image_uri,
entry_point=kwargs["entry_point"],
)
warn_if_parameter_server_with_multi_gpu(
training_instance_type=instance_type, distribution=distribution
)
Expand All @@ -793,27 +817,75 @@ def validate_distribution(
instance_type = renamed_kwargs(
"train_instance_type", "instance_type", kwargs.get("instance_type"), kwargs
)
validate_smdistributed(
validate_distribution_for_instance_type(
instance_type=instance_type,
framework_name=framework_name,
framework_version=framework_version,
py_version=py_version,
distribution=distribution,
image_uri=image_uri,
)
validate_pytorch_distribution(
distribution=distribution,
validate_smdistributed(
instance_type=instance_type,
framework_name=framework_name,
framework_version=framework_version,
py_version=py_version,
distribution=distribution,
image_uri=image_uri,
)
if framework_name and framework_name == "pytorch":
# We need to validate only for PyTorch framework
validate_pytorch_distribution(
distribution=distribution,
framework_name=framework_name,
framework_version=framework_version,
py_version=py_version,
image_uri=image_uri,
)
validate_torch_distributed_distribution(
instance_type=instance_type,
distribution=distribution,
framework_version=framework_version,
py_version=py_version,
image_uri=image_uri,
entry_point=kwargs["entry_point"],
)
warn_if_parameter_server_with_multi_gpu(
training_instance_type=instance_type, distribution=distribution
)
return distribution


def validate_distribution_for_instance_type(instance_type, distribution):
"""Check if the provided distribution strategy is supported for the instance_type

Args:
instance_type (str): A string representing the type of training instance selected.
distribution (dict): A dictionary with information to enable distributed training.
"""
err_msg = ""
if isinstance(instance_type, str):
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
if match and match[1].startswith("trn"):
keys = list(distribution.keys())
if len(keys) == 0:
return
if len(keys) == 1:
distribution_strategy = keys[0]
if distribution_strategy != "torch_distributed":
err_msg += (
f"Provided distribution strategy {distribution_strategy} is not supported"
" for Trainium instances.\n"
"Please specify one of the following supported distribution strategies:"
f" {TRAINIUM_SUPPORTED_DISTRIBUTION_STRATEGIES} \n"
)
elif len(keys) > 1:
err_msg += (
"Multiple distribution strategies are not supported for Trainium instances.\n"
"Please specify one of the following supported distribution strategies:"
f" {TRAINIUM_SUPPORTED_DISTRIBUTION_STRATEGIES} "
)

if err_msg:
raise ValueError(err_msg)


def validate_pytorch_distribution(
distribution, framework_name, framework_version, py_version, image_uri
):
Expand Down Expand Up @@ -871,6 +943,86 @@ def validate_pytorch_distribution(
raise ValueError(err_msg)


def validate_torch_distributed_distribution(
instance_type,
distribution,
framework_version,
py_version,
image_uri,
entry_point,
):
"""Check if torch_distributed distribution strategy is correctly invoked by the user.

Args:
instance_type (str): A string representing the type of training instance selected.
distribution (dict): A dictionary with information to enable distributed training.
(Defaults to None if distributed training is not enabled.) For example:

.. code:: python

{
"torch_distributed": {
"enabled": True
}
}
framework_version (str): A string representing the framework version selected.
py_version (str): A string representing the python version selected.
image_uri (str): A string representing a Docker image URI.
entry_point (str or PipelineVariable): The absolute or relative path to the local Python
source file that should be executed as the entry point to
training.

Raises:
ValueError: if
`py_version` is not python3 or
`framework_version` is not in TORCH_DISTRIBUTED_SUPPORTED_FRAMEWORK_VERSIONS
"""

torch_distributed_enabled = False
if "torch_distributed" in distribution:
torch_distributed_enabled = distribution.get("torch_distributed").get("enabled", False)
if not torch_distributed_enabled:
# Distribution strategy other than torch_distributed is selected
return

err_msg = ""
if not image_uri:
# ignore framework_version and py_version if image_uri is set
# in case image_uri is not set, then both are mandatory
if framework_version not in TORCH_DISTRIBUTED_SUPPORTED_FRAMEWORK_VERSIONS:
err_msg += (
f"Provided framework_version {framework_version} is not supported by"
" torch_distributed.\n"
"Please specify one of the supported framework versions:"
f" {TORCH_DISTRIBUTED_SUPPORTED_FRAMEWORK_VERSIONS} \n"
)
if "py3" not in py_version:
err_msg += (
f"Provided py_version {py_version} is not supported by torch_distributed.\n"
"Please specify py_version>=py3"
)

# Check instance compatibility
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
if match:
if not match[1].startswith("trn"):
err_msg += (
"torch_distributed is currently supported only for trainium instances.\n"
" Please refer https://sagemaker.readthedocs.io/en/stable/frameworks/pytorch/using_pytorch.html#distributed-pytorch-training \n" # noqa E501 # pylint: disable=c0301
"for information regarding distributed training on non-trainium instances"
)

# Check entry point type
if not entry_point.endswith(".py"):
err_msg += (
"Unsupported entry point type for the distribution torch_distributed.\n"
"Only python programs (*.py) are supported."
)

if err_msg:
raise ValueError(err_msg)


def python_deprecation_warning(framework, latest_supported_version):
"""Placeholder docstring"""
return PYTHON_2_DEPRECATION_WARNING.format(
Expand Down
Loading