Skip to content

Commit 1896135

Browse files
committed
add all api docs
1 parent 864cb55 commit 1896135

8 files changed

+1129
-341
lines changed

doc/api/training/smd_model_parallel.rst

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,32 +9,41 @@ allowing you to increase prediction accuracy by creating larger models with more
99
You can use the library to automatically partition your existing TensorFlow and PyTorch workloads
1010
across multiple GPUs with minimal code changes. The library's API can be accessed through the Amazon SageMaker SDK.
1111

12-
Use the following sections to learn more about the model parallelism and the library.
12+
See the following sections to learn more about the SageMaker model parallel library APIs.
1313

1414
Use with the SageMaker Python SDK
1515
=================================
1616

1717
Use the following page to learn how to configure and enable distributed model parallel
18-
when you configure an Amazon SageMaker Python SDK `Estimator`.
18+
when you construct an Amazon SageMaker Python SDK `Estimator`.
1919

2020
.. toctree::
2121
:maxdepth: 1
2222

2323
smd_model_parallel_general
2424

25-
API Documentation
26-
=================
25+
The library's API to Adapt Training Scripts
26+
===========================================
2727

28-
The library contains a Common API that is shared across frameworks, as well as APIs
29-
that are specific to supported frameworks, TensorFlow and PyTorch.
28+
The library contains a Common API that is shared across frameworks,
29+
as well as framework-specific APIs for TensorFlow and PyTorch.
3030

31-
Select a version to see the API documentation for version. To use the library, reference the
31+
Select the latest or one of the previous versions of the API documentation
32+
depending on which version of the library you need to use.
33+
To use the library, reference the
3234
**Common API** documentation alongside the framework specific API documentation.
3335

3436
.. toctree::
35-
:maxdepth: 1
37+
:maxdepth: 2
3638

3739
smp_versions/latest.rst
40+
41+
To find archived API documentation for the previous versions of the library,
42+
see the following link:
43+
44+
.. toctree::
45+
:maxdepth: 1
46+
3847
smp_versions/archives.rst
3948

4049
It is recommended to use this documentation alongside `SageMaker Distributed Model Parallel

doc/api/training/smd_model_parallel_general.rst

Lines changed: 300 additions & 298 deletions
Large diffs are not rendered by default.

doc/api/training/smp_versions/archives.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
.. _smdmp-pt-version-archive:
2+
13
Version Archive
24
===============
35

doc/api/training/smp_versions/latest.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,5 @@ To use the library, reference the Common API documentation alongside the framewo
99

1010
latest/smd_model_parallel_common_api
1111
latest/smd_model_parallel_pytorch
12+
latest/smd_model_parallel_pytorch_tensor_parallel
1213
latest/smd_model_parallel_tensorflow

doc/api/training/smp_versions/latest/smd_model_parallel_common_api.rst

Lines changed: 66 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -258,23 +258,72 @@ MPI Basics
258258

259259
The library exposes the following basic MPI primitives to its Python API:
260260

261-
- ``smp.rank()``: The rank of the current process.
262-
- ``smp.size()``: The total number of processes.
263-
- ``smp.mp_rank()``: The rank of the process among the processes that
264-
hold the current model replica.
265-
- ``smp.dp_rank()``: The rank of the process among the processes that
266-
hold different replicas of the same model partition.
267-
- ``smp.dp_size()``: The total number of model replicas.
268-
- ``smp.local_rank()``: The rank among the processes on the current
269-
instance.
270-
- ``smp.local_size()``: The total number of processes on the current
271-
instance.
272-
- ``smp.get_mp_group()``: The list of ranks over which the current
273-
model replica is partitioned.
274-
- ``smp.get_dp_group()``: The list of ranks that hold different
275-
replicas of the same model partition.
276-
277-
.. _communication_api:
261+
**Global**
262+
263+
- ``smp.rank()`` : The global rank of the current process.
264+
- ``smp.size()`` : The total number of processes.
265+
- ``smp.get_world_process_group()`` :
266+
``torch.distributed.ProcessGroup`` that contains all processes.
267+
- ``[smp.CommGroup.WORLD](https://sagemaker.readthedocs.io/en/stable/api/training/smp_versions/latest/smd_model_parallel_common_api.html#smp.CommGroup)``
268+
: The communication group corresponding to all processes.
269+
- ``smp.local_rank()``: The rank among the processes on the current instance.
270+
- ``smp.local_size()``: The total number of processes on the current instance.
271+
- ``smp.get_mp_group()``: The list of ranks over which the current model replica is partitioned.
272+
- ``smp.get_dp_group()``: The list of ranks that hold different replicas of the same model partition.
273+
274+
**Tensor Parallelism**
275+
276+
- ``smp.tp_rank()`` : The rank of the process within its
277+
tensor-parallelism group.
278+
- ``smp.tp_size()`` : The size of the tensor-parallelism group.
279+
- ``smp.get_tp_process_group()`` : Equivalent to
280+
``torch.distributed.ProcessGroup`` that contains the processes in the
281+
current tensor-parallelism group.
282+
- ``smp.CommGroup.TP_GROUP`` : The communication group corresponding to
283+
the current tensor parallelism group.
284+
285+
**Pipeline Parallelism**
286+
287+
- ``smp.pp_rank()`` : The rank of the process within its
288+
pipeline-parallelism group.
289+
- ``smp.pp_size()`` : The size of the pipeline-parallelism group.
290+
- ``smp.get_pp_process_group()`` : ``torch.distributed.ProcessGroup``
291+
that contains the processes in the current pipeline-parallelism group.
292+
- ``smp.CommGroup.PP_GROUP`` : The communication group corresponding to
293+
the current pipeline parallelism group.
294+
295+
**Reduced-Data Parallelism**
296+
297+
- ``smp.rdp_rank()`` : The rank of the process within its
298+
reduced-data-parallelism group.
299+
- ``smp.rdp_size()`` : The size of the reduced-data-parallelism group.
300+
- ``smp.get_rdp_process_group()`` : ``torch.distributed.ProcessGroup``
301+
that contains the processes in the current reduced data parallelism
302+
group.
303+
- ``smp.CommGroup.RDP_GROUP`` : The communication group corresponding
304+
to the current reduced data parallelism group.
305+
306+
**Model Parallelism**
307+
308+
- ``smp.mp_rank()`` : The rank of the process within its model-parallelism
309+
group.
310+
- ``smp.mp_size()`` : The size of the model-parallelism group.
311+
- ``smp.get_mp_process_group()`` : ``torch.distributed.ProcessGroup``
312+
that contains the processes in the current model-parallelism group.
313+
- ``smp.CommGroup.MP_GROUP`` : The communication group corresponding to
314+
the current model parallelism group.
315+
316+
**Data Parallelism**
317+
318+
- ``smp.dp_rank()`` : The rank of the process within its data-parallelism
319+
group.
320+
- ``smp.dp_size()`` : The size of the data-parallelism group.
321+
- ``smp.get_dp_process_group()`` : ``torch.distributed.ProcessGroup``
322+
that contains the processes in the current data-parallelism group.
323+
- ``smp.CommGroup.DP_GROUP`` : The communication group corresponding to
324+
the current data-parallelism group.
325+
326+
.. _communication_api:
278327

279328
Communication API
280329
^^^^^^^^^^^^^^^^^

doc/api/training/smp_versions/latest/smd_model_parallel_pytorch.rst

Lines changed: 112 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,8 @@
1-
.. admonition:: Contents
2-
3-
- :ref:`pytorch_saving_loading`
4-
- :ref:`pytorch_saving_loading_instructions`
5-
61
PyTorch API
72
===========
83

9-
**Supported versions: 1.7.1, 1.8.1**
10-
11-
This API document assumes you use the following import statements in your training scripts.
4+
To use the PyTorch-specific APIs for SageMaker distributed model parallism,
5+
you need to add the following import statement at the top of your training script.
126

137
.. code:: python
148
@@ -19,10 +13,10 @@ This API document assumes you use the following import statements in your traini
1913

2014
Refer to
2115
`Modify a PyTorch Training Script
22-
<https://docs.aws.amazon.com/sagemaker/latest/dg/model-parallel-customize-training-script.html#model-parallel-customize-training-script-pt>`_
16+
<https://docs.aws.amazon.com/sagemaker/latest/dg/model-parallel-customize-training-script-pt.html>`_
2317
to learn how to use the following API in your PyTorch training script.
2418

25-
.. class:: smp.DistributedModel
19+
.. py:class:: smp.DistributedModel()
2620
2721
A sub-class of ``torch.nn.Module`` which specifies the model to be
2822
partitioned. Accepts a ``torch.nn.Module`` object ``module`` which is
@@ -42,7 +36,6 @@ This API document assumes you use the following import statements in your traini
4236
is \ ``model``) can only be made inside a ``smp.step``-decorated
4337
function.
4438

45-
4639
Since ``DistributedModel``  is a ``torch.nn.Module``, a forward pass can
4740
be performed by calling the \ ``DistributedModel`` object on the input
4841
tensors.
@@ -56,7 +49,6 @@ This API document assumes you use the following import statements in your traini
5649
arguments, replacing the PyTorch operations \ ``torch.Tensor.backward``
5750
or ``torch.autograd.backward``.
5851

59-
6052
The API for ``model.backward`` is very similar to
6153
``torch.autograd.backward``. For example, the following
6254
``backward`` calls:
@@ -90,7 +82,7 @@ This API document assumes you use the following import statements in your traini
9082

9183
**Using DDP**
9284

93-
If DDP is enabled, do not not place a PyTorch
85+
If DDP is enabled with the SageMaker model parallel library, do not not place a PyTorch
9486
``DistributedDataParallel`` wrapper around the ``DistributedModel`` because
9587
the ``DistributedModel`` wrapper will also handle data parallelism.
9688

@@ -284,6 +276,113 @@ This API document assumes you use the following import statements in your traini
284276
`register_comm_hook <https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel.register_comm_hook>`__
285277
in the PyTorch documentation.
286278

279+
**Behavior of** ``smp.DistributedModel`` **with Tensor Parallelism**
280+
281+
When a model is wrapped by ``smp.DistributedModel``, the library
282+
immediately traverses the modules of the model object, and replaces the
283+
modules that are supported for tensor parallelism with their distributed
284+
counterparts. This replacement happens in place. If there are no other
285+
references to the original modules in the script, they are
286+
garbage-collected. The module attributes that previously referred to the
287+
original submodules now refer to the distributed versions of those
288+
submodules.
289+
290+
**Example:**
291+
292+
.. code:: python
293+
294+
# register DistributedSubmodule as the distributed version of Submodule
295+
# (note this is a hypothetical example, smp.nn.DistributedSubmodule does not exist)
296+
smp.tp_register_with_module(Submodule, smp.nn.DistributedSubmodule)
297+
298+
class MyModule(nn.Module):
299+
def __init__(self):
300+
...
301+
302+
self.submodule = Submodule()
303+
...
304+
305+
# enabling tensor parallelism for the entire model
306+
with smp.tensor_parallelism():
307+
model = MyModule()
308+
309+
# here model.submodule is still a Submodule object
310+
assert isinstance(model.submodule, Submodule)
311+
312+
model = smp.DistributedModel(model)
313+
314+
# now model.submodule is replaced with an equivalent instance
315+
# of smp.nn.DistributedSubmodule
316+
assert isinstance(model.module.submodule, smp.nn.DistributedSubmodule)
317+
318+
If ``pipeline_parallel_degree`` (equivalently, ``partitions``) is 1, the
319+
placement of model partitions into GPUs and the initial broadcast of
320+
model parameters and buffers across data-parallel ranks take place
321+
immediately. This is because it does not need to wait for the model
322+
partition when ``smp.DistributedModel`` wrapper is called. For other
323+
cases with ``pipeline_parallel_degree`` greater than 1, the broadcast
324+
and device placement will be deferred until the first call of an
325+
``smp.step``-decorated function happens. This is because the first
326+
``smp.step``-decorated function call is when the model partitioning
327+
happens if pipeline parallelism is enabled.
328+
329+
Because of the module replacement during the ``smp.DistributedModel``
330+
call, any ``load_state_dict`` calls on the model, as well as any direct
331+
access to model parameters, such as during the optimizer creation,
332+
should be done **after** the ``smp.DistributedModel`` call.
333+
334+
Since the broadcast of the model parameters and buffers happens
335+
immediately during ``smp.DistributedModel`` call when the degree of
336+
pipeline parallelism is 1, using ``@smp.step`` decorators is not
337+
required when tensor parallelism is used by itself (without pipeline
338+
parallelism).
339+
340+
For more information about the library's tensor parallelism APIs for PyTorch,
341+
see :ref:`smdmp-pytorch-tensor-parallel`.
342+
343+
**Additional Methods of** ``smp.DistributedModel`` **for Tensor Parallelism**
344+
345+
The following are the new methods of ``smp.DistributedModel``, in
346+
addition to the ones listed in the
347+
`documentation <https://sagemaker.readthedocs.io/en/stable/api/training/smp_versions/v1.2.0/smd_model_parallel_pytorch.html#smp.DistributedModel>`__.
348+
349+
.. function:: distributed_modules()
350+
351+
- An iterator that runs over the set of distributed
352+
(tensor-parallelized) modules in the model
353+
354+
.. function:: is_distributed_parameter(param)
355+
356+
- Returns ``True`` if the given ``nn.Parameter`` is distributed over
357+
tensor-parallel ranks.
358+
359+
.. function:: is_distributed_buffer(buf)
360+
361+
- Returns ``True`` if the given buffer is distributed over
362+
tensor-parallel ranks.
363+
364+
.. function:: is_scaled_batch_parameter(param)
365+
366+
- Returns ``True`` if the given ``nn.Parameter`` is operates on the
367+
scaled batch (batch over the entire ``TP_GROUP``, and not only the
368+
local batch).
369+
370+
.. function:: is_scaled_batch_buffer(buf)
371+
372+
- Returns ``True`` if the parameter corresponding to the given
373+
buffer operates on the scaled batch (batch over the entire
374+
``TP_GROUP``, and not only the local batch).
375+
376+
.. function:: default_reducer_named_parameters()
377+
378+
- Returns an iterator that runs over ``(name, param)`` tuples, for
379+
``param`` that is allreduced over the ``DP_GROUP``.
380+
381+
.. function:: scaled_batch_reducer_named_parameters()
382+
383+
- Returns an iterator that runs over ``(name, param)`` tuples, for
384+
``param`` that is allreduced over the ``RDP_GROUP``.
385+
287386

288387

289388
.. class:: smp.DistributedOptimizer

0 commit comments

Comments
 (0)