|
1 |
| -.. admonition:: Contents |
2 |
| - |
3 |
| - - :ref:`communication_api` |
4 |
| - - :ref:`mpi_basics` |
5 |
| - |
6 | 1 | Common API
|
7 | 2 | ==========
|
8 | 3 |
|
9 | 4 | The following SageMaker distribute model parallel APIs are common across all frameworks.
|
10 | 5 |
|
11 |
| -**Important**: This API document assumes you use the following import statement in your training scripts. |
| 6 | +.. contents:: Table of Contents |
| 7 | + :depth: 3 |
| 8 | + :local: |
| 9 | + |
| 10 | +The Library's Core APIs |
| 11 | +----------------------- |
| 12 | + |
| 13 | +This API document assumes you use the following import statement in your training scripts. |
12 | 14 |
|
13 | 15 | **TensorFlow**
|
14 | 16 |
|
@@ -254,30 +256,78 @@ The following SageMaker distribute model parallel APIs are common across all fra
|
254 | 256 | .. _mpi_basics:
|
255 | 257 |
|
256 | 258 | MPI Basics
|
257 |
| -^^^^^^^^^^ |
| 259 | +---------- |
258 | 260 |
|
259 | 261 | The library exposes the following basic MPI primitives to its Python API:
|
260 | 262 |
|
261 |
| -- ``smp.rank()``: The rank of the current process. |
262 |
| -- ``smp.size()``: The total number of processes. |
263 |
| -- ``smp.mp_rank()``: The rank of the process among the processes that |
264 |
| - hold the current model replica. |
265 |
| -- ``smp.dp_rank()``: The rank of the process among the processes that |
266 |
| - hold different replicas of the same model partition. |
267 |
| -- ``smp.dp_size()``: The total number of model replicas. |
268 |
| -- ``smp.local_rank()``: The rank among the processes on the current |
269 |
| - instance. |
270 |
| -- ``smp.local_size()``: The total number of processes on the current |
271 |
| - instance. |
272 |
| -- ``smp.get_mp_group()``: The list of ranks over which the current |
273 |
| - model replica is partitioned. |
274 |
| -- ``smp.get_dp_group()``: The list of ranks that hold different |
275 |
| - replicas of the same model partition. |
276 |
| - |
277 |
| - .. _communication_api: |
| 263 | +**Global** |
| 264 | + |
| 265 | +- ``smp.rank()`` : The global rank of the current process. |
| 266 | +- ``smp.size()`` : The total number of processes. |
| 267 | +- ``smp.get_world_process_group()`` : |
| 268 | + ``torch.distributed.ProcessGroup`` that contains all processes. |
| 269 | +- ``smp.CommGroup.WORLD``: The communication group corresponding to all processes. |
| 270 | +- ``smp.local_rank()``: The rank among the processes on the current instance. |
| 271 | +- ``smp.local_size()``: The total number of processes on the current instance. |
| 272 | +- ``smp.get_mp_group()``: The list of ranks over which the current model replica is partitioned. |
| 273 | +- ``smp.get_dp_group()``: The list of ranks that hold different replicas of the same model partition. |
| 274 | + |
| 275 | +**Tensor Parallelism** |
| 276 | + |
| 277 | +- ``smp.tp_rank()`` : The rank of the process within its |
| 278 | + tensor-parallelism group. |
| 279 | +- ``smp.tp_size()`` : The size of the tensor-parallelism group. |
| 280 | +- ``smp.get_tp_process_group()`` : Equivalent to |
| 281 | + ``torch.distributed.ProcessGroup`` that contains the processes in the |
| 282 | + current tensor-parallelism group. |
| 283 | +- ``smp.CommGroup.TP_GROUP`` : The communication group corresponding to |
| 284 | + the current tensor parallelism group. |
| 285 | + |
| 286 | +**Pipeline Parallelism** |
| 287 | + |
| 288 | +- ``smp.pp_rank()`` : The rank of the process within its |
| 289 | + pipeline-parallelism group. |
| 290 | +- ``smp.pp_size()`` : The size of the pipeline-parallelism group. |
| 291 | +- ``smp.get_pp_process_group()`` : ``torch.distributed.ProcessGroup`` |
| 292 | + that contains the processes in the current pipeline-parallelism group. |
| 293 | +- ``smp.CommGroup.PP_GROUP`` : The communication group corresponding to |
| 294 | + the current pipeline parallelism group. |
| 295 | + |
| 296 | +**Reduced-Data Parallelism** |
| 297 | + |
| 298 | +- ``smp.rdp_rank()`` : The rank of the process within its |
| 299 | + reduced-data-parallelism group. |
| 300 | +- ``smp.rdp_size()`` : The size of the reduced-data-parallelism group. |
| 301 | +- ``smp.get_rdp_process_group()`` : ``torch.distributed.ProcessGroup`` |
| 302 | + that contains the processes in the current reduced data parallelism |
| 303 | + group. |
| 304 | +- ``smp.CommGroup.RDP_GROUP`` : The communication group corresponding |
| 305 | + to the current reduced data parallelism group. |
| 306 | + |
| 307 | +**Model Parallelism** |
| 308 | + |
| 309 | +- ``smp.mp_rank()`` : The rank of the process within its model-parallelism |
| 310 | + group. |
| 311 | +- ``smp.mp_size()`` : The size of the model-parallelism group. |
| 312 | +- ``smp.get_mp_process_group()`` : ``torch.distributed.ProcessGroup`` |
| 313 | + that contains the processes in the current model-parallelism group. |
| 314 | +- ``smp.CommGroup.MP_GROUP`` : The communication group corresponding to |
| 315 | + the current model parallelism group. |
| 316 | + |
| 317 | +**Data Parallelism** |
| 318 | + |
| 319 | +- ``smp.dp_rank()`` : The rank of the process within its data-parallelism |
| 320 | + group. |
| 321 | +- ``smp.dp_size()`` : The size of the data-parallelism group. |
| 322 | +- ``smp.get_dp_process_group()`` : ``torch.distributed.ProcessGroup`` |
| 323 | + that contains the processes in the current data-parallelism group. |
| 324 | +- ``smp.CommGroup.DP_GROUP`` : The communication group corresponding to |
| 325 | + the current data-parallelism group. |
| 326 | + |
| 327 | +.. _communication_api: |
278 | 328 |
|
279 | 329 | Communication API
|
280 |
| -^^^^^^^^^^^^^^^^^ |
| 330 | +----------------- |
281 | 331 |
|
282 | 332 | The library provides a few communication primitives which can be helpful while
|
283 | 333 | developing the training script. These primitives use the following
|
|
0 commit comments