Skip to content

Commit 0d5c634

Browse files
feature: support torch_distributed distribution for Trainium instances (#3424)
1 parent 5e7aa6b commit 0d5c634

File tree

7 files changed

+545
-19
lines changed

7 files changed

+545
-19
lines changed

doc/frameworks/pytorch/using_pytorch.rst

+101-1
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,8 @@ during the PyTorch DDP initialization.
262262

263263
.. note::
264264

265-
The SageMaker PyTorch estimator doesn’t use ``torchrun`` for distributed training.
265+
The SageMaker PyTorch estimator operates ``mpirun`` in the backend.
266+
It doesn’t use ``torchrun`` for distributed training.
266267

267268
For more information about setting up PyTorch DDP in your training script,
268269
see `Getting Started with Distributed Data Parallel
@@ -292,7 +293,106 @@ using two ``ml.p4d.24xlarge`` instances:
292293
293294
pt_estimator.fit("s3://bucket/path/to/training/data")
294295
296+
.. _distributed-pytorch-training-on-trainium:
295297

298+
Distributed PyTorch Training on Trainium
299+
========================================
300+
301+
SageMaker Training on Trainium instances now supports the ``xla``
302+
package through ``torchrun``. With this, you do not need to manually pass RANK,
303+
WORLD_SIZE, MASTER_ADDR, and MASTER_PORT. You can launch the training job using the
304+
:class:`sagemaker.pytorch.estimator.PyTorch` estimator class
305+
with the ``torch_distributed`` option as the distribution strategy.
306+
307+
.. note::
308+
309+
This ``torch_distributed`` support is available
310+
in the SageMaker Trainium (trn1) PyTorch Deep Learning Containers starting v1.11.0.
311+
To find a complete list of supported versions of PyTorch Neuron, see `Neuron Containers <https://github.com/aws/deep-learning-containers/blob/master/available_images.md#neuron-containers>`_ in the *AWS Deep Learning Containers GitHub repository*.
312+
313+
SageMaker Debugger and Profiler are currently not supported with Trainium instances.
314+
315+
Adapt Your Training Script to Initialize with the XLA backend
316+
-------------------------------------------------------------
317+
318+
To initialize distributed training in your script, call
319+
`torch.distributed.init_process_group
320+
<https://pytorch.org/docs/master/distributed.html#torch.distributed.init_process_group>`_
321+
with the ``xla`` backend as shown below.
322+
323+
.. code:: python
324+
325+
import torch.distributed as dist
326+
327+
dist.init_process_group('xla')
328+
329+
SageMaker takes care of ``'MASTER_ADDR'`` and ``'MASTER_PORT'`` for you via ``torchrun``
330+
331+
For detailed documentation about modifying your training script for Trainium, see `Multi-worker data-parallel MLP training using torchrun <https://awsdocs-neuron.readthedocs-hosted.com/en/latest/frameworks/torch/torch-neuronx/tutorials/training/mlp.html?highlight=torchrun#multi-worker-data-parallel-mlp-training-using-torchrun>`_ in the *AWS Neuron Documentation*.
332+
333+
**Currently Supported backends:**
334+
335+
- ``xla`` for Trainium (Trn1) instances
336+
337+
For up-to-date information on supported backends for Trainium instances, see `AWS Neuron Documentation <https://awsdocs-neuron.readthedocs-hosted.com/en/latest/index.html>`_.
338+
339+
Launching a Distributed Training Job on Trainium
340+
------------------------------------------------
341+
342+
You can run multi-node distributed PyTorch training jobs on Trainium instances using the
343+
:class:`sagemaker.pytorch.estimator.PyTorch` estimator class.
344+
With ``instance_count=1``, the estimator submits a
345+
single-node training job to SageMaker; with ``instance_count`` greater
346+
than one, a multi-node training job is launched.
347+
348+
With the ``torch_distributed`` option, the SageMaker PyTorch estimator runs a SageMaker
349+
training container for PyTorch Neuron, sets up the environment, and launches
350+
the training job using the ``torchrun`` command on each worker with the given information.
351+
352+
**Examples**
353+
354+
The following examples show how to run a PyTorch training using ``torch_distributed`` in SageMaker
355+
on one ``ml.trn1.2xlarge`` instance and two ``ml.trn1.32xlarge`` instances:
356+
357+
.. code:: python
358+
359+
from sagemaker.pytorch import PyTorch
360+
361+
pt_estimator = PyTorch(
362+
entry_point="train_ptddp.py",
363+
role="SageMakerRole",
364+
framework_version="1.11.0",
365+
py_version="py38",
366+
instance_count=1,
367+
instance_type="ml.trn1.2xlarge",
368+
distribution={
369+
"torch_distributed": {
370+
"enabled": True
371+
}
372+
}
373+
)
374+
375+
pt_estimator.fit("s3://bucket/path/to/training/data")
376+
377+
.. code:: python
378+
379+
from sagemaker.pytorch import PyTorch
380+
381+
pt_estimator = PyTorch(
382+
entry_point="train_ptddp.py",
383+
role="SageMakerRole",
384+
framework_version="1.11.0",
385+
py_version="py38",
386+
instance_count=2,
387+
instance_type="ml.trn1.32xlarge",
388+
distribution={
389+
"torch_distributed": {
390+
"enabled": True
391+
}
392+
}
393+
)
394+
395+
pt_estimator.fit("s3://bucket/path/to/training/data")
296396
297397
*********************
298398
Deploy PyTorch Models

src/sagemaker/fw_utils.py

+167-15
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,10 @@
134134
"1.12.0",
135135
]
136136

137+
TORCH_DISTRIBUTED_SUPPORTED_FRAMEWORK_VERSIONS = ["1.11", "1.11.0"]
138+
139+
TRAINIUM_SUPPORTED_DISTRIBUTION_STRATEGIES = ["torch_distributed"]
140+
137141
SMDISTRIBUTED_SUPPORTED_STRATEGIES = ["dataparallel", "modelparallel"]
138142

139143

@@ -701,7 +705,13 @@ def _validate_smdataparallel_args(
701705

702706

703707
def validate_distribution(
704-
distribution, instance_groups, framework_name, framework_version, py_version, image_uri, kwargs
708+
distribution,
709+
instance_groups,
710+
framework_name,
711+
framework_version,
712+
py_version,
713+
image_uri,
714+
kwargs,
705715
):
706716
"""Check if distribution strategy is correctly invoked by the user.
707717
@@ -767,21 +777,35 @@ def validate_distribution(
767777
f"Invalid training instance group {train_instance_group.instance_group_name} !"
768778
)
769779
instance_type = train_instance_group.instance_type
770-
validate_smdistributed(
780+
validate_distribution_for_instance_type(
771781
instance_type=instance_type,
772-
framework_name=framework_name,
773-
framework_version=framework_version,
774-
py_version=py_version,
775782
distribution=distribution,
776-
image_uri=image_uri,
777783
)
778-
validate_pytorch_distribution(
779-
distribution=distribution,
784+
validate_smdistributed(
785+
instance_type=instance_type,
780786
framework_name=framework_name,
781787
framework_version=framework_version,
782788
py_version=py_version,
789+
distribution=distribution,
783790
image_uri=image_uri,
784791
)
792+
if framework_name and framework_name == "pytorch":
793+
# We need to validate only for PyTorch framework
794+
validate_pytorch_distribution(
795+
distribution=distribution,
796+
framework_name=framework_name,
797+
framework_version=framework_version,
798+
py_version=py_version,
799+
image_uri=image_uri,
800+
)
801+
validate_torch_distributed_distribution(
802+
instance_type=instance_type,
803+
distribution=distribution,
804+
framework_version=framework_version,
805+
py_version=py_version,
806+
image_uri=image_uri,
807+
entry_point=kwargs["entry_point"],
808+
)
785809
warn_if_parameter_server_with_multi_gpu(
786810
training_instance_type=instance_type, distribution=distribution
787811
)
@@ -793,27 +817,75 @@ def validate_distribution(
793817
instance_type = renamed_kwargs(
794818
"train_instance_type", "instance_type", kwargs.get("instance_type"), kwargs
795819
)
796-
validate_smdistributed(
820+
validate_distribution_for_instance_type(
797821
instance_type=instance_type,
798-
framework_name=framework_name,
799-
framework_version=framework_version,
800-
py_version=py_version,
801822
distribution=distribution,
802-
image_uri=image_uri,
803823
)
804-
validate_pytorch_distribution(
805-
distribution=distribution,
824+
validate_smdistributed(
825+
instance_type=instance_type,
806826
framework_name=framework_name,
807827
framework_version=framework_version,
808828
py_version=py_version,
829+
distribution=distribution,
809830
image_uri=image_uri,
810831
)
832+
if framework_name and framework_name == "pytorch":
833+
# We need to validate only for PyTorch framework
834+
validate_pytorch_distribution(
835+
distribution=distribution,
836+
framework_name=framework_name,
837+
framework_version=framework_version,
838+
py_version=py_version,
839+
image_uri=image_uri,
840+
)
841+
validate_torch_distributed_distribution(
842+
instance_type=instance_type,
843+
distribution=distribution,
844+
framework_version=framework_version,
845+
py_version=py_version,
846+
image_uri=image_uri,
847+
entry_point=kwargs["entry_point"],
848+
)
811849
warn_if_parameter_server_with_multi_gpu(
812850
training_instance_type=instance_type, distribution=distribution
813851
)
814852
return distribution
815853

816854

855+
def validate_distribution_for_instance_type(instance_type, distribution):
856+
"""Check if the provided distribution strategy is supported for the instance_type
857+
858+
Args:
859+
instance_type (str): A string representing the type of training instance selected.
860+
distribution (dict): A dictionary with information to enable distributed training.
861+
"""
862+
err_msg = ""
863+
if isinstance(instance_type, str):
864+
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
865+
if match and match[1].startswith("trn"):
866+
keys = list(distribution.keys())
867+
if len(keys) == 0:
868+
return
869+
if len(keys) == 1:
870+
distribution_strategy = keys[0]
871+
if distribution_strategy != "torch_distributed":
872+
err_msg += (
873+
f"Provided distribution strategy {distribution_strategy} is not supported"
874+
" for Trainium instances.\n"
875+
"Please specify one of the following supported distribution strategies:"
876+
f" {TRAINIUM_SUPPORTED_DISTRIBUTION_STRATEGIES} \n"
877+
)
878+
elif len(keys) > 1:
879+
err_msg += (
880+
"Multiple distribution strategies are not supported for Trainium instances.\n"
881+
"Please specify one of the following supported distribution strategies:"
882+
f" {TRAINIUM_SUPPORTED_DISTRIBUTION_STRATEGIES} "
883+
)
884+
885+
if err_msg:
886+
raise ValueError(err_msg)
887+
888+
817889
def validate_pytorch_distribution(
818890
distribution, framework_name, framework_version, py_version, image_uri
819891
):
@@ -871,6 +943,86 @@ def validate_pytorch_distribution(
871943
raise ValueError(err_msg)
872944

873945

946+
def validate_torch_distributed_distribution(
947+
instance_type,
948+
distribution,
949+
framework_version,
950+
py_version,
951+
image_uri,
952+
entry_point,
953+
):
954+
"""Check if torch_distributed distribution strategy is correctly invoked by the user.
955+
956+
Args:
957+
instance_type (str): A string representing the type of training instance selected.
958+
distribution (dict): A dictionary with information to enable distributed training.
959+
(Defaults to None if distributed training is not enabled.) For example:
960+
961+
.. code:: python
962+
963+
{
964+
"torch_distributed": {
965+
"enabled": True
966+
}
967+
}
968+
framework_version (str): A string representing the framework version selected.
969+
py_version (str): A string representing the python version selected.
970+
image_uri (str): A string representing a Docker image URI.
971+
entry_point (str or PipelineVariable): The absolute or relative path to the local Python
972+
source file that should be executed as the entry point to
973+
training.
974+
975+
Raises:
976+
ValueError: if
977+
`py_version` is not python3 or
978+
`framework_version` is not in TORCH_DISTRIBUTED_SUPPORTED_FRAMEWORK_VERSIONS
979+
"""
980+
981+
torch_distributed_enabled = False
982+
if "torch_distributed" in distribution:
983+
torch_distributed_enabled = distribution.get("torch_distributed").get("enabled", False)
984+
if not torch_distributed_enabled:
985+
# Distribution strategy other than torch_distributed is selected
986+
return
987+
988+
err_msg = ""
989+
if not image_uri:
990+
# ignore framework_version and py_version if image_uri is set
991+
# in case image_uri is not set, then both are mandatory
992+
if framework_version not in TORCH_DISTRIBUTED_SUPPORTED_FRAMEWORK_VERSIONS:
993+
err_msg += (
994+
f"Provided framework_version {framework_version} is not supported by"
995+
" torch_distributed.\n"
996+
"Please specify one of the supported framework versions:"
997+
f" {TORCH_DISTRIBUTED_SUPPORTED_FRAMEWORK_VERSIONS} \n"
998+
)
999+
if "py3" not in py_version:
1000+
err_msg += (
1001+
f"Provided py_version {py_version} is not supported by torch_distributed.\n"
1002+
"Please specify py_version>=py3"
1003+
)
1004+
1005+
# Check instance compatibility
1006+
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
1007+
if match:
1008+
if not match[1].startswith("trn"):
1009+
err_msg += (
1010+
"torch_distributed is currently supported only for trainium instances.\n"
1011+
" Please refer https://sagemaker.readthedocs.io/en/stable/frameworks/pytorch/using_pytorch.html#distributed-pytorch-training \n" # noqa E501 # pylint: disable=c0301
1012+
"for information regarding distributed training on non-trainium instances"
1013+
)
1014+
1015+
# Check entry point type
1016+
if not entry_point.endswith(".py"):
1017+
err_msg += (
1018+
"Unsupported entry point type for the distribution torch_distributed.\n"
1019+
"Only python programs (*.py) are supported."
1020+
)
1021+
1022+
if err_msg:
1023+
raise ValueError(err_msg)
1024+
1025+
8741026
def python_deprecation_warning(framework, latest_supported_version):
8751027
"""Placeholder docstring"""
8761028
return PYTHON_2_DEPRECATION_WARNING.format(

0 commit comments

Comments
 (0)