Skip to content

Commit 2fe3114

Browse files
satishpasumarthiclaytonparnell
authored andcommitted
feature: Add support for torch_distributed distribution for Trainium instances
1 parent f2d5e41 commit 2fe3114

File tree

7 files changed

+500
-0
lines changed

7 files changed

+500
-0
lines changed

doc/frameworks/pytorch/using_pytorch.rst

+98
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,105 @@ using two ``ml.p4d.24xlarge`` instances:
292292
293293
pt_estimator.fit("s3://bucket/path/to/training/data")
294294
295+
Distributed PyTorch Training on Trainum (trn1) Instances
296+
========================================================
295297

298+
SageMaker Training on Trainium instances now supports the `xla`
299+
package through `torchrun`. With this, you do not need to manually pass RANK,
300+
WORLD_SIZE, MASTER_ADDR, and MASTER_PORT. You can launch the training job using the
301+
:class:`sagemaker.pytorch.estimator.PyTorch` estimator class
302+
with the ``torch_distributed`` option as the distribution strategy.
303+
304+
.. note::
305+
306+
This torch_distributed support is available
307+
in the SageMaker Trainium (trn1) PyTorch Deep Learning Containers starting v1.11.0
308+
`Supported Containers <https://github.com/aws/deep-learning-containers/blob/master/available_images.md#neuron-containers>`_
309+
310+
SageMaker Debugger and Profiler are currently not supported with Trainium instances.
311+
312+
Adapt Your Training Script
313+
--------------------------
314+
315+
To initialize distributed training in your script, call
316+
`torch.distributed.init_process_group
317+
<https://pytorch.org/docs/master/distributed.html#torch.distributed.init_process_group>`_
318+
with the ``xla`` backend as shown below.
319+
320+
.. code:: python
321+
322+
import torch.distributed as dist
323+
324+
dist.init_process_group('xla')
325+
326+
SageMaker takes care of ``'MASTER_ADDR'`` and ``'MASTER_PORT'`` for you via `torchrun`
327+
328+
For detailed documentation about modifying your training script for Trainium. Please refer `AWS Neuron Documentation <https://awsdocs-neuron.readthedocs-hosted.com/en/latest/frameworks/torch/torch-neuronx/tutorials/training/mlp.html?highlight=torchrun#id5>`_
329+
330+
**Currently Supported backends:**
331+
332+
- ``xla`` for Trainium (Trn1) instances
333+
334+
For up-to-date information on supported backends for Trainium instances, please refer `AWS Neuron Documentation <https://awsdocs-neuron.readthedocs-hosted.com/en/latest/index.html>`_
335+
336+
Launching a Distributed Training Job
337+
------------------------------------
338+
339+
You can run multi-node distributed PyTorch training jobs on Trainium instances using the
340+
:class:`sagemaker.pytorch.estimator.PyTorch` estimator class.
341+
With ``instance_count=1``, the estimator submits a
342+
single-node training job to SageMaker; with ``instance_count`` greater
343+
than one, a multi-node training job is launched.
344+
345+
With the ``torch_distributed`` option, the SageMaker PyTorch estimator runs a SageMaker
346+
training container for PyTorch, sets up the environment, and launches
347+
the training job using the ``torchrun`` command on each worker with the given information
348+
during the PyTorch DDP initialization.
349+
350+
.. note::
351+
352+
The following example shows how to run a PyTorch training using `torch_distributed` in SageMaker
353+
using one ``ml.trn1.2xlarge`` and two ``ml.trn1.32xlarge`` instances:
354+
355+
.. code:: python
356+
357+
from sagemaker.pytorch import PyTorch
358+
359+
pt_estimator = PyTorch(
360+
entry_point="train_ptddp.py",
361+
role="SageMakerRole",
362+
framework_version="1.11.0",
363+
py_version="py38",
364+
instance_count=1,
365+
instance_type="ml.trn1.2xlarge",
366+
distribution={
367+
"torch_distributed": {
368+
"enabled": True
369+
}
370+
}
371+
)
372+
373+
pt_estimator.fit("s3://bucket/path/to/training/data")
374+
375+
.. code:: python
376+
377+
from sagemaker.pytorch import PyTorch
378+
379+
pt_estimator = PyTorch(
380+
entry_point="train_ptddp.py",
381+
role="SageMakerRole",
382+
framework_version="1.11.0",
383+
py_version="py38",
384+
instance_count=2,
385+
instance_type="ml.trn1.32xlarge",
386+
distribution={
387+
"torch_distributed": {
388+
"enabled": True
389+
}
390+
}
391+
)
392+
393+
pt_estimator.fit("s3://bucket/path/to/training/data")
296394
297395
*********************
298396
Deploy PyTorch Models

src/sagemaker/fw_utils.py

+132
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,14 @@
134134
"1.12.0",
135135
]
136136

137+
TORCH_DISTRIBUTED_SUPPORTED_FRAMEWORK_VERSIONS = [
138+
"1.11.0"
139+
]
140+
141+
TRAINIUM_SUPPORTED_DISTRIBUTION_STRATEGIES = [
142+
"torch_distributed"
143+
]
144+
137145
SMDISTRIBUTED_SUPPORTED_STRATEGIES = ["dataparallel", "modelparallel"]
138146

139147

@@ -767,6 +775,10 @@ def validate_distribution(
767775
f"Invalid training instance group {train_instance_group.instance_group_name} !"
768776
)
769777
instance_type = train_instance_group.instance_type
778+
validate_supported_distributions(
779+
instance_type=instance_type,
780+
distribution=distribution,
781+
)
770782
validate_smdistributed(
771783
instance_type=instance_type,
772784
framework_name=framework_name,
@@ -782,6 +794,14 @@ def validate_distribution(
782794
py_version=py_version,
783795
image_uri=image_uri,
784796
)
797+
validate_torch_distributed_distribution(
798+
instance_type=instance_type,
799+
distribution=distribution,
800+
framework_name=framework_name,
801+
framework_version=framework_version,
802+
py_version=py_version,
803+
image_uri=image_uri,
804+
)
785805
warn_if_parameter_server_with_multi_gpu(
786806
training_instance_type=instance_type, distribution=distribution
787807
)
@@ -793,6 +813,10 @@ def validate_distribution(
793813
instance_type = renamed_kwargs(
794814
"train_instance_type", "instance_type", kwargs.get("instance_type"), kwargs
795815
)
816+
validate_supported_distributions(
817+
instance_type=instance_type,
818+
distribution=distribution,
819+
)
796820
validate_smdistributed(
797821
instance_type=instance_type,
798822
framework_name=framework_name,
@@ -808,11 +832,52 @@ def validate_distribution(
808832
py_version=py_version,
809833
image_uri=image_uri,
810834
)
835+
validate_torch_distributed_distribution(
836+
instance_type=instance_type,
837+
distribution=distribution,
838+
framework_name=framework_name,
839+
framework_version=framework_version,
840+
py_version=py_version,
841+
image_uri=image_uri,
842+
)
811843
warn_if_parameter_server_with_multi_gpu(
812844
training_instance_type=instance_type, distribution=distribution
813845
)
814846
return distribution
815847

848+
def validate_supported_distributions(
849+
instance_type, distribution
850+
):
851+
"""Check if the provided distribution strategy is supported for the instance_type
852+
853+
Args:
854+
instance_type (str): A string representing the type of training instance selected.
855+
distribution (dict): A dictionary with information to enable distributed training.
856+
"""
857+
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
858+
err_msg = ""
859+
if match and match[1].startswith("trn"):
860+
keys = distribution.keys()
861+
if len(keys) == 0:
862+
return
863+
elif len(keys) == 1:
864+
distribution_strategy = keys[0]
865+
if distribution_strategy != "torch_distributed":
866+
err_msg += (
867+
f"Provided distribution strategy {distribution_strategy} is not supported by"
868+
" Trainium instances yet.\n"
869+
"Please specify one of the following supported distribution strategies:"
870+
f" {TRAINIUM_SUPPORTED_DISTRIBUTION_STRATEGIES} \n"
871+
)
872+
elif len(keys) > 1:
873+
err_msg += (
874+
f"Multiple distribution strategies are not supported for Trainium instances yet."
875+
"Please specify one of the following supported distribution strategies:"
876+
f" {TRAINIUM_SUPPORTED_DISTRIBUTION_STRATEGIES} "
877+
)
878+
879+
if err_msg:
880+
raise ValueError(err_msg)
816881

817882
def validate_pytorch_distribution(
818883
distribution, framework_name, framework_version, py_version, image_uri
@@ -870,6 +935,73 @@ def validate_pytorch_distribution(
870935
if err_msg:
871936
raise ValueError(err_msg)
872937

938+
def validate_torch_distributed_distribution(
939+
instance_type, distribution, framework_name, framework_version, py_version, image_uri
940+
):
941+
"""Check if torch_distributed distribution strategy is correctly invoked by the user.
942+
943+
Args:
944+
instance_type (str): A string representing the type of training instance selected.
945+
distribution (dict): A dictionary with information to enable distributed training.
946+
(Defaults to None if distributed training is not enabled.) For example:
947+
948+
.. code:: python
949+
950+
{
951+
"torch_distributed": {
952+
"enabled": True
953+
}
954+
}
955+
framework_name (str): A string representing the name of framework selected.
956+
framework_version (str): A string representing the framework version selected.
957+
py_version (str): A string representing the python version selected.
958+
image_uri (str): A string representing a Docker image URI.
959+
960+
Raises:
961+
ValueError: if
962+
`py_version` is not python3 or
963+
`framework_version` is not in TORCH_DISTRIBUTED_SUPPORTED_FRAMEWORK_VERSIONS
964+
"""
965+
if framework_name and framework_name != "pytorch":
966+
# We need to validate only for PyTorch framework
967+
return
968+
969+
torch_distributed_enabled = False
970+
if "torch_distributed" in distribution:
971+
torch_distributed_enabled = distribution.get("torch_distributed").get("enabled", False)
972+
if not torch_distributed_enabled:
973+
# Distribution strategy other than pytorchddp is selected
974+
return
975+
976+
err_msg = ""
977+
if not image_uri:
978+
# ignore framework_version and py_version if image_uri is set
979+
# in case image_uri is not set, then both are mandatory
980+
if framework_version not in TORCH_DISTRIBUTED_SUPPORTED_FRAMEWORK_VERSIONS:
981+
err_msg += (
982+
f"Provided framework_version {framework_version} is not supported by"
983+
" torch_distributed.\n"
984+
"Please specify one of the supported framework versions:"
985+
f" {TORCH_DISTRIBUTED_SUPPORTED_FRAMEWORK_VERSIONS} \n"
986+
)
987+
if "py3" not in py_version:
988+
err_msg += (
989+
f"Provided py_version {py_version} is not supported by torch_distributed.\n"
990+
"Please specify py_version>=py3"
991+
)
992+
993+
# Check instance compatibility
994+
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
995+
if match and match[1].startswith("trn"):
996+
return
997+
else:
998+
err_msg += (
999+
f"torch_distributed is currently supported only for trainium instances."
1000+
" Please refer https://sagemaker.readthedocs.io/en/stable/frameworks/pytorch/using_pytorch.html#distributed-pytorch-training \
1001+
for information regarding distributed training on non-trainium instances"
1002+
)
1003+
if err_msg:
1004+
raise ValueError(err_msg)
8731005

8741006
def python_deprecation_warning(framework, latest_supported_version):
8751007
"""Placeholder docstring"""

src/sagemaker/pytorch/estimator.py

+21
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ class PyTorch(Framework):
3939

4040
_framework_name = "pytorch"
4141
LAUNCH_PYTORCH_DDP_ENV_NAME = "sagemaker_pytorch_ddp_enabled"
42+
LAUNCH_TORCH_DISTRIBUTED_ENV_NAME = "sagemaker_torch_distributed_enabled"
4243
INSTANCE_TYPE_ENV_NAME = "sagemaker_instance_type"
4344

4445
def __init__(
@@ -167,6 +168,18 @@ def __init__(
167168
168169
To learn more, see `Distributed PyTorch Training
169170
<https://sagemaker.readthedocs.io/en/stable/frameworks/pytorch/using_pytorch.html#distributed-pytorch-training>`_.
171+
172+
**To enable Torch Distributed (Trainium Instances):**
173+
174+
.. code:: python
175+
{
176+
"torch_distributed": {
177+
"enabled": True
178+
}
179+
}
180+
To learn more, see `Distributed PyTorch Training on Trainium
181+
<https://sagemaker.readthedocs.io/en/stable/frameworks/pytorch/using_pytorch.html#distributed-pytorch-training-on-trainium>`_.
182+
`
170183
171184
**To enable MPI:**
172185
@@ -242,13 +255,21 @@ def _pytorch_distribution_configuration(self, distribution):
242255
"""
243256
distribution_config = {}
244257
pytorch_ddp_enabled = False
258+
torch_distributed_enabled = False
259+
245260
if "pytorchddp" in distribution:
246261
pytorch_ddp_enabled = distribution.get("pytorchddp").get("enabled", False)
262+
elif "torch_distributed" in distribution:
263+
torch_distributed_enabled = distribution.get("torch_distributed").get("enabled", False)
247264

248265
if pytorch_ddp_enabled:
249266
distribution_config[self.LAUNCH_PYTORCH_DDP_ENV_NAME] = pytorch_ddp_enabled
250267
if self.instance_type is not None:
251268
distribution_config[self.INSTANCE_TYPE_ENV_NAME] = self.instance_type
269+
elif torch_distributed_enabled:
270+
distribution_config[self.LAUNCH_TORCH_DISTRIBUTED_ENV_NAME] = torch_distributed_enabled
271+
if self.instance_type is not None:
272+
distribution_config[self.INSTANCE_TYPE_ENV_NAME] = self.instance_type
252273
else:
253274
distribution_config = self._distribution_configuration(distribution=distribution)
254275

tests/conftest.py

+11
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,17 @@ def pytorch_ddp_framework_version(request):
447447
return request.param
448448

449449

450+
@pytest.fixture(scope="module")
451+
def torch_distributed_py_version():
452+
return "py3"
453+
454+
455+
@pytest.fixture(
456+
scope="module", params=["1.11.0"]
457+
)
458+
def torch_distributed_framework_version(request):
459+
return request.param
460+
450461
@pytest.fixture(scope="session")
451462
def cpu_instance_type(sagemaker_session, request):
452463
region = sagemaker_session.boto_session.region_name

0 commit comments

Comments
 (0)