Skip to content

feature: Add PyTorch DDP distribution support #3270

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 80 additions & 11 deletions doc/frameworks/pytorch/using_pytorch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -200,29 +200,98 @@ fit Optional Arguments
Distributed PyTorch Training
============================

You can run a multi-machine, distributed PyTorch training using the PyTorch Estimator. By default, PyTorch objects will
submit single-machine training jobs to SageMaker. If you set ``instance_count`` to be greater than one, multi-machine
training jobs will be launched when ``fit`` is called. When you run multi-machine training, SageMaker will import your
training script and run it on each host in the cluster.
SageMaker supports the `PyTorch DistributedDataParallel (DDP)
<https://pytorch.org/docs/master/generated/torch.nn.parallel.DistributedDataParallel.html>`_
package. You simply need to check the variables in your training script,
such as the world size and the rank of the current host, when initializing
process groups for distributed training.
And then, launch the training job using the
:class:`sagemaker.pytorch.estimator.PyTorch` estimator class
with the ``pytorchddp`` option as the distribution strategy.

To initialize distributed training in your script you would call ``dist.init_process_group`` providing desired backend
and rank and setting 'WORLD_SIZE' environment variable similar to how you would do it outside of SageMaker using
environment variable initialization:
.. note::

This PyTorch DDP support is available
in the SageMaker PyTorch Deep Learning Containers v1.12 and later.

Adapt Your Training Script
--------------------------

To initialize distributed training in your script, call
`torch.distributed.init_process_group
<https://pytorch.org/docs/master/distributed.html#torch.distributed.init_process_group>`_
with the desired backend and the rank of the current host.

.. code:: python

import torch.distributed as dist

if args.distributed:
# Initialize the distributed environment.
world_size = len(args.hosts)
os.environ['WORLD_SIZE'] = str(world_size)
host_rank = args.hosts.index(args.current_host)
dist.init_process_group(backend=args.backend, rank=host_rank)

SageMaker sets 'MASTER_ADDR' and 'MASTER_PORT' environment variables for you, but you can overwrite them.
SageMaker sets ``'MASTER_ADDR'`` and ``'MASTER_PORT'`` environment variables for you,
but you can also overwrite them.

**Supported backends:**

- ``gloo`` and ``tcp`` for CPU instances
- ``gloo`` and ``nccl`` for GPU instances

Launching a Distributed Training Job
------------------------------------

You can run multi-node distributed PyTorch training jobs using the
:class:`sagemaker.pytorch.estimator.PyTorch` estimator class.
With ``instance_count=1``, the estimator submits a
single-node training job to SageMaker; with ``instance_count`` greater
than one, a multi-node training job is launched.

To run a distributed training script that adopts
the `PyTorch DistributedDataParallel (DDP) package
<https://pytorch.org/docs/master/generated/torch.nn.parallel.DistributedDataParallel.html>`_,
choose the ``pytorchddp`` as the distributed training option in the ``PyTorch`` estimator.

With the ``pytorchddp`` option, the SageMaker PyTorch estimator runs a SageMaker
training container for PyTorch, sets up the environment for MPI, and launches
the training job using the ``mpirun`` command on each worker with the given information
during the PyTorch DDP initialization.

.. note::

The SageMaker PyTorch estimator doesn’t use ``torchrun`` for distributed training.

For more information about setting up PyTorch DDP in your training script,
see `Getting Started with Distributed Data Parallel
<https://pytorch.org/tutorials/intermediate/ddp_tutorial.html>`_ in the
PyTorch documentation.

The following example shows how to run a PyTorch DDP training in SageMaker
using two ``ml.p4d.24xlarge`` instances:

.. code:: python

from sagemaker.pytorch import PyTorch

pt_estimator = PyTorch(
entry_point="train_ptddp.py",
role="SageMakerRole",
framework_version="1.12.0",
py_version="py38",
instance_count=2,
instance_type="ml.p4d.24xlarge",
distribution={
"pytorchddp": {
"enabled": True
}
}
)

pt_estimator.fit("s3://bucket/path/to/training/data")

Supported backends:
- `gloo` and `tcp` for cpu instances
- `gloo` and `nccl` for gpu instances


*********************
Expand Down
82 changes: 82 additions & 0 deletions src/sagemaker/fw_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,17 @@
"1.11.0",
],
}

PYTORCHDDP_SUPPORTED_FRAMEWORK_VERSIONS = [
"1.10",
"1.10.0",
"1.10.2",
"1.11",
"1.11.0",
"1.12",
"1.12.0",
]

SMDISTRIBUTED_SUPPORTED_STRATEGIES = ["dataparallel", "modelparallel"]


Expand Down Expand Up @@ -728,6 +739,13 @@ def validate_distribution(
distribution=distribution,
image_uri=image_uri,
)
validate_pytorch_distribution(
distribution=distribution,
framework_name=framework_name,
framework_version=framework_version,
py_version=py_version,
image_uri=image_uri,
)
warn_if_parameter_server_with_multi_gpu(
training_instance_type=instance_type, distribution=distribution
)
Expand All @@ -747,12 +765,76 @@ def validate_distribution(
distribution=distribution,
image_uri=image_uri,
)
validate_pytorch_distribution(
distribution=distribution,
framework_name=framework_name,
framework_version=framework_version,
py_version=py_version,
image_uri=image_uri,
)
warn_if_parameter_server_with_multi_gpu(
training_instance_type=instance_type, distribution=distribution
)
return distribution


def validate_pytorch_distribution(
distribution, framework_name, framework_version, py_version, image_uri
):
"""Check if pytorch distribution strategy is correctly invoked by the user.

Args:
distribution (dict): A dictionary with information to enable distributed training.
(Defaults to None if distributed training is not enabled.) For example:

.. code:: python

{
"pytorchddp": {
"enabled": True
}
}
framework_name (str): A string representing the name of framework selected.
framework_version (str): A string representing the framework version selected.
py_version (str): A string representing the python version selected.
image_uri (str): A string representing a Docker image URI.

Raises:
ValueError: if
`py_version` is not python3 or
`framework_version` is not in PYTORCHDDP_SUPPORTED_FRAMEWORK_VERSIONS
"""
if framework_name and framework_name != "pytorch":
# We need to validate only for PyTorch framework
return

pytorch_ddp_enabled = False
if "pytorchddp" in distribution:
pytorch_ddp_enabled = distribution.get("pytorchddp").get("enabled", False)
if not pytorch_ddp_enabled:
# Distribution strategy other than pytorchddp is selected
return

err_msg = ""
if not image_uri:
# ignore framework_version and py_version if image_uri is set
# in case image_uri is not set, then both are mandatory
if framework_version not in PYTORCHDDP_SUPPORTED_FRAMEWORK_VERSIONS:
err_msg += (
f"Provided framework_version {framework_version} is not supported by"
" pytorchddp.\n"
"Please specify one of the supported framework versions:"
f" {PYTORCHDDP_SUPPORTED_FRAMEWORK_VERSIONS} \n"
)
if "py3" not in py_version:
err_msg += (
f"Provided py_version {py_version} is not supported by pytorchddp.\n"
"Please specify py_version>=py3"
)
if err_msg:
raise ValueError(err_msg)


def python_deprecation_warning(framework, latest_supported_version):
"""Placeholder docstring"""
return PYTHON_2_DEPRECATION_WARNING.format(
Expand Down
39 changes: 38 additions & 1 deletion src/sagemaker/pytorch/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ class PyTorch(Framework):
"""Handle end-to-end training and deployment of custom PyTorch code."""

_framework_name = "pytorch"
LAUNCH_PYTORCH_DDP_ENV_NAME = "sagemaker_pytorch_ddp_enabled"
INSTANCE_TYPE_ENV_NAME = "sagemaker_instance_type"

def __init__(
self,
Expand Down Expand Up @@ -153,6 +155,19 @@ def __init__(
To find a complete list of parameters for SageMaker model parallelism,
see :ref:`sm-sdk-modelparallel-general`.

**To enable PyTorch DDP:**

.. code:: python

{
"pytorchddp": {
"enabled": True
}
}

To learn more, see `Distributed PyTorch Training
<https://sagemaker.readthedocs.io/en/stable/frameworks/pytorch/using_pytorch.html#distributed-pytorch-training>`_.

**To enable MPI:**

.. code:: python
Expand Down Expand Up @@ -217,10 +232,32 @@ def __init__(

self.distribution = distribution or {}

def _pytorch_distribution_configuration(self, distribution):
"""Returns a dict of distribution config for PyTorch training

Args:
distribution (dict): A dictionary with information on how to run distributed training.
Returns:
dict containing Pytorch DDP config
"""
distribution_config = {}
pytorch_ddp_enabled = False
if "pytorchddp" in distribution:
pytorch_ddp_enabled = distribution.get("pytorchddp").get("enabled", False)

if pytorch_ddp_enabled:
distribution_config[self.LAUNCH_PYTORCH_DDP_ENV_NAME] = pytorch_ddp_enabled
if self.instance_type is not None:
distribution_config[self.INSTANCE_TYPE_ENV_NAME] = self.instance_type
else:
distribution_config = self._distribution_configuration(distribution=distribution)

return distribution_config

def hyperparameters(self):
"""Return hyperparameters used by your custom PyTorch code during model training."""
hyperparameters = super(PyTorch, self).hyperparameters()
additional_hyperparameters = self._distribution_configuration(
additional_hyperparameters = self._pytorch_distribution_configuration(
distribution=self.distribution
)
hyperparameters.update(
Expand Down
12 changes: 12 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,18 @@ def tf_full_py_version(tf_full_version):
return "py39"


@pytest.fixture(scope="module")
def pytorch_ddp_py_version():
return "py3"


@pytest.fixture(
scope="module", params=["1.10", "1.10.0", "1.10.2", "1.11", "1.11.0", "1.12", "1.12.0"]
)
def pytorch_ddp_framework_version(request):
return request.param


@pytest.fixture(scope="session")
def cpu_instance_type(sagemaker_session, request):
region = sagemaker_session.boto_session.region_name
Expand Down
Loading