Skip to content

Commit 0f715e0

Browse files
feature: Add support for torch_distributed distribution strategy for Trainium instances
1 parent 0914f17 commit 0f715e0

File tree

7 files changed

+499
-1
lines changed

7 files changed

+499
-1
lines changed

doc/frameworks/pytorch/using_pytorch.rst

+97
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,104 @@ using two ``ml.p4d.24xlarge`` instances:
292292
293293
pt_estimator.fit("s3://bucket/path/to/training/data")
294294
295+
.. _distributed-pytorch-training-on-trainium:
296+
=============================================
295297

298+
SageMaker Training on Trainium instances now supports the `xla`
299+
package through `torchrun`. With this, you do not need to manually pass RANK,
300+
WORLD_SIZE, MASTER_ADDR, and MASTER_PORT. You can launch the training job using the
301+
:class:`sagemaker.pytorch.estimator.PyTorch` estimator class
302+
with the ``torch_distributed`` option as the distribution strategy.
303+
304+
.. note::
305+
306+
This ``torch_distributed`` support is available
307+
in the SageMaker Trainium (trn1) PyTorch Deep Learning Containers starting v1.11.0.
308+
To find a complete list of supported versions of PyTorch Neuron, see `Neuron Containers <https://github.com/aws/deep-learning-containers/blob/master/available_images.md#neuron-containers>`_ in the *AWS Deep Learning Containers GitHub repository*.
309+
310+
SageMaker Debugger and Profiler are currently not supported with Trainium instances.
311+
312+
Adapt Your Training Script
313+
--------------------------
314+
315+
To initialize distributed training in your script, call
316+
`torch.distributed.init_process_group
317+
<https://pytorch.org/docs/master/distributed.html#torch.distributed.init_process_group>`_
318+
with the ``xla`` backend as shown below.
319+
320+
.. code:: python
321+
322+
import torch.distributed as dist
323+
324+
dist.init_process_group('xla')
325+
326+
SageMaker takes care of ``'MASTER_ADDR'`` and ``'MASTER_PORT'`` for you via ``torchrun``
327+
328+
For detailed documentation about modifying your training script for Trainium, see `Multi-worker data-parallel MLP training using torchrun <https://awsdocs-neuron.readthedocs-hosted.com/en/latest/frameworks/torch/torch-neuronx/tutorials/training/mlp.html?highlight=torchrun#multi-worker-data-parallel-mlp-training-using-torchrun>`_ in the *AWS Neuron Documentation*.
329+
330+
**Currently Supported backends:**
331+
332+
- ``xla`` for Trainium (Trn1) instances
333+
334+
For up-to-date information on supported backends for Trainium instances, see `AWS Neuron Documentation <https://awsdocs-neuron.readthedocs-hosted.com/en/latest/index.html>`_.
335+
336+
Launching a Distributed Training Job
337+
------------------------------------
338+
339+
You can run multi-node distributed PyTorch training jobs on Trainium instances using the
340+
:class:`sagemaker.pytorch.estimator.PyTorch` estimator class.
341+
With ``instance_count=1``, the estimator submits a
342+
single-node training job to SageMaker; with ``instance_count`` greater
343+
than one, a multi-node training job is launched.
344+
345+
With the ``torch_distributed`` option, the SageMaker PyTorch estimator runs a SageMaker
346+
training container for PyTorch Neuron, sets up the environment, and launches
347+
the training job using the ``torchrun`` command on each worker with the given information.
348+
349+
.. note::
350+
351+
The following example shows how to run a PyTorch training using ``torch_distributed`` in SageMaker
352+
using one ``ml.trn1.2xlarge`` and two ``ml.trn1.32xlarge`` instances:
353+
354+
.. code:: python
355+
356+
from sagemaker.pytorch import PyTorch
357+
358+
pt_estimator = PyTorch(
359+
entry_point="train_ptddp.py",
360+
role="SageMakerRole",
361+
framework_version="1.11.0",
362+
py_version="py38",
363+
instance_count=1,
364+
instance_type="ml.trn1.2xlarge",
365+
distribution={
366+
"torch_distributed": {
367+
"enabled": True
368+
}
369+
}
370+
)
371+
372+
pt_estimator.fit("s3://bucket/path/to/training/data")
373+
374+
.. code:: python
375+
376+
from sagemaker.pytorch import PyTorch
377+
378+
pt_estimator = PyTorch(
379+
entry_point="train_ptddp.py",
380+
role="SageMakerRole",
381+
framework_version="1.11.0",
382+
py_version="py38",
383+
instance_count=2,
384+
instance_type="ml.trn1.32xlarge",
385+
distribution={
386+
"torch_distributed": {
387+
"enabled": True
388+
}
389+
}
390+
)
391+
392+
pt_estimator.fit("s3://bucket/path/to/training/data")
296393
297394
*********************
298395
Deploy PyTorch Models

src/sagemaker/fw_utils.py

+147-1
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,15 @@
134134
"1.12.0",
135135
]
136136

137+
TORCH_DISTRIBUTED_SUPPORTED_FRAMEWORK_VERSIONS = [
138+
"1.11",
139+
"1.11.0"
140+
]
141+
142+
TRAINIUM_SUPPORTED_DISTRIBUTION_STRATEGIES = [
143+
"torch_distributed"
144+
]
145+
137146
SMDISTRIBUTED_SUPPORTED_STRATEGIES = ["dataparallel", "modelparallel"]
138147

139148

@@ -701,7 +710,7 @@ def _validate_smdataparallel_args(
701710

702711

703712
def validate_distribution(
704-
distribution, instance_groups, framework_name, framework_version, py_version, image_uri, kwargs
713+
distribution, instance_groups, framework_name, framework_version, py_version, image_uri, entry_point, kwargs
705714
):
706715
"""Check if distribution strategy is correctly invoked by the user.
707716
@@ -726,6 +735,8 @@ def validate_distribution(
726735
framework_version (str): A string representing the framework version selected.
727736
py_version (str): A string representing the python version selected.
728737
image_uri (str): A string representing a Docker image URI.
738+
entry_point (str or PipelineVariable): Path (absolute or relative) to the
739+
Python source file which should be executed as the entry point to training.
729740
kwargs(dict): Additional kwargs passed to this function
730741
731742
Returns:
@@ -767,6 +778,10 @@ def validate_distribution(
767778
f"Invalid training instance group {train_instance_group.instance_group_name} !"
768779
)
769780
instance_type = train_instance_group.instance_type
781+
validate_distribution_for_instance_type(
782+
instance_type=instance_type,
783+
distribution=distribution,
784+
)
770785
validate_smdistributed(
771786
instance_type=instance_type,
772787
framework_name=framework_name,
@@ -782,6 +797,15 @@ def validate_distribution(
782797
py_version=py_version,
783798
image_uri=image_uri,
784799
)
800+
validate_torch_distributed_distribution(
801+
instance_type=instance_type,
802+
distribution=distribution,
803+
framework_name=framework_name,
804+
framework_version=framework_version,
805+
py_version=py_version,
806+
image_uri=image_uri,
807+
entry_point=entry_point,
808+
)
785809
warn_if_parameter_server_with_multi_gpu(
786810
training_instance_type=instance_type, distribution=distribution
787811
)
@@ -793,6 +817,10 @@ def validate_distribution(
793817
instance_type = renamed_kwargs(
794818
"train_instance_type", "instance_type", kwargs.get("instance_type"), kwargs
795819
)
820+
validate_distribution_for_instance_type(
821+
instance_type=instance_type,
822+
distribution=distribution,
823+
)
796824
validate_smdistributed(
797825
instance_type=instance_type,
798826
framework_name=framework_name,
@@ -808,11 +836,53 @@ def validate_distribution(
808836
py_version=py_version,
809837
image_uri=image_uri,
810838
)
839+
validate_torch_distributed_distribution(
840+
instance_type=instance_type,
841+
distribution=distribution,
842+
framework_name=framework_name,
843+
framework_version=framework_version,
844+
py_version=py_version,
845+
image_uri=image_uri,
846+
entry_point=entry_point,
847+
)
811848
warn_if_parameter_server_with_multi_gpu(
812849
training_instance_type=instance_type, distribution=distribution
813850
)
814851
return distribution
815852

853+
def validate_distribution_for_instance_type(
854+
instance_type, distribution
855+
):
856+
"""Check if the provided distribution strategy is supported for the instance_type
857+
858+
Args:
859+
instance_type (str): A string representing the type of training instance selected.
860+
distribution (dict): A dictionary with information to enable distributed training.
861+
"""
862+
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
863+
err_msg = ""
864+
if match and match[1].startswith("trn"):
865+
keys = distribution.keys()
866+
if len(keys) == 0:
867+
return
868+
elif len(keys) == 1:
869+
distribution_strategy = keys[0]
870+
if distribution_strategy != "torch_distributed":
871+
err_msg += (
872+
f"Provided distribution strategy {distribution_strategy} is not supported for"
873+
" Trainium instances.\n"
874+
"Please specify one of the following supported distribution strategies:"
875+
f" {TRAINIUM_SUPPORTED_DISTRIBUTION_STRATEGIES} \n"
876+
)
877+
elif len(keys) > 1:
878+
err_msg += (
879+
f"Multiple distribution strategies are not supported for Trainium instances.\n"
880+
"Please specify one of the following supported distribution strategies:"
881+
f" {TRAINIUM_SUPPORTED_DISTRIBUTION_STRATEGIES} "
882+
)
883+
884+
if err_msg:
885+
raise ValueError(err_msg)
816886

817887
def validate_pytorch_distribution(
818888
distribution, framework_name, framework_version, py_version, image_uri
@@ -870,6 +940,82 @@ def validate_pytorch_distribution(
870940
if err_msg:
871941
raise ValueError(err_msg)
872942

943+
def validate_torch_distributed_distribution(
944+
instance_type, distribution, framework_name, framework_version, py_version, image_uri, entry_point,
945+
):
946+
"""Check if torch_distributed distribution strategy is correctly invoked by the user.
947+
948+
Args:
949+
instance_type (str): A string representing the type of training instance selected.
950+
distribution (dict): A dictionary with information to enable distributed training.
951+
(Defaults to None if distributed training is not enabled.) For example:
952+
953+
.. code:: python
954+
955+
{
956+
"torch_distributed": {
957+
"enabled": True
958+
}
959+
}
960+
framework_name (str): A string representing the name of framework selected.
961+
framework_version (str): A string representing the framework version selected.
962+
py_version (str): A string representing the python version selected.
963+
image_uri (str): A string representing a Docker image URI.
964+
entry_point (str): Path (absolute or relative) to the
965+
Python source file which should be executed as the entry point to training.
966+
967+
Raises:
968+
ValueError: if
969+
`py_version` is not python3 or
970+
`framework_version` is not in TORCH_DISTRIBUTED_SUPPORTED_FRAMEWORK_VERSIONS
971+
"""
972+
if framework_name and framework_name != "pytorch":
973+
# We need to validate only for PyTorch framework
974+
return
975+
976+
torch_distributed_enabled = False
977+
if "torch_distributed" in distribution:
978+
torch_distributed_enabled = distribution.get("torch_distributed").get("enabled", False)
979+
if not torch_distributed_enabled:
980+
# Distribution strategy other than torch_distributed is selected
981+
return
982+
983+
err_msg = ""
984+
if not image_uri:
985+
# ignore framework_version and py_version if image_uri is set
986+
# in case image_uri is not set, then both are mandatory
987+
if framework_version not in TORCH_DISTRIBUTED_SUPPORTED_FRAMEWORK_VERSIONS:
988+
err_msg += (
989+
f"Provided framework_version {framework_version} is not supported by"
990+
" torch_distributed.\n"
991+
"Please specify one of the supported framework versions:"
992+
f" {TORCH_DISTRIBUTED_SUPPORTED_FRAMEWORK_VERSIONS} \n"
993+
)
994+
if "py3" not in py_version:
995+
err_msg += (
996+
f"Provided py_version {py_version} is not supported by torch_distributed.\n"
997+
"Please specify py_version>=py3"
998+
)
999+
1000+
# Check instance compatibility
1001+
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
1002+
if match and match[1].startswith("trn"):
1003+
return
1004+
else:
1005+
err_msg += (
1006+
f"torch_distributed is currently supported only for trainium instances."
1007+
" Please refer https://sagemaker.readthedocs.io/en/stable/frameworks/pytorch/using_pytorch.html#distributed-pytorch-training \
1008+
for information regarding distributed training on non-trainium instances"
1009+
)
1010+
1011+
# Check entry point type
1012+
if not entry_point.endswith(".py"):
1013+
err_msg += ("Unsupported entry point type for torch_distributed.\n"
1014+
"Only python programs (*.py) are supported."
1015+
)
1016+
1017+
if err_msg:
1018+
raise ValueError(err_msg)
8731019

8741020
def python_deprecation_warning(framework, latest_supported_version):
8751021
"""Placeholder docstring"""

src/sagemaker/pytorch/estimator.py

+21
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ class PyTorch(Framework):
3939

4040
_framework_name = "pytorch"
4141
LAUNCH_PYTORCH_DDP_ENV_NAME = "sagemaker_pytorch_ddp_enabled"
42+
LAUNCH_TORCH_DISTRIBUTED_ENV_NAME = "sagemaker_torch_distributed_enabled"
4243
INSTANCE_TYPE_ENV_NAME = "sagemaker_instance_type"
4344

4445
def __init__(
@@ -167,6 +168,17 @@ def __init__(
167168
168169
To learn more, see `Distributed PyTorch Training
169170
<https://sagemaker.readthedocs.io/en/stable/frameworks/pytorch/using_pytorch.html#distributed-pytorch-training>`_.
171+
172+
**To enable Torch Distributed (Trainium Instances):**
173+
174+
.. code:: python
175+
{
176+
"torch_distributed": {
177+
"enabled": True
178+
}
179+
}
180+
To learn more, see `Distributed PyTorch Training on Trainium
181+
<https://sagemaker.readthedocs.io/en/stable/frameworks/pytorch/using_pytorch.html#distributed-pytorch-training-on-trainium>`_.
170182
171183
**To enable MPI:**
172184
@@ -227,6 +239,7 @@ def __init__(
227239
framework_version,
228240
py_version,
229241
image_uri,
242+
entry_point,
230243
kwargs,
231244
)
232245

@@ -242,13 +255,21 @@ def _pytorch_distribution_configuration(self, distribution):
242255
"""
243256
distribution_config = {}
244257
pytorch_ddp_enabled = False
258+
torch_distributed_enabled = False
259+
245260
if "pytorchddp" in distribution:
246261
pytorch_ddp_enabled = distribution.get("pytorchddp").get("enabled", False)
262+
elif "torch_distributed" in distribution:
263+
torch_distributed_enabled = distribution.get("torch_distributed").get("enabled", False)
247264

248265
if pytorch_ddp_enabled:
249266
distribution_config[self.LAUNCH_PYTORCH_DDP_ENV_NAME] = pytorch_ddp_enabled
250267
if self.instance_type is not None:
251268
distribution_config[self.INSTANCE_TYPE_ENV_NAME] = self.instance_type
269+
elif torch_distributed_enabled:
270+
distribution_config[self.LAUNCH_TORCH_DISTRIBUTED_ENV_NAME] = torch_distributed_enabled
271+
if self.instance_type is not None:
272+
distribution_config[self.INSTANCE_TYPE_ENV_NAME] = self.instance_type
252273
else:
253274
distribution_config = self._distribution_configuration(distribution=distribution)
254275

tests/conftest.py

+11
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,17 @@ def pytorch_ddp_framework_version(request):
447447
return request.param
448448

449449

450+
@pytest.fixture(scope="module")
451+
def torch_distributed_py_version():
452+
return "py3"
453+
454+
455+
@pytest.fixture(
456+
scope="module", params=["1.11.0"]
457+
)
458+
def torch_distributed_framework_version(request):
459+
return request.param
460+
450461
@pytest.fixture(scope="session")
451462
def cpu_instance_type(sagemaker_session, request):
452463
region = sagemaker_session.boto_session.region_name

0 commit comments

Comments
 (0)