@@ -24,10 +24,12 @@ The following SageMaker distribute model parallel APIs are common across all fra
24
24
25
25
26
26
.. function :: smp.init( )
27
+ :noindex:
27
28
28
29
Initialize the library. Must be called at the beginning of training script.
29
30
30
31
.. function :: @smp.step(non_split_inputs, input_split_axes, [*args, **kwargs])
32
+ :noindex:
31
33
32
34
A decorator that must be placed over a function that represents a single
33
35
forward and backward pass (for training use cases), or a single forward
@@ -162,6 +164,7 @@ The following SageMaker distribute model parallel APIs are common across all fra
162
164
163
165
164
166
.. class :: StepOutput
167
+ :noindex:
165
168
166
169
167
170
A class that encapsulates all versions of a ``tf.Tensor ``
@@ -191,27 +194,32 @@ The following SageMaker distribute model parallel APIs are common across all fra
191
194
post-processing operations on tensors.
192
195
193
196
.. data:: StepOutput.outputs
197
+ :noindex:
194
198
195
199
Returns a list of the underlying tensors, indexed by microbatch.
196
200
197
201
.. function:: StepOutput.reduce_mean( )
202
+ :noindex:
198
203
199
204
Returns a ``tf.Tensor ``, ``torch.Tensor `` that averages the constituent ``tf.Tensor `` s
200
205
``torch.Tensor `` s. This is commonly used for averaging loss and gradients across microbatches.
201
206
202
207
.. function :: StepOutput.reduce_sum( )
208
+ :noindex:
203
209
204
210
Returns a ``tf.Tensor`` /
205
211
``torch.Tensor `` that sums the constituent
206
212
``tf.Tensor ``\ s/\ ``torch.Tensor ``\ s.
207
213
208
214
.. function :: StepOutput.concat( )
215
+ :noindex:
209
216
210
217
Returns a
211
218
``tf.Tensor ``/``torch.Tensor`` that concatenates tensors along the
212
219
batch dimension using ``tf.concat `` / ``torch.cat ``.
213
220
214
221
.. function :: StepOutput.stack( )
222
+ :noindex:
215
223
216
224
Applies ``tf.stack`` / ``torch.stack ``
217
225
operation to the list of constituent ``tf.Tensor ``\ s /
@@ -220,13 +228,15 @@ The following SageMaker distribute model parallel APIs are common across all fra
220
228
**TensorFlow-only methods **
221
229
222
230
.. function :: StepOutput.merge( )
231
+ :noindex:
223
232
224
233
Returns a ``tf.Tensor`` that
225
234
concatenates the constituent ``tf.Tensor ``\ s along the batch
226
235
dimension. This is commonly used for merging the model predictions
227
236
across microbatches.
228
237
229
238
.. function :: StepOutput.accumulate(method="variable", var=None)
239
+ :noindex:
230
240
231
241
Functionally the same as ``StepOutput.reduce_mean() ``. However, it is
232
242
more memory-efficient, especially for large numbers of microbatches,
@@ -252,6 +262,7 @@ The following SageMaker distribute model parallel APIs are common across all fra
252
262
ignored.
253
263
254
264
.. _mpi_basics :
265
+ :noindex:
255
266
256
267
MPI Basics
257
268
^^^^^^^^^^
@@ -275,6 +286,7 @@ The library exposes the following basic MPI primitives to its Python API:
275
286
replicas of the same model partition.
276
287
277
288
.. _communication_api :
289
+ :noindex:
278
290
279
291
Communication API
280
292
^^^^^^^^^^^^^^^^^
@@ -288,6 +300,7 @@ should involve.
288
300
**Helper structures **
289
301
290
302
.. data :: smp.CommGroup
303
+ :noindex:
291
304
292
305
An ``enum`` that takes the values
293
306
``CommGroup.WORLD ``, ``CommGroup.MP_GROUP ``, and ``CommGroup.DP_GROUP ``.
@@ -306,6 +319,7 @@ should involve.
306
319
themselves.
307
320
308
321
.. data:: smp.RankType
322
+ :noindex:
309
323
310
324
An ``enum`` that takes the values
311
325
``RankType.WORLD_RANK ``, ``RankType.MP_RANK ``, and ``RankType.DP_RANK ``.
@@ -321,6 +335,7 @@ should involve.
321
335
**Communication primitives: **
322
336
323
337
.. function :: smp.broadcast(obj, group)
338
+ :noindex:
324
339
325
340
Sends the object to all processes in the
326
341
group. The receiving process must call ``smp.recv_from`` to receive the
@@ -353,6 +368,7 @@ should involve.
353
368
smp.recv_from(0, rank_type=smp.RankType.WORLD_RANK)
354
369
355
370
.. function:: smp.send(obj, dest_rank, rank_type)
371
+ :noindex:
356
372
357
373
Sends the object ``obj `` to
358
374
``dest_rank ``, which is of a type specified by ``rank_type ``.
@@ -376,6 +392,7 @@ should involve.
376
392
``recv_from`` call.
377
393
378
394
.. function:: smp.recv_from(src_rank, rank_type)
395
+ :noindex:
379
396
380
397
Receive an object from a peer process. Can be used with a matching
381
398
``smp.send`` or a ``smp.broadcast`` call.
@@ -401,6 +418,7 @@ should involve.
401
418
``broadcast `` call, and the object is received.
402
419
403
420
.. function :: smp.allgather(obj, group)
421
+ :noindex:
404
422
405
423
A collective call that gathers all the
406
424
submitted objects across all ranks in the specified ``group ``. Returns a
@@ -434,6 +452,7 @@ should involve.
434
452
out = smp.allgather(obj2, smp.CommGroup.MP_GROUP ) # returns [obj1, obj2]
435
453
436
454
.. function :: smp.barrier(group=smp.WORLD)
455
+ :noindex:
437
456
438
457
A statement that hangs until all
439
458
processes in the specified group reach the barrier statement, similar to
@@ -455,12 +474,14 @@ should involve.
455
474
processes outside that ``mp_group ``.
456
475
457
476
.. function :: smp.dp_barrier()
477
+ :noindex:
458
478
459
479
Same as passing ``smp.DP_GROUP ``\ to ``smp.barrier() ``.
460
480
Waits for the processes in the same \ ``dp_group`` as
461
481
the current process to reach the same point in execution.
462
482
463
483
.. function:: smp.mp_barrier()
484
+ :noindex:
464
485
465
486
Same as passing ``smp.MP_GROUP `` to
466
487
``smp.barrier() ``. Waits for the processes in the same ``mp_group`` as
0 commit comments