Skip to content

Commit 6f72e3c

Browse files
vishwakariaUbuntu
and
Ubuntu
authored
feature: Add PyTorch DDP distribution support (#3270)
Co-authored-by: Ubuntu <[email protected]>
1 parent 60872f3 commit 6f72e3c

File tree

9 files changed

+722
-14
lines changed

9 files changed

+722
-14
lines changed

doc/frameworks/pytorch/using_pytorch.rst

+80-11
Original file line numberDiff line numberDiff line change
@@ -200,29 +200,98 @@ fit Optional Arguments
200200
Distributed PyTorch Training
201201
============================
202202

203-
You can run a multi-machine, distributed PyTorch training using the PyTorch Estimator. By default, PyTorch objects will
204-
submit single-machine training jobs to SageMaker. If you set ``instance_count`` to be greater than one, multi-machine
205-
training jobs will be launched when ``fit`` is called. When you run multi-machine training, SageMaker will import your
206-
training script and run it on each host in the cluster.
203+
SageMaker supports the `PyTorch DistributedDataParallel (DDP)
204+
<https://pytorch.org/docs/master/generated/torch.nn.parallel.DistributedDataParallel.html>`_
205+
package. You simply need to check the variables in your training script,
206+
such as the world size and the rank of the current host, when initializing
207+
process groups for distributed training.
208+
And then, launch the training job using the
209+
:class:`sagemaker.pytorch.estimator.PyTorch` estimator class
210+
with the ``pytorchddp`` option as the distribution strategy.
207211

208-
To initialize distributed training in your script you would call ``dist.init_process_group`` providing desired backend
209-
and rank and setting 'WORLD_SIZE' environment variable similar to how you would do it outside of SageMaker using
210-
environment variable initialization:
212+
.. note::
213+
214+
This PyTorch DDP support is available
215+
in the SageMaker PyTorch Deep Learning Containers v1.12 and later.
216+
217+
Adapt Your Training Script
218+
--------------------------
219+
220+
To initialize distributed training in your script, call
221+
`torch.distributed.init_process_group
222+
<https://pytorch.org/docs/master/distributed.html#torch.distributed.init_process_group>`_
223+
with the desired backend and the rank of the current host.
211224

212225
.. code:: python
213226
227+
import torch.distributed as dist
228+
214229
if args.distributed:
215230
# Initialize the distributed environment.
216231
world_size = len(args.hosts)
217232
os.environ['WORLD_SIZE'] = str(world_size)
218233
host_rank = args.hosts.index(args.current_host)
219234
dist.init_process_group(backend=args.backend, rank=host_rank)
220235
221-
SageMaker sets 'MASTER_ADDR' and 'MASTER_PORT' environment variables for you, but you can overwrite them.
236+
SageMaker sets ``'MASTER_ADDR'`` and ``'MASTER_PORT'`` environment variables for you,
237+
but you can also overwrite them.
238+
239+
**Supported backends:**
240+
241+
- ``gloo`` and ``tcp`` for CPU instances
242+
- ``gloo`` and ``nccl`` for GPU instances
243+
244+
Launching a Distributed Training Job
245+
------------------------------------
246+
247+
You can run multi-node distributed PyTorch training jobs using the
248+
:class:`sagemaker.pytorch.estimator.PyTorch` estimator class.
249+
With ``instance_count=1``, the estimator submits a
250+
single-node training job to SageMaker; with ``instance_count`` greater
251+
than one, a multi-node training job is launched.
252+
253+
To run a distributed training script that adopts
254+
the `PyTorch DistributedDataParallel (DDP) package
255+
<https://pytorch.org/docs/master/generated/torch.nn.parallel.DistributedDataParallel.html>`_,
256+
choose the ``pytorchddp`` as the distributed training option in the ``PyTorch`` estimator.
257+
258+
With the ``pytorchddp`` option, the SageMaker PyTorch estimator runs a SageMaker
259+
training container for PyTorch, sets up the environment for MPI, and launches
260+
the training job using the ``mpirun`` command on each worker with the given information
261+
during the PyTorch DDP initialization.
262+
263+
.. note::
264+
265+
The SageMaker PyTorch estimator doesn’t use ``torchrun`` for distributed training.
266+
267+
For more information about setting up PyTorch DDP in your training script,
268+
see `Getting Started with Distributed Data Parallel
269+
<https://pytorch.org/tutorials/intermediate/ddp_tutorial.html>`_ in the
270+
PyTorch documentation.
271+
272+
The following example shows how to run a PyTorch DDP training in SageMaker
273+
using two ``ml.p4d.24xlarge`` instances:
274+
275+
.. code:: python
276+
277+
from sagemaker.pytorch import PyTorch
278+
279+
pt_estimator = PyTorch(
280+
entry_point="train_ptddp.py",
281+
role="SageMakerRole",
282+
framework_version="1.12.0",
283+
py_version="py38",
284+
instance_count=2,
285+
instance_type="ml.p4d.24xlarge",
286+
distribution={
287+
"pytorchddp": {
288+
"enabled": True
289+
}
290+
}
291+
)
292+
293+
pt_estimator.fit("s3://bucket/path/to/training/data")
222294
223-
Supported backends:
224-
- `gloo` and `tcp` for cpu instances
225-
- `gloo` and `nccl` for gpu instances
226295
227296
228297
*********************

src/sagemaker/fw_utils.py

+82
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,17 @@
103103
"1.11.0",
104104
],
105105
}
106+
107+
PYTORCHDDP_SUPPORTED_FRAMEWORK_VERSIONS = [
108+
"1.10",
109+
"1.10.0",
110+
"1.10.2",
111+
"1.11",
112+
"1.11.0",
113+
"1.12",
114+
"1.12.0",
115+
]
116+
106117
SMDISTRIBUTED_SUPPORTED_STRATEGIES = ["dataparallel", "modelparallel"]
107118

108119

@@ -728,6 +739,13 @@ def validate_distribution(
728739
distribution=distribution,
729740
image_uri=image_uri,
730741
)
742+
validate_pytorch_distribution(
743+
distribution=distribution,
744+
framework_name=framework_name,
745+
framework_version=framework_version,
746+
py_version=py_version,
747+
image_uri=image_uri,
748+
)
731749
warn_if_parameter_server_with_multi_gpu(
732750
training_instance_type=instance_type, distribution=distribution
733751
)
@@ -747,12 +765,76 @@ def validate_distribution(
747765
distribution=distribution,
748766
image_uri=image_uri,
749767
)
768+
validate_pytorch_distribution(
769+
distribution=distribution,
770+
framework_name=framework_name,
771+
framework_version=framework_version,
772+
py_version=py_version,
773+
image_uri=image_uri,
774+
)
750775
warn_if_parameter_server_with_multi_gpu(
751776
training_instance_type=instance_type, distribution=distribution
752777
)
753778
return distribution
754779

755780

781+
def validate_pytorch_distribution(
782+
distribution, framework_name, framework_version, py_version, image_uri
783+
):
784+
"""Check if pytorch distribution strategy is correctly invoked by the user.
785+
786+
Args:
787+
distribution (dict): A dictionary with information to enable distributed training.
788+
(Defaults to None if distributed training is not enabled.) For example:
789+
790+
.. code:: python
791+
792+
{
793+
"pytorchddp": {
794+
"enabled": True
795+
}
796+
}
797+
framework_name (str): A string representing the name of framework selected.
798+
framework_version (str): A string representing the framework version selected.
799+
py_version (str): A string representing the python version selected.
800+
image_uri (str): A string representing a Docker image URI.
801+
802+
Raises:
803+
ValueError: if
804+
`py_version` is not python3 or
805+
`framework_version` is not in PYTORCHDDP_SUPPORTED_FRAMEWORK_VERSIONS
806+
"""
807+
if framework_name and framework_name != "pytorch":
808+
# We need to validate only for PyTorch framework
809+
return
810+
811+
pytorch_ddp_enabled = False
812+
if "pytorchddp" in distribution:
813+
pytorch_ddp_enabled = distribution.get("pytorchddp").get("enabled", False)
814+
if not pytorch_ddp_enabled:
815+
# Distribution strategy other than pytorchddp is selected
816+
return
817+
818+
err_msg = ""
819+
if not image_uri:
820+
# ignore framework_version and py_version if image_uri is set
821+
# in case image_uri is not set, then both are mandatory
822+
if framework_version not in PYTORCHDDP_SUPPORTED_FRAMEWORK_VERSIONS:
823+
err_msg += (
824+
f"Provided framework_version {framework_version} is not supported by"
825+
" pytorchddp.\n"
826+
"Please specify one of the supported framework versions:"
827+
f" {PYTORCHDDP_SUPPORTED_FRAMEWORK_VERSIONS} \n"
828+
)
829+
if "py3" not in py_version:
830+
err_msg += (
831+
f"Provided py_version {py_version} is not supported by pytorchddp.\n"
832+
"Please specify py_version>=py3"
833+
)
834+
if err_msg:
835+
raise ValueError(err_msg)
836+
837+
756838
def python_deprecation_warning(framework, latest_supported_version):
757839
"""Placeholder docstring"""
758840
return PYTHON_2_DEPRECATION_WARNING.format(

src/sagemaker/pytorch/estimator.py

+38-1
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ class PyTorch(Framework):
3838
"""Handle end-to-end training and deployment of custom PyTorch code."""
3939

4040
_framework_name = "pytorch"
41+
LAUNCH_PYTORCH_DDP_ENV_NAME = "sagemaker_pytorch_ddp_enabled"
42+
INSTANCE_TYPE_ENV_NAME = "sagemaker_instance_type"
4143

4244
def __init__(
4345
self,
@@ -153,6 +155,19 @@ def __init__(
153155
To find a complete list of parameters for SageMaker model parallelism,
154156
see :ref:`sm-sdk-modelparallel-general`.
155157
158+
**To enable PyTorch DDP:**
159+
160+
.. code:: python
161+
162+
{
163+
"pytorchddp": {
164+
"enabled": True
165+
}
166+
}
167+
168+
To learn more, see `Distributed PyTorch Training
169+
<https://sagemaker.readthedocs.io/en/stable/frameworks/pytorch/using_pytorch.html#distributed-pytorch-training>`_.
170+
156171
**To enable MPI:**
157172
158173
.. code:: python
@@ -217,10 +232,32 @@ def __init__(
217232

218233
self.distribution = distribution or {}
219234

235+
def _pytorch_distribution_configuration(self, distribution):
236+
"""Returns a dict of distribution config for PyTorch training
237+
238+
Args:
239+
distribution (dict): A dictionary with information on how to run distributed training.
240+
Returns:
241+
dict containing Pytorch DDP config
242+
"""
243+
distribution_config = {}
244+
pytorch_ddp_enabled = False
245+
if "pytorchddp" in distribution:
246+
pytorch_ddp_enabled = distribution.get("pytorchddp").get("enabled", False)
247+
248+
if pytorch_ddp_enabled:
249+
distribution_config[self.LAUNCH_PYTORCH_DDP_ENV_NAME] = pytorch_ddp_enabled
250+
if self.instance_type is not None:
251+
distribution_config[self.INSTANCE_TYPE_ENV_NAME] = self.instance_type
252+
else:
253+
distribution_config = self._distribution_configuration(distribution=distribution)
254+
255+
return distribution_config
256+
220257
def hyperparameters(self):
221258
"""Return hyperparameters used by your custom PyTorch code during model training."""
222259
hyperparameters = super(PyTorch, self).hyperparameters()
223-
additional_hyperparameters = self._distribution_configuration(
260+
additional_hyperparameters = self._pytorch_distribution_configuration(
224261
distribution=self.distribution
225262
)
226263
hyperparameters.update(

tests/conftest.py

+12
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,18 @@ def tf_full_py_version(tf_full_version):
411411
return "py39"
412412

413413

414+
@pytest.fixture(scope="module")
415+
def pytorch_ddp_py_version():
416+
return "py3"
417+
418+
419+
@pytest.fixture(
420+
scope="module", params=["1.10", "1.10.0", "1.10.2", "1.11", "1.11.0", "1.12", "1.12.0"]
421+
)
422+
def pytorch_ddp_framework_version(request):
423+
return request.param
424+
425+
414426
@pytest.fixture(scope="session")
415427
def cpu_instance_type(sagemaker_session, request):
416428
region = sagemaker_session.boto_session.region_name

0 commit comments

Comments
 (0)