From a46ced7e6d6d98a0e9c03848543c9ab0e53d7b80 Mon Sep 17 00:00:00 2001 From: Miyoung Choi Date: Wed, 22 Jun 2022 20:08:31 -0700 Subject: [PATCH 01/23] archive doc for past versions --- doc/api/training/smp_versions/archives.rst | 1 + doc/api/training/smp_versions/latest.rst | 2 +- .../v1.9.0/smd_model_parallel_common_api.rst | 538 +++++++++++ .../v1.9.0/smd_model_parallel_pytorch.rst | 678 ++++++++++++++ ...model_parallel_pytorch_tensor_parallel.rst | 875 ++++++++++++++++++ .../v1.9.0/smd_model_parallel_tensorflow.rst | 171 ++++ doc/api/training/smp_versions/v1_9_0.rst | 13 + 7 files changed, 2277 insertions(+), 1 deletion(-) create mode 100644 doc/api/training/smp_versions/v1.9.0/smd_model_parallel_common_api.rst create mode 100644 doc/api/training/smp_versions/v1.9.0/smd_model_parallel_pytorch.rst create mode 100644 doc/api/training/smp_versions/v1.9.0/smd_model_parallel_pytorch_tensor_parallel.rst create mode 100644 doc/api/training/smp_versions/v1.9.0/smd_model_parallel_tensorflow.rst create mode 100644 doc/api/training/smp_versions/v1_9_0.rst diff --git a/doc/api/training/smp_versions/archives.rst b/doc/api/training/smp_versions/archives.rst index fe893928ef..8c87476e99 100644 --- a/doc/api/training/smp_versions/archives.rst +++ b/doc/api/training/smp_versions/archives.rst @@ -3,6 +3,7 @@ .. toctree:: :maxdepth: 1 + v1_9_0.rst v1_6_0.rst v1_5_0.rst v1_4_0.rst diff --git a/doc/api/training/smp_versions/latest.rst b/doc/api/training/smp_versions/latest.rst index 49085d9347..ee606b8c34 100644 --- a/doc/api/training/smp_versions/latest.rst +++ b/doc/api/training/smp_versions/latest.rst @@ -10,7 +10,7 @@ depending on which version of the library you need to use. To use the library, reference the **Common API** documentation alongside the framework specific API documentation. -Version 1.7.0, 1.8.0, 1.8.1, 1.9.0 (Latest) +Version 1.10.0 (Latest) =========================================== To use the library, reference the Common API documentation alongside the framework specific API documentation. diff --git a/doc/api/training/smp_versions/v1.9.0/smd_model_parallel_common_api.rst b/doc/api/training/smp_versions/v1.9.0/smd_model_parallel_common_api.rst new file mode 100644 index 0000000000..b4713b2707 --- /dev/null +++ b/doc/api/training/smp_versions/v1.9.0/smd_model_parallel_common_api.rst @@ -0,0 +1,538 @@ +Common API +========== + +The following SageMaker distribute model parallel APIs are common across all frameworks. + +.. contents:: Table of Contents + :depth: 3 + :local: + +The Library's Core APIs +----------------------- + +This API document assumes you use the following import statement in your training scripts. + +**TensorFlow** + +.. code:: python + + import smdistributed.modelparallel.tensorflow as smp + +**PyTorch** + +.. code:: python + + import smdistributed.modelparallel.torch as smp + + +.. function:: smp.init( ) + :noindex: + + Initialize the library. Must be called at the beginning of training script. + +.. function:: @smp.step(non_split_inputs, input_split_axes, [*args, **kwargs]) + :noindex: + + A decorator that must be placed over a function that represents a single + forward and backward pass (for training use cases), or a single forward + pass (for evaluation use cases). Any computation that is defined inside + the ``smp.step``-decorated function is executed in a pipelined manner. + + By default, every tensor input to the function is split across its batch + dimension into a number of microbatches specified while launching the + training job. This behavior can be customized through the arguments to + ``smp.step``, described below. The library then orchestrates the execution of + each microbatch across all partitions, based on the chosen pipeline + type. + + In a typical use case, forward pass and back-propagation are executed + inside an \ ``smp.step``-decorated function and gradients, loss, and + other relevant metrics (such as accuracy, etc.) are returned from + ``smp.step``-decorated function. + + Any gradient post-processing operation, such as gradient clipping and + allreduce, as well as ``optimizer.apply_gradients`` calls (for TF) or + ``optimizer.step`` (for PT) should be applied on the gradients returned + from the ``smp.step`` function, and not inside the ``smp.step`` + function. This is because every operation inside ``smp.step`` is + executed once per microbatch, so having these operations inside + ``smp.step`` can either be inefficient (in the case of allreduce), or + lead to wrong results (in the case of ``apply_gradients`` / + ``optimizer.step``). + + If the objects returned from the ``smp.step``-decorated function contain + ``tf.Tensor``\ s / ``torch.Tensor``\ s, they are converted to + ``StepOutput`` objects. A ``StepOutput`` object encapsulates all + versions of the tensor across different microbatches + (see ``StepOutput`` entry for more information). + + The argument to ``smp.step`` decorated function should either be a tensor + or an instance of list, tuple, dict or set for it to be split across + microbatches. If your object doesn't fall into this category, you can make + the library split your object, by implementing ``smp_slice`` method. + + Below is an example of how to use it with PyTorch. + + .. code:: python + + class CustomType: + def __init__(self, tensor): + self.data = tensor + + # The library will call this to invoke slicing on the object passing in total microbatches (num_mb) + # and the current microbatch index (mb). + def smp_slice(self, num_mb, mb, axis): + dim_size = list(self.data.size())[axis] + + split_size = dim_size // num_mb + sliced_tensor = self.data.narrow(axis, mb * split_size, split_size) + return CustomType(sliced_tensor, self.other) + + custom_obj = CustomType(torch.ones(4,)) + + @smp.step() + def step(custom_obj): + loss = model(custom_obj) + model.backward(loss) + return loss + + + **Important:** ``smp.step`` splits the batch into microbatches, and + executes everything inside the decorated function once per microbatch. + This might affect the behavior of batch normalization, any operation + that explicitly uses the batch size information, or any other Python + code that is expected to run once. + + **TensorFlow-specific behavior** + + ``smp.step`` is a wrapper that + inherits from and extends the behavior of ``tf.function``, and as such, + all the caveats that apply to the use of ``tf.function``\ s also apply + to ``smp.step``. In particular, any operation that is inside + ``smp.step`` executes in graph mode, and not eager mode. + + In the first call, ``smp.step`` performs tracing of the wrapped function every time + one of the tensor arguments changes their shape or dtype, or for every + new value of a Python argument, if there is one. Tracing is expensive, + so such scenarios should be avoided as much as possible or, + alternatively, an ``input_signature`` argument must be provided. For + more information on the usage of ``tf.function``, refer to the + TensorFlow documentation: + + - https://www.tensorflow.org/api_docs/python/tf/function\ + - https://www.tensorflow.org/guide/function\ + + Each ``smp.step`` decorated function must have a return value that depends on the + output of ``smp.DistributedModel``. + + **Common parameters** + + - ``non_split_inputs`` (``list``): The list of arguments to the decorated function + that should not be split along the batch dimension. Should be used + for all input tensors that do not have a batch dimension. Should be a + list of argument names as ``str``, as they appear in the signature of + the ``smp.step``-decorated function. By default it is considered an + empty list. + + - ``input_split_axes`` (``dict``): A dict that maps the argument name to its batch + axis. The keys should be the argument names as ``str``, as they + appear in the signature of the ``smp.step``-decorated function.  By + default all batch axes are assumed to be the 0-axis. + + **TensorFlow-only parameters** + + - All arguments of ``tf.function``. Note: + The \ ``experimental_compile`` argument of ``tf.function`` may not + work as expected with ``smp.step``, since it interferes with + pipelining and model partitioning. To enable XLA with the library, you can + instead use \ ``tf.config.optimizer.set_jit(True)``. + + **PyTorch-only parameters** + + - ``detach_outputs`` (``bool``) : If ``True``, calls ``torch.Tensor.detach()`` on + all returned ``torch.Tensor`` outputs. Setting it to ``False`` + increases memory consumption, unless ``detach()`` is manually called + on the returned tensors, because the model graph is not cleared from + memory after the training step. Set to \ ``True`` by default. + + **Returns** + + - The same object(s) returned from the decorated function. All + returned \ ``tf.Tensor``, \ ``tf.Variable``  objects (for TF) or + ``torch.Tensor`` objects (for PT) are wrapped inside + a \ ``StepOutput`` object, even when they are inside a Python + ``list``, ``tuple``, or ``dict``. + + + +.. class:: StepOutput + :noindex: + + + A class that encapsulates all versions of a ``tf.Tensor`` + or \ ``torch.Tensor`` across all microbatches. + + When a particular ``tf.Tensor`` or ``torch.Tensor`` is computed inside + ``smp.step``, different versions of the tensor are computed for each + microbatch. + + When this tensor is returned from ``smp.step`` and is accessed outside + of the decorated function, it appears as a ``StepOutput`` object, which + contains all such versions. For example, + + - In the case of Tensorflow, the gradient for a particular + ``tf.Variable`` is computed on each microbatch individually, and if + this gradient is returned from ``smp.step``, all gradients for this + ``tf.Variable`` become part of the same ``StepOutput`` object. The + ``StepOutput`` class offers the following API for commonly-used + post-processing operations on such tensors. + - In the case of PyTorch, the loss for each microbatch is computed + individually and all the ``torch.Tensor``\ s that represent the loss + for different microbatches become part of same ``StepOutput`` object, + if loss is returned from the ``smp.step`` function. + + + The ``StepOutput`` class offers the following API for commonly-used + post-processing operations on tensors. + + .. data:: StepOutput.outputs + :noindex: + + Returns a list of the underlying tensors, indexed by microbatch. + + .. function:: StepOutput.reduce_mean( ) + :noindex: + + Returns a ``tf.Tensor``, ``torch.Tensor`` that averages the constituent ``tf.Tensor`` s + ``torch.Tensor`` s. This is commonly used for averaging loss and gradients across microbatches. + + .. function:: StepOutput.reduce_sum( ) + :noindex: + + Returns a ``tf.Tensor`` / + ``torch.Tensor`` that sums the constituent + ``tf.Tensor``\ s/\ ``torch.Tensor``\ s. + + .. function:: StepOutput.concat( ) + :noindex: + + Returns a + ``tf.Tensor``/``torch.Tensor`` that concatenates tensors along the + batch dimension using ``tf.concat`` / ``torch.cat``. + + .. function:: StepOutput.stack( ) + :noindex: + + Applies ``tf.stack`` / ``torch.stack`` + operation to the list of constituent ``tf.Tensor``\ s / + ``torch.Tensor``\ s. + + **TensorFlow-only methods** + + .. function:: StepOutput.merge( ) + :noindex: + + Returns a ``tf.Tensor`` that + concatenates the constituent ``tf.Tensor``\ s along the batch + dimension. This is commonly used for merging the model predictions + across microbatches. + + .. function:: StepOutput.accumulate(method="variable", var=None) + :noindex: + + Functionally the same as ``StepOutput.reduce_mean()``. However, it is + more memory-efficient, especially for large numbers of microbatches, + since it does not wait for all constituent \ ``tf.Tensor``\ s to be + ready to start averaging them, thereby saving memory. + + In some cases (XLA for example) ``StepOutput.reduce_mean()`` might end + up being more memory-efficient than ``StepOutput.accumulate()``. + + **Parameters** + + - ``method`` (``"add_n"`` or ``"accumulate_n"`` or ``"variable"``): + If ``"add_n"`` or ``"accumulate_n"``, the library uses + ``tf.add_n`` and ``tf.accumulate_n``, respectively, to implement + accumulation. If ``"variable"``, the library uses an internal ``tf.Variable`` + into which to accumulate the tensors. Default is \ ``"variable"``. + Note: Memory usage behavior of these choices can depend on the model + and implementation. + + - ``var``: A ``tf.Variable`` into which, if provided, the library uses to + accumulate the tensors. If \ ``None``, the library internally creates a + variable. If ``method`` is not ``"variable"``, this argument is + ignored. + +.. _mpi_basics: + :noindex: + +MPI Basics +---------- + +The library exposes the following basic MPI primitives to its Python API: + +**Global** + +- ``smp.rank()`` : The global rank of the current process. +- ``smp.size()`` : The total number of processes. +- ``smp.get_world_process_group()`` : + ``torch.distributed.ProcessGroup`` that contains all processes. +- ``smp.CommGroup.WORLD``: The communication group corresponding to all processes. +- ``smp.local_rank()``: The rank among the processes on the current instance. +- ``smp.local_size()``: The total number of processes on the current instance. +- ``smp.get_mp_group()``: The list of ranks over which the current model replica is partitioned. +- ``smp.get_dp_group()``: The list of ranks that hold different replicas of the same model partition. + +**Tensor Parallelism** + +- ``smp.tp_rank()`` : The rank of the process within its + tensor-parallelism group. +- ``smp.tp_size()`` : The size of the tensor-parallelism group. +- ``smp.get_tp_process_group()`` : Equivalent to + ``torch.distributed.ProcessGroup`` that contains the processes in the + current tensor-parallelism group. +- ``smp.CommGroup.TP_GROUP`` : The communication group corresponding to + the current tensor parallelism group. + +**Pipeline Parallelism** + +- ``smp.pp_rank()`` : The rank of the process within its + pipeline-parallelism group. +- ``smp.pp_size()`` : The size of the pipeline-parallelism group. +- ``smp.get_pp_process_group()`` : ``torch.distributed.ProcessGroup`` + that contains the processes in the current pipeline-parallelism group. +- ``smp.CommGroup.PP_GROUP`` : The communication group corresponding to + the current pipeline parallelism group. + +**Reduced-Data Parallelism** + +- ``smp.rdp_rank()`` : The rank of the process within its + reduced-data-parallelism group. +- ``smp.rdp_size()`` : The size of the reduced-data-parallelism group. +- ``smp.get_rdp_process_group()`` : ``torch.distributed.ProcessGroup`` + that contains the processes in the current reduced data parallelism + group. +- ``smp.CommGroup.RDP_GROUP`` : The communication group corresponding + to the current reduced data parallelism group. + +**Model Parallelism** + +- ``smp.mp_rank()`` : The rank of the process within its model-parallelism + group. +- ``smp.mp_size()`` : The size of the model-parallelism group. +- ``smp.get_mp_process_group()`` : ``torch.distributed.ProcessGroup`` + that contains the processes in the current model-parallelism group. +- ``smp.CommGroup.MP_GROUP`` : The communication group corresponding to + the current model parallelism group. + +**Data Parallelism** + +- ``smp.dp_rank()`` : The rank of the process within its data-parallelism + group. +- ``smp.dp_size()`` : The size of the data-parallelism group. +- ``smp.get_dp_process_group()`` : ``torch.distributed.ProcessGroup`` + that contains the processes in the current data-parallelism group. +- ``smp.CommGroup.DP_GROUP`` : The communication group corresponding to + the current data-parallelism group. + +.. _communication_api: + :noindex: + +Communication API +----------------- + +The library provides a few communication primitives which can be helpful while +developing the training script. These primitives use the following +``enum`` s as arguments to specify which processes the communication +should involve. +​ + +**Helper structures** + +.. data:: smp.CommGroup + :noindex: + + An ``enum`` that takes the values + ``CommGroup.WORLD``, ``CommGroup.MP_GROUP``, and ``CommGroup.DP_GROUP``. + These values can also be accessed as ``smp.WORLD``, ``smp.MP_GROUP``, + and ``smp.DP_GROUP`` respectively. + + - ``CommGroup.WORLD``: Represents the entire group of processes used in + training + - ``CommGroup.MP_GROUP``: Represents the group of processes that hold + the same model replica as the current process. The processes in a + single ``MP_GROUP`` collectively store an entire replica of the + model. + - ``CommGroup.DP_GROUP``: Represents the group of processes that hold + the same model partition as the current process. The processes in a + single ``DP_GROUP`` perform data parallelism/allreduce among + themselves. + +.. data:: smp.RankType + :noindex: + + An ``enum`` that takes the values + ``RankType.WORLD_RANK``, ``RankType.MP_RANK``, and ``RankType.DP_RANK``. + + - ``RankType.WORLD_RANK``: The associated rank is to be interpreted as + the rank of the process across all processes used in training. + - ``RankType.MP_RANK``: The associated rank is to be interpreted as the + rank of the process within the ``MP_GROUP``. + - ``RankType.DP_RANK``: The associated rank is to be interpreted as the + rank of the process within the ``DP_GROUP``. + + +**Communication primitives:** + +.. function:: smp.broadcast(obj, group) + :noindex: + + Sends the object to all processes in the + group. The receiving process must call ``smp.recv_from`` to receive the + sent object. + + **Inputs** + + - ``obj``: An arbitrary picklable Python object that will be broadcast. + + - ``group``: A ``CommGroup`` argument that represents to which group of + processes the object will be sent. + + **Notes** + + - When you use ``broadcast`` on the sender process, there needs + to be an accompanying ``smp.recv_from()`` call on the receiver + processes. + + - This is a synchronous call; the ``broadcast`` statement + returns only after all ranks participating in the call have made a + matching ``recv_from`` call. + + **Example** + + .. code:: python + + if smp.rank() == 0: +     smp.broadcast(something, group=smp.CommGroup.WORLD) + else: +     smp.recv_from(0, rank_type=smp.RankType.WORLD_RANK) + +.. function:: smp.send(obj, dest_rank, rank_type) + :noindex: + + Sends the object ``obj`` to + ``dest_rank``, which is of a type specified by ``rank_type``. + + **Inputs** + + - ``obj``: An arbitrary picklable Python object that will be sent. + + - ``dest_rank`` (``int``): An integer denoting the rank of the receiving process. + + - ``rank_type`` (``enum``): A ``smp.RankType`` ``enum`` that determines how + ``dest_rank`` is to be interpreted. For example if ``dest_rank`` is 1 + and ``rank_type`` is ``MP_RANK``, then ``obj`` is sent to process + with ``mp_rank`` 1 in the ``MP_GROUP`` which contains the current + process. + + **Notes** + + - Note: \ This is a synchronous call; the ``send`` statement returns + only after the destination rank has made a matching + ``recv_from`` call. + +.. function:: smp.recv_from(src_rank, rank_type) + :noindex: + + Receive an object from a peer process. Can be used with a matching + ``smp.send`` or a ``smp.broadcast`` call. + + **Inputs** + + - ``src_rank`` (``int``): An integer denoting rank of the sending process. + + - ``rank_type`` (``enum``): A ``smp.RankType`` ``enum`` that determines how + ``dest_rank`` is to be interpreted. For example if ``src_rank`` is 1 + and ``rank_type`` is ``MP_RANK``, then the object is received from + the process with ``mp_rank`` 1 in the ``MP_GROUP`` which contains the + current process. + + **Returns** + + Returns the python object that is sent by the peer process. + + **Notes** + + - Note: This is a synchronous call; the ``recv_from`` statement returns + only after the source rank has made a matching ``send`` or + ``broadcast`` call, and the object is received. + +.. function:: smp.allgather(obj, group) + :noindex: + + A collective call that gathers all the + submitted objects across all ranks in the specified ``group``. Returns a + list whose ``i``\ th index contains the object submitted by the + ``i``\ th rank in ``group``. + + **Inputs** + + - ``obj``: An arbitrary picklable Python object that will be + allgathered. + + - ``group`` : A ``CommGroup`` argument that represents which group of + processes participate in ``allgather``. + + **Notes** + + - Note: This is a synchronous call; the ``allgather`` statement returns + only after all ranks participating in the call have made a matching + ``allgather`` call, and all the objects are received at the current + rank. + + **Examples** + + .. code:: python + + # assuming mp_size() == 2 + + if smp.mp_rank() == 0: +     out = smp.allgather(obj1, smp.CommGroup.MP_GROUP)  # returns [obj1, obj2] + else: +     out = smp.allgather(obj2, smp.CommGroup.MP_GROUP)  # returns [obj1, obj2] + +.. function:: smp.barrier(group=smp.WORLD) + :noindex: + + A statement that hangs until all + processes in the specified group reach the barrier statement, similar to + ``MPI_Barrier()``. + + **Inputs** + + - ``group``: An ``smp.CommGroup`` ``enum`` that specifies the group of + processes participating in the barrier call. Defaults to + ``smp.WORLD``. + + **Examples** + + - Assume there are 8 processes and 2 model partitions, and + therefore 4 \ ``mp_group``\ s, and 2 ``dp_group``\ s. If + the \ ``barrier`` call is passed the value ``smp.MP_GROUP`` for its + group argument, then each process only waits until the other process + of its own ``mp_group`` reaches that point. It does not wait for + processes outside that ``mp_group``. + +.. function:: smp.dp_barrier() + :noindex: + + Same as passing ``smp.DP_GROUP``\ to ``smp.barrier()``. + Waits for the processes in the same \ ``dp_group`` as + the current process to reach the same point in execution. + +.. function:: smp.mp_barrier() + :noindex: + + Same as passing ``smp.MP_GROUP`` to + ``smp.barrier()``. Waits for the processes in the same ``mp_group`` as + the current process to reach the same point in execution. diff --git a/doc/api/training/smp_versions/v1.9.0/smd_model_parallel_pytorch.rst b/doc/api/training/smp_versions/v1.9.0/smd_model_parallel_pytorch.rst new file mode 100644 index 0000000000..055f2b6dde --- /dev/null +++ b/doc/api/training/smp_versions/v1.9.0/smd_model_parallel_pytorch.rst @@ -0,0 +1,678 @@ +PyTorch API +=========== + +To use the PyTorch-specific APIs for SageMaker distributed model parallism, +you need to add the following import statement at the top of your training script. + +.. code:: python + + import smdistributed.modelparallel.torch as smp + + +.. tip:: + + Refer to + `Modify a PyTorch Training Script + `_ + to learn how to use the following API in your PyTorch training script. + +.. class:: smp.DistributedModel + :noindex: + + A sub-class of ``torch.nn.Module`` which specifies the model to be + partitioned. Accepts a ``torch.nn.Module`` object ``module`` which is + the model to be partitioned. The returned ``DistributedModel`` object + internally manages model parallelism and data parallelism. Only one + model in the training script can be wrapped with + ``smp.DistributedModel``. + + **Example:** + + .. code:: python + + model = smp.DistributedModel(model) + + **Important**: The ``__call__`` and  ``backward`` method calls on the + ``smp.DistributedModel`` object (in the following example, the object + is \ ``model``) can only be made inside a ``smp.step``-decorated + function. + + Since ``DistributedModel``  is a ``torch.nn.Module``, a forward pass can + be performed by calling the \ ``DistributedModel`` object on the input + tensors. + + .. code:: python + + predictions = model(inputs)   # model is a smp.DistributedModel object + + For a backward pass, one needs to call the backward function on + the \ ``DistributedModel`` object, with tensors and gradients as + arguments, replacing the PyTorch operations \ ``torch.Tensor.backward`` + or ``torch.autograd.backward``. + + The API for ``model.backward`` is very similar to + ``torch.autograd.backward``. For example, the following + ``backward`` calls: + + .. code:: python + + torch.autograd.backward(loss) or loss.backward() + + should be replaced with: + + .. code:: python + + model.backward(loss) # loss is a tensor with only one element as its data + + Similarly, for non-scalar tensors, replace the following + ``backward`` call containing incoming gradient arguments: + + .. code:: python + + torch.autograd.backward(outputs, out_grads) + + with the following line: + + .. code:: python + + model.backward(outputs, out_grads) + + In these examples, all ``__call__``  and ``backward`` method calls on + the model objects (``model(inputs)`` and ``model.backward(loss)``) must be made inside + a ``smp.step``-decorated function. + + **Using DDP** + + If DDP is enabled with the SageMaker model parallel library, do not not place a PyTorch + ``DistributedDataParallel`` wrapper around the ``DistributedModel`` because + the ``DistributedModel`` wrapper will also handle data parallelism. + + Unlike the original DDP wrapper, when you use ``DistributedModel``, + model parameters and buffers are not immediately broadcast across + processes when the wrapper is called. Instead, the broadcast is deferred to the first call of the + ``smp.step``-decorated function when the partition is done. + + **Parameters** + + - ``module`` (``torch.nn.Module``): Module to be distributed (data parallelism and model parallelism). + + - ``trace_device`` (``"cpu"`` or ``"gpu"``) (default: ``"gpu"``) + Whether to perform the tracing step on the GPU or CPU. The tracing step gathers + information on the order of execution of modules, the shapes of + intermediate outputs, and execution times, to be used by the + partitioning algorithm. If ``trace_device`` is set to GPU, accurate + module execution times can be gathered during tracing for potentially + improved partitioning decision. However, if the model is too large to + fit in a single GPU, then ``trace_device`` should be set to CPU. + + - ``trace_execution_times`` (``bool``) (default: ``False``): If ``True``, + the library profiles the execution time of each module during tracing, and uses + it in the partitioning decision. This improves the partitioning + decision, but it might make the tracing slower. It may also introduce + some degree of non-determinism in partitioning results, because of the + inherent randomness in module execution times. Must be ``False`` if + ``trace_device`` is ``"cpu"``. + + - ``overlapping_allreduce`` (``bool``) (default: ``True``): This is only + applicable for hybrid data parallelism/model parallelism use cases (when + ``ddp`` is set to ``True`` while launching training). The library uses this flag + to decide whether to do overlapping allreduce whenever a parameter + gradients are ready. This leads to overlapping of communication and + computation and can improve performance. If this is set to ``False`` , + allreduce is performed at the end of the step. + + - ``backward_passes_per_step`` (``int``) (default: 1): This is only + applicable for hybrid data parallelism/model parallelism use cases (when + ``ddp`` is set to ``True`` in config). This parameter indicates the + number of backward passes to perform before calling allreduce on DDP. + This allows accumulating updates over multiple mini-batches before + reducing and applying them. + + - ``average_grads_across_microbatches`` (``bool``) (default: ``True``): + Whether or not the computed gradients should be averaged across + microbatches. If ``False``, the computed gradients will be summed across + microbatches, but not divided by the number of microbatches. In typical + use case where the computed loss is averaged over the mini-batch, this + should be left as ``True``. If you use a loss function that only sums + the per-sample loss across the batch (and not divide by the batch size), + then this must be set to ``False`` for correctness. + + - ``bucket_cap_mb`` (default: 25): \ ``DistributedDataParallel`` buckets + parameters into multiple buckets so that gradient reduction of each + bucket can potentially overlap with backward + computation. \ ``bucket_cap_mb``\ controls the bucket size in MegaBytes + (MB). + + - ``trace_memory_usage`` (default: False): When set to True, the library attempts + to measure memory usage per module during tracing. If this is disabled, + memory usage will be estimated through the sizes of tensors returned from + the module. + + - ``broadcast_buffers`` (default: True): Flag to be used with ``ddp=True``. + This parameter is forwarded to the underlying ``DistributedDataParallel`` wrapper. + Please see: `broadcast_buffer `__. + + - ``gradient_as_bucket_view`` (default: False): To be + used with ``ddp=True``. This parameter is forwarded to the underlying + ``DistributedDataParallel`` wrapper. Please see `gradient_as_bucket_view `__. + + **Properties** + + - ``partitioned``: Is ``True`` if the model is partitioned, ``False`` + otherwise. Initialized to ``False`` when ``DistributedModel`` is first + created. It becomes be ``True`` during the first call + to ``smp.step``-decorated function. Once the model is partitioned, the + local parameters or local ``state_dict`` can be fetched using the + following methods. + + **Methods** + + .. function:: backward(tensors, grad_tensors) + :noindex: + + Triggers a distributed backward + pass across model partitions. Example usage provided in the previous + section. The API is very similar + to https://pytorch.org/docs/stable/autograd.html#torch.autograd.backward. + ``retain_grad`` and ``create_graph``  flags are not supported. + + .. function:: local_buffers( ) + :noindex: + + Returns an iterator over buffers for the modules in + the partitioned model that have been assigned to the current process. + + .. function:: local_named_buffers( ) + :noindex: + + Returns an iterator over buffers for the + modules in the partitioned model that have been assigned to the current + process. This yields both the name of the buffer as well as the buffer + itself. + + .. function:: local_parameters( ) + :noindex: + + Returns an iterator over parameters for the + modules in the partitioned model that have been assigned to the current + process. + + .. function:: local_named_parameters( ) + :noindex: + + Returns an iterator over parameters for + the modules in the partitioned model that have been assigned to the + current process. This yields both the name of the parameter as well as + the parameter itself. + + .. function:: local_modules( ) + :noindex: + + Returns an iterator over the modules in the + partitioned model that have been assigned to the current process. + + .. function:: local_named_modules( ) + :noindex: + + Returns an iterator over the modules in the + partitioned model that have been assigned to the current process. This + yields both the name of the module as well as the module itself. + + .. function:: local_state_dict( ) + :noindex: + + Returns the ``state_dict`` that contains local + parameters that belong to the current \ ``mp_rank``. This ``state_dict`` + contains a key \ ``_smp_is_partial`` to indicate this is a + partial \ ``state_dict``, which indicates whether the + ``state_dict`` contains elements corresponding to only the current + partition, or to the entire model. + + .. function:: state_dict( ) + :noindex: + + Returns the ``state_dict`` that contains parameters + for the entire model. It first collects the \ ``local_state_dict``  and + gathers and merges the \ ``local_state_dict`` from all ``mp_rank``\ s to + create a full ``state_dict``. Please note that this needs to be called on all ranks with + ``dp_rank()==0`` to ensure the gather happens properly. + If it is only called on all such ranks, it can hang. + + .. function:: load_state_dict( ) + :noindex: + + Same as the ``torch.module.load_state_dict()`` , + except: It first gathers and merges the ``state_dict``\ s across + ``mp_rank``\ s, if they are partial. The actual loading happens after the + model partition so that each rank knows its local parameters. + + .. function:: register_post_partition_hook(hook) + :noindex: + + Registers a callable ``hook`` to + be executed after the model is partitioned. This is useful in situations + where an operation needs to be executed after the model partition during + the first call to ``smp.step``, but before the actual execution of the + first forward pass. Returns a ``RemovableHandle`` object ``handle``, + which can be used to remove the hook by calling ``handle.remove()``. + + .. function:: cpu( ) + :noindex: + + Allgathers parameters and buffers across all ``mp_rank``\ s and moves them + to the CPU. + + .. function:: join( ) + :noindex: + + A context manager to be used in conjunction with an instance of + ``smp.DistributedModel`` to be able to train with uneven inputs across + participating processes. This is only supported when ``ddp=True``. This will use the join with the wrapped + ``DistributedDataParallel`` instance. For more information, see: + `join `__ + in the PyTorch documentation. + + .. function:: register_comm_hook( state, callable ) + :noindex: + + **Available for PyTorch 1.8.1 only** + Registers a communication hook which is an enhancement that provides + a flexible hook ``callable`` to users where they can specify how + gradients are aggregated across multiple workers. This method will be called on the wrapped ``DistributedDataParallel`` instance. + + Please note that when you register a comm hook you have full control of how the gradients are processed. + When using only data parallelism with Torch DDP you are expected to average grads across data parallel replicas within the hook. + Similarly, when using DistributedModel you have to averaging grads across data parallel replicas within the hook. + In addition to that, you also have to average grads across microbatches within the hook unless you explicitly desire to not average based on your loss function. + See ``average_grads_across_microbatches`` for more information about averaging grads across microbatches. + + This is only supported when ``ddp=True`` and ``overlapping_allreduce=True`` (default). + For more information, see: + `register_comm_hook `__ + in the PyTorch documentation. + + **Behavior of** ``smp.DistributedModel`` **with Tensor Parallelism** + + When a model is wrapped by ``smp.DistributedModel``, the library + immediately traverses the modules of the model object, and replaces the + modules that are supported for tensor parallelism with their distributed + counterparts. This replacement happens in place. If there are no other + references to the original modules in the script, they are + garbage-collected. The module attributes that previously referred to the + original submodules now refer to the distributed versions of those + submodules. + + **Example:** + + .. code:: python + + # register DistributedSubmodule as the distributed version of Submodule + # (note this is a hypothetical example, smp.nn.DistributedSubmodule does not exist) + smp.tp_register_with_module(Submodule, smp.nn.DistributedSubmodule) + + class MyModule(nn.Module): + def __init__(self): + ... + + self.submodule = Submodule() + ... + + # enabling tensor parallelism for the entire model + with smp.tensor_parallelism(): + model = MyModule() + + # here model.submodule is still a Submodule object + assert isinstance(model.submodule, Submodule) + + model = smp.DistributedModel(model) + + # now model.submodule is replaced with an equivalent instance + # of smp.nn.DistributedSubmodule + assert isinstance(model.module.submodule, smp.nn.DistributedSubmodule) + + If ``pipeline_parallel_degree`` (equivalently, ``partitions``) is 1, the + placement of model partitions into GPUs and the initial broadcast of + model parameters and buffers across data-parallel ranks take place + immediately. This is because it does not need to wait for the model + partition when ``smp.DistributedModel`` wrapper is called. For other + cases with ``pipeline_parallel_degree`` greater than 1, the broadcast + and device placement will be deferred until the first call of an + ``smp.step``-decorated function happens. This is because the first + ``smp.step``-decorated function call is when the model partitioning + happens if pipeline parallelism is enabled. + + Because of the module replacement during the ``smp.DistributedModel`` + call, any ``load_state_dict`` calls on the model, as well as any direct + access to model parameters, such as during the optimizer creation, + should be done **after** the ``smp.DistributedModel`` call. + + Since the broadcast of the model parameters and buffers happens + immediately during ``smp.DistributedModel`` call when the degree of + pipeline parallelism is 1, using ``@smp.step`` decorators is not + required when tensor parallelism is used by itself (without pipeline + parallelism). + + For more information about the library's tensor parallelism APIs for PyTorch, + see :ref:`smdmp-pytorch-tensor-parallel`. + + **Additional Methods of** ``smp.DistributedModel`` **for Tensor Parallelism** + + The following are the new methods of ``smp.DistributedModel``, in + addition to the ones listed in the + `documentation `__. + + .. function:: distributed_modules() + :noindex: + + - An iterator that runs over the set of distributed + (tensor-parallelized) modules in the model + + .. function:: is_distributed_parameter(param) + :noindex: + + - Returns ``True`` if the given ``nn.Parameter`` is distributed over + tensor-parallel ranks. + + .. function:: is_distributed_buffer(buf) + :noindex: + + - Returns ``True`` if the given buffer is distributed over + tensor-parallel ranks. + + .. function:: is_scaled_batch_parameter(param) + :noindex: + + - Returns ``True`` if the given ``nn.Parameter`` is operates on the + scaled batch (batch over the entire ``TP_GROUP``, and not only the + local batch). + + .. function:: is_scaled_batch_buffer(buf) + :noindex: + + - Returns ``True`` if the parameter corresponding to the given + buffer operates on the scaled batch (batch over the entire + ``TP_GROUP``, and not only the local batch). + + .. function:: default_reducer_named_parameters() + :noindex: + + - Returns an iterator that runs over ``(name, param)`` tuples, for + ``param`` that is allreduced over the ``DP_GROUP``. + + .. function:: scaled_batch_reducer_named_parameters() + :noindex: + + - Returns an iterator that runs over ``(name, param)`` tuples, for + ``param`` that is allreduced over the ``RDP_GROUP``. + + + +.. class:: smp.DistributedOptimizer + :noindex: + + **Parameters** + - ``optimizer`` + + An optimizer wrapper for saving/loading optimizer states. This wrapper + returns ``optimizer`` with the following methods overridden: + + .. function:: state_dict( ) + :noindex: + + Returns the ``state_dict`` that contains optimizer state for the entire model. + It first collects the ``local_state_dict`` and gathers and merges + the ``local_state_dict`` from all ``mp_rank``s to create a full + ``state_dict``. + + .. function:: load_state_dict( ) + :noindex: + + Same as the ``torch.optimizer.load_state_dict()`` , except: + + - It first gathers and merges the local ``state_dict``\ s if they are + partial. + - The actual loading happens after the model partition so that each + rank knows its local parameters. + + .. function:: local_state_dict( ) + :noindex: + + Returns the ``state_dict`` that contains the + local optimizer state that belongs to the current \ ``mp_rank``. This + ``state_dict`` contains a key \ ``_smp_is_partial`` to indicate this is + a partial \ ``state_dict``, which indicates whether the + ``state_dict`` contains elements corresponding to only the current + partition, or to the entire model. + + ​ +.. function:: smp.partition(index) + :noindex: + + **Inputs** + + - ``index`` (int) - The index of the partition. + + A context manager which places all modules defined inside into the + partition with ID ``index``.  The ``index`` argument must be less than + the number of partitions. + + Use ``smp.partition`` to implement manual partitioning. + If ``"auto_partition"`` is ``True``, then the + ``smp.partition`` contexts are ignored. Any module that is not placed in + any ``smp.partition`` context is placed in the + ``default_partition`` defined through the SageMaker Python SDK. + + When ``smp.partition`` contexts are nested, the innermost context + overrides the rest (see the following example). In PyTorch, manual + partitioning should be done inside the module \ ``__init__``, and the + partition assignment applies to the modules that are *created* inside + the ``smp.partition`` context. + + Example: + + .. code:: python + + class Model(torch.nn.Module): +     def __init__(self): +         with smp.partition(1): +             self.child0 = Child0()            # child0 on partition 1 +             with smp.partition(2): +                 self.child1 = Child1()        # child1 on partition 2 +             self.child2 = Child2()            # child2 on partition 1 +         self.child3 = Child3()                # child3 on default_partition + +.. function:: smp.get_world_process_group( ) + :noindex: + + Returns a ``torch.distributed`` ``ProcessGroup`` that consists of all + processes, which can be used with the ``torch.distributed`` API. + Requires ``"ddp": True`` in SageMaker Python SDK parameters. + +.. function:: smp.get_mp_process_group( ) + :noindex: + + Returns a ``torch.distributed`` ``ProcessGroup`` that consists of the + processes in the ``MP_GROUP`` which contains the current process, which + can be used with the \ ``torch.distributed`` API. Requires + ``"ddp": True`` in SageMaker Python SDK parameters. + +.. function:: smp.get_dp_process_group( ) + :noindex: + + Returns a ``torch.distributed`` ``ProcessGroup`` that consists of the + processes in the ``DP_GROUP`` which contains the current process, which + can be used with the \ ``torch.distributed`` API. Requires + ``"ddp": True`` in SageMaker Python SDK parameters. + +.. function:: smp.is_initialized( ) + :noindex: + + Returns ``True`` if ``smp.init`` has already been called for the + process, and ``False`` otherwise. + +.. function::smp.is_tracing( ) + :noindex: + + Returns ``True`` if the current process is running the tracing step, and + ``False`` otherwise. + +.. data:: smp.nn.FusedLayerNorm + :noindex: + + `Apex Fused Layer Norm `__ is currently not + supported by the library. ``smp.nn.FusedLayerNorm`` replaces ``apex`` + ``FusedLayerNorm`` and provides the same functionality. This requires + ``apex`` to be installed on the system. + +.. data:: smp.optimizers.FusedNovoGrad + :noindex: + + + `Fused Novo Grad optimizer `__ is + currently not supported by the library. ``smp.optimizers.FusedNovoGrad`` replaces ``apex`` ``FusedNovoGrad`` + optimizer and provides the same functionality. This requires ``apex`` to + be installed on the system. + +.. data:: smp.optimizers.FusedLamb + :noindex: + + + `FusedLamb optimizer `__ + currently doesn’t work with the library. ``smp.optimizers.FusedLamb`` replaces + ``apex`` ``FusedLamb`` optimizer and provides the same functionality. + This requires ``apex`` to be installed on the system. + +.. data:: smp.amp.GradScaler + :noindex: + + `Torch AMP Gradscaler `__ + currently doesn’t work with the library. ``smp.amp.GradScaler`` replaces + ``torch.amp.GradScaler`` and provides the same functionality. + +.. _pytorch_saving_loading: + :noindex: + +APIs for Saving and Loading +^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. function:: smp.save( ) + :noindex: + + Saves an object. This operation is similar to ``torch.save()``, except + it has an additional keyword argument, ``partial``, and accepts only + string type for the argument ``f`` (file). If ``partial=True``, each + ``mp_rank`` saves a separate checkpoint file and the library adds an ``mp_rank`` + index to your saved file. + + **Parameters** + + - ``obj`` (dict): A saved object. + - ``f`` (str): A string containing a file name. + - ``partial`` (bool, default= ``True``):  When set to ``True``, each + ``mp_rank`` saves a separate checkpoint file and the library adds an + ``mp_rank`` index to the saved file. If you want to be able to load + and further train a model that you save with ``smp.save()``, you must + set ``partial=True``. + - ``pickle_module`` (picklemodule, default = module ``"pickle"`` from ``"/opt/conda/lib/python3.6/pickle.py"``): + A module used for pickling metadata and objects. + - ``pickle_protocol``  (int, default=2): Can be specified to + override the defaultprotocol. + +.. function:: smp.load( ) + :noindex: + + Loads an object saved with ``smp.save()`` from a file. + + Similar to, `torch.load() `__, + except it has an additional keyword argument, ``partial``, and accepts + only string type for the argument ``f`` (file). If \ ``partial=True``, + then each ``mp_rank`` loads a separate checkpoint file. + + **Parameters** + + - ``f`` (string): A string containing a file name. + - ``map_location`` (function): A function + `torch.device `__, + a string, or a dict specifying how to remap storage locations. + - ``pickle_module`` (pickle module): A module used for unpickling + metadata and objects (has to match the \ ``pickle_module``\ used to + serialize file). + - ``pickle_load_args`` (Python 3 only): Optional keyword arguments + passed to ``pickle_module.load()`` and ``pickle_module.Unpickler()``. + - ``partial`` (bool, default= ``True``): When set to ``True``, each + ``mp_rank`` loads the checkpoint corresponding to the ``mp_rank``. + Should be used when loading a model trained with the library. + +.. _pytorch_saving_loading_instructions: + :noindex: + +General Instruction For Saving and Loading +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The library can save partial or full checkpoints. + +- For partial checkpoints, each ``mp_rank`` saves its own checkpoint + file with only the parameters that belong to that rank. +- For full checkpoints, the library saves a single checkpoint that contains + entire model parameters. + +When **saving** using ``smp.save()``, each rank only holds its own +parameters. If you want to save the full model, there will be some +communication between the ranks to create the full model. If you save +checkpoints often, you should save partial checkpoints for best +performance. + +When **loading** using ``smp.load()``, the library can load either partial or | +full checkpoints or full checkpoints saved by a non-model-parallel model. If you +want to resume training with a non-model-parallel model or do inference, you need +a full checkpoint. + +The following is an example of how you can save and load a checkpoint: + +.. code:: python + + # Original model and optimizer + model = MyModel(...) + optimizer = MyOpt(...) + + # model parallel wrapper + model = smp.DistributedModel(model) + optimizer = smp.DistributedOptimizer(optimizer) + + # To save, always save on dp_rank 0 to avoid data racing + if partial: +     # To save the partial model on each mp rank +     # the library will create `checkpoint.pt_{mprank}` for each mp rank +     if save_partial_model: +         if smp.dp_rank() == 0: +             model_dict = model.local_state_dict() # save the partial model +             opt_dict = optimizer.local_state_dict() # save the partial optimizer state +             smp.save( +                 {"model_state_dict": model_dict, "optimizer_state_dict": opt_dict}, +                 f"/checkpoint.pt", +                 partial=True, +             ) + +     # To save the full model +     if save_full_model: +         if smp.dp_rank() == 0: +             model_dict = model.state_dict() # save the full model +             opt_dict = optimizer.state_dict() # save the full optimizer state +             smp.save( +                 {"model_state_dict": model_dict, "optimizer_state_dict": opt_dict}, +                 "/checkpoint.pt", +                 partial=False, +             ) + + # To load, load on all ranks. + # The only difference for partial/full loading is the partial flag in smp.load + # Load partial checkpoint + if partial_checkpoint: +    checkpoint = smp.load("/checkpoint.pt", partial=True) +    model.load_state_dict(checkpoint["model_state_dict"]) +    optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) + # Load full checkpoint + if full_checkpoint: +    checkpoint = smp.load("/checkpoint.pt", partial=False) +    model.load_state_dict(checkpoint["model_state_dict"]) +    optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) diff --git a/doc/api/training/smp_versions/v1.9.0/smd_model_parallel_pytorch_tensor_parallel.rst b/doc/api/training/smp_versions/v1.9.0/smd_model_parallel_pytorch_tensor_parallel.rst new file mode 100644 index 0000000000..851408b4b8 --- /dev/null +++ b/doc/api/training/smp_versions/v1.9.0/smd_model_parallel_pytorch_tensor_parallel.rst @@ -0,0 +1,875 @@ +.. _smdmp-pytorch-tensor-parallel: + :noindex: + +PyTorch API for Tensor Parallelism +================================== + +SageMaker distributed tensor parallelism works by replacing specific submodules +in the model with their distributed implementations. The distributed modules +have their parameters and optimizer states partitioned across tensor-parallel +ranks. This is to compute the same output as it would have been computed by +the original modules. Since tensor parallelism occurs across data-parallel +ranks, a rank might collect slices of the activations corresponding to the +data shards on other devices that are part of the same tensor parallelism group. + +You can enable or disable tensor parallelism for specific parts of the model. +Within the enabled parts, the replacements with distributed modules will take +place on a best-effort basis for those module supported for tensor parallelism. +Alternatively, you can directly import and use the library’s distributed +modules in the model definition. + +Some of the supported modules (such as ``smp.nn.Transformer``) are high-level +blocks that contain many operations. Because custom implementations +(as opposed to the built-in PyTorch modules) are typically used for these +high-level blocks, the library offers an API that you can use to register +specific distributed versions with such custom modules (provided that they +are functionally equivalent). This allows the library to automatically replace +the occurrences of such PyTorch modules with their distributed counterparts +provided by the library. +For more information, see the following topics. + +.. contents:: Topics + :depth: 3 + :local: + +.. _registering-tp-modules: + :noindex: + +Registering Tensor Parallelism Distributed Modules +-------------------------------------------------- + +Although PyTorch natively provides some of the commonly used (and +tensor-parallelizable) building blocks such as Transformer, users often +use custom implementations for such higher-level modules. To distribute +such modules with tensor parallelism, you need to register the +distributed modules to the custom module implementation in your class, +so that the library knows how to distribute the custom module. When you +register the distributed modules, make sure the custom module that you +use is functionally equivalent to the distributed module. You can verify +this by taking a look at the equivalent reference implementations in the +:ref:`smdmp-tp-appendix`. +These implementations are functionally equivalent to their distributed +versions in ``smp.nn`` module. + +.. decorator:: @smp.tp_register(dist_module, init_hook=None, forward_hook=None, return_hook=None) + + - A class decorator that registers the ``dist_module`` class with + the module class that it is attached to. The hooks can be used to + adapt to different interfaces used with ``__init__`` and + ``forward`` methods. + - **Arguments:** + + - ``dist_module``: A subclass of ``smp.nn.DistributedModule`` + that implements the distributed version of the module class the + decorator is attached to. Any distributed module class defined + in ``smp.nn`` module can be used. + - ``init_hook``: A callable that translates the arguments of the + original module ``__init__`` method to an ``(args, kwargs)`` + tuple compatible with the arguments of the corresponding + distributed module ``__init__`` method. Must return a tuple, + whose first element is an iterable representing the positional + arguments, and second element is a ``dict`` representing the + keyword arguments. The input signature of the ``init_hook`` + must **exactly** match the signature of the original + ``__init__`` method (including argument order and default + values), except it must exclude ``self``. + - ``forward_hook``: A callable that translates the arguments of + the original module ``forward`` method to an ``(args, kwargs)`` + tuple compatible with the arguments of the corresponding + distributed module ``forward`` method. Must return a tuple, + whose first element is an iterable representing the positional + arguments, and second element is a ``dict`` representing the + keyword arguments. The input signature of the ``init_hook`` + must **exactly** match the signature of the original + ``forward`` method (including argument order and default + values), except it must exclude ``self``. + - ``return_hook``: A callable that translates the object returned + from the distributed module to the return object expected of + the original module. + + - **Example:** + + .. code:: python + + init_hook = lambda config: ((), config.to_dict()) + + # register smp.nn.DistributedTransformer + # as the distributed version of MyTransformer + @smp.tp_register(smp.nn.DistributedTransformer, init_hook=init_hook) + class MyTransformer(nn.Module): + def __init__(self, config): + ... + + def forward(self, hidden_states, attention_mask): + ... + +.. function:: smp.tp_register_with_module(module_cls, dist_module, init_hook=None, forward_hook=None, return_hook=None) + :noindex: + + - When you do not have direct access to model definition code, you + can use this API to similarly register a distributed module with + an existing module class. + + - **Arguments:** + + - ``module_cls``: The existing module class that will be + distributed. + - ``dist_module``: A subclass of ``smp.nn.DistributedModule`` + that implements the distributed version of the module class the + decorator is attached to. Any distributed module class defined + in ``smp.nn`` module can be used. + - ``init_hook``: A callable that translates the arguments of the + original module ``__init__`` method to an ``(args, kwargs)`` + tuple compatible with the arguments of the corresponding + distributed module ``__init__`` method. Must return a tuple, + whose first element is an iterable representing the positional + arguments, and second element is a ``dict`` representing the + keyword arguments. The input signature of the ``init_hook`` + must **exactly** match the signature of the original + ``__init__`` method (including argument order and default + values), except it must exclude ``self``. + - ``forward_hook``: A callable that translates the arguments of + the original module ``forward`` method to an ``(args, kwargs)`` + tuple compatible with the arguments of the corresponding + distributed module ``forward`` method. Must return a tuple, + whose first element is an iterable representing the positional + arguments, and second element is a ``dict`` representing the + keyword arguments. The input signature of the ``init_hook`` + must **exactly** match the signature of the original + ``forward`` method (including argument order and default + values), except it must exclude ``self``. + - ``return_hook``: A callable that translates the object returned + from the distributed module to the return object expected of + the original module. + + - **Example:** + + .. code:: python + + from somelibrary import MyTransformer + + init_hook = lambda config: ((), config.to_dict()) + + # register smp.nn.DistributedTransformer as the distributed version of MyTransformer + smp.tp_register_with_module(MyTransformer, + smp.nn.DistributedTransformer, + init_hook=init_hook) + +.. _smdmp-supported-modules-for-tp: + :noindex: + +Supported Modules for Tensor Parallelism +---------------------------------------- + +The following modules are supported for tensor +parallelism. + +- ``smp.nn.DistributedLinear`` (implements ``nn.Linear``) +- ``smp.nn.DistributedTransformerLMHead`` +- ``smp.nn.DistributedTransformer`` +- ``smp.nn.DistributedTransformerLayer`` +- ``smp.nn.DistributedAttentionLayer`` +- ``smp.nn.DistributedTransformerOutputLayer`` +- ``smp.nn.DistributedEmbedding`` + +.. contents:: Topics + :depth: 3 + :local: + +.. _tp-module-api: + :noindex: + +Tensor Parallelism Module APIs +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. class:: smp.nn.DistributedLinear(in_features, out_features) + :noindex: + + - Tensor-parallel implementation of the ``nn.Linear`` class. + Functionally equivalent to an ``nn.Linear`` module with the same + ``in_features`` and ``out_features``. In other words, + ``in_features`` and ``out_features`` are the number of *global* + channels across tensor-parallel ranks. + - **Arguments:** + + - ``in_features``: The total number of input channels for the + linear layer across all tensor-parallel ranks. + - ``out_features``: The total number of output channels for the + linear layer across all tensor-parallel ranks. + +.. class:: smp.nn.DistributedTransformerLMHead(num_layers=12, num_attention_heads=32, attention_head_size=32, hidden_size=1024, intermediate_size=4096, vocab_size=30522, num_positions=1024, attention_dropout_prob=0.1, hidden_dropout_prob=0.1, activation="gelu", layernorm_epsilon=1e-5, num_token_types=0, causal_mask_size=None, add_cross_attention=False, add_lm_head=True, initializer_range=0.02, use_normal_initialization=False, pre_layernorm=False, post_layernorm=True) + :noindex: + + - Constructs a distributed transformer model, including embeddings + and a single LM head. A word embedding of size + ``(vocab_size, hidden_size)`` is created, as well as a positional + embedding of size ``(num_positions, hidden_size)``, and the + embeddings are added together. If ``num_token_types`` is larger + than 0, a separate embedding of size + ``(num_token_types, hidden_size)`` is created, and further added + on top. + - The embeddings are fed through a ``DistributedTransformer``, and + if ``add_lm_head`` is ``True``, the output passes through a single + LM head, which is a linear module without bias whose weight is + tied to the word embeddings. + - See :class:`smp.nn.DistributedTransformerLayer` for descriptions of the rest + of the arguments. + - **Methods:** + + - ``forward(self, inputs)`` + + - If ``add_cross_attention`` is ``True``, ``inputs`` must be a + tuple + ``(input_ids, attention_mask, token_type_ids, position_ids, cross_states, cross_states, cross_mask, labels)``. + - Otherwise, ``inputs`` must be a tuple + ``(input_ids, attention_mask, token_type_ids, position_ids, labels)``. + - If ``token_type_ids`` is ``None``, token type embedding will + not be used. + - ``input_ids`` is assumed to be of shape ``[N, S]``, where + ``N`` is the batch size and ``S`` is sequence length. + - ``attention_mask`` is assumed to be a 0-1 tensor of shape + ``[N, S]``, where 1 represents a masked position. + +.. class:: smp.nn.DistributedTransformer(num_layers=12, num_attention_heads=32, attention_head_size=32, hidden_size=1024, intermediate_size=4096, attention_dropout_prob=0.1, hidden_dropout_prob=0.1, activation="gelu", layernorm_epsilon=1e-5, initializer_range=0.02, use_normal_initialization=False, causal_mask_size=None, add_cross_attention=False, pre_layernorm=False, post_layernorm=True) + :noindex: + + - A sequence of ``smp.nn.DistributedTransformerLayer``\ s, whose + number is given by ``num_layers`` argument. For the other + arguments and methods, refer to + ``smp.nn.DistributedTransformerLayer``. + - If both ``pre_layernorm`` and ``post_layernorm`` are ``True``, + layer normalization is applied to both the input and the output of + the ``DistributedTransformer``, in addition to the intermediate + attention and transformer-output layers. + +.. class:: smp.nn.DistributedTransformerLayer(num_attention_heads=32, attention_head_size=32, hidden_size=1024, intermediate_size=4096, attention_dropout_prob=0.1, hidden_dropout_prob=0.1, activation="gelu", layernorm_epsilon=1e-5, initializer_range=0.02, use_normal_initialization=False, causal_mask_size=None, add_cross_attention=False, pre_layernorm=False, post_layernorm=True) + :noindex: + + - Tensor-parallel implementation of a single transformer layer. + Number of attention heads, hidden size, and intermediate size + refer to the global quantities across all tensor-parallel ranks. + - **Arguments:** + + - ``num_attention_heads``: The total number of attention heads + across tensor-parallel ranks + - ``attention_head_size``: The number of channels of a single + attention head. + - ``hidden_size``: The hidden dimension of the transformer. The + input tensor ``hidden_states`` is assumed to have its last + dimension size equal to ``hidden_size``. + - ``intermediate_size``: The number of output channels in the + first linear transformation of the transformer output layer. + ``DistributedTransformerOutputLayer`` first maps + ``hidden_size`` dimensions of its input tensor into + ``intermediate_size`` dimensions, and then maps it back into + ``hidden_size`` dimensions. + - ``attention_dropout_prob``: The dropout probability applied to + the attention probabilities. + - ``hidden_dropout_prob``: The dropout probability used in + dropout layers other than the one applied to the attention + probabilities. + - ``activation``: Choice of activation function to use at the + output layer. Must be ``"gelu"`` or ``"relu"``. + - ``layernorm_epsilon``: The epsilon added to the denominator of + layer normalization for numerical stability. + - ``initializer_range``: If ``use_normal_initialization`` is + ``True``, the standard deviation of the normal random variable + to initialize the weights with. + - ``use_normal_initialization``: If ``True``, the weights are + initialized with normal distribution with standard deviation + given by ``initializer_range``. Otherwise, default PyTorch + initialization is used. + - ``causal_mask_size``: If ``None``, no causal mask is used on + attentions. Otherwise, should be set to maximum sequence length + to apply a causal mask to the attention scores. This is used, + for instance, in GPT-2. + - ``add_cross_attention``: If ``True``, a cross-attention layer + will be added after the self-attention block. The + cross-attention layer computes the attention keys and values + based on the ``cross_states`` input (instead of + ``hidden_states`` input, as in self-attention. This is used in + the decoder block of encoder-decoder architectures. For + encoder-only architectures that only use self-attention, this + should be kept ``False``. + - ``pre_layernorm``: If ``True``, inserts layer normalization at + the input. At least one of ``pre_layernorm`` and + ``post_layernorm`` must be ``True``. + - ``post_layernorm``: If ``True``, inserts layer normalization at + the output. At least one of ``pre_layernorm`` and + ``post_layernorm`` must be ``True``. + + - **Methods:** + + - ``forward(self, inputs)``: Forward pass for the transformer + layer. + + - **Arguments:** + + - If ``add_cross_attention=False``, ``inputs`` must be a + tuple ``(hidden_states, attention_mask)``, where + ``hidden_states`` is assumed to be a tensor of dimensions + ``[N, S, H]``, where ``N`` is batch size, ``S`` is + sequence length, and ``H`` is ``hidden_size``. + ``attention_mask`` is assumed to be a tensor of + dimensions ``[N, 1, 1, S]``, where ``N`` is the batch + size, and ``S`` is the sequence length. + - If ``add_cross_attention=True``, ``inputs`` must be a + tuple + ``(hidden_states, cross_states, attention_mask, cross_mask)``, + where ``hidden_states`` is assumed to be a tensor of + dimensions ``[N, S_1, H]``, where ``N`` is batch size, + ``S_1`` is sequence length, and ``H`` is ``hidden_size``. + ``cross_states`` is assumed to be a tensor of size + ``[N, S_2, H]``, similarly interpreted. + ``attention_mask`` is assumed to be a tensor of + dimensions ``[N, 1, 1, S_1]``, where ``N`` is the batch + size, and ``S_1`` is the sequence length, and + ``cross_mask`` is assumed to be a tensor of size + ``[N, 1, 1, S_2]``. Keys and values for the attention + heads in the cross-attention layer (but not the + self-attention layer) are computed using + ``cross_states``, and ``cross_mask`` is applied as the + attention mask in the cross-attention layer (but not the + self-attention layer). + + - **Returns:** + + - If ``add_cross_attention=False``, a tuple + ``(hidden_states, attention_mask)``, where + ``hidden_states`` is the output of the transformer, and + ``attention_mask`` is the same the ``attention_mask`` + argument. + - If ``add_cross_attention=True``, a tuple + ``(hidden_states, cross_states, attention_mask, cross_mask)``, + where ``hidden_states`` is the output of the transformer, + and the next three tensors are the same as the input + arguments. + +.. class:: smp.nn.DistributedAttentionLayer(num_attention_heads=32, attention_head_size=32, hidden_size=1024, attention_dropout_prob=0.1, hidden_dropout_prob=0.1, layernorm_epsilon=1e-5, initializer_range=0.02, use_normal_initialization=False, cross_attention=False, causal_mask_size=None, pre_layernorm=False, post_layernorm=True) + :noindex: + + - A distributed implementation for the attention block. Includes the + computation of the self- or cross-attention (context layer), + followed by a linear mapping and dropout, which is optionally + followed by the residual-connection and layer normalization. + - **Arguments:** + + - See :class:`smp.nn.DistributedTransformerLayer` for descriptions of the + arguments. + - ``cross_attention``: If ``True``, it computes the attentions + with respect to the ``cross_states`` tensor of the ``forward`` + method input tuple. (Default: ``False``) + + - **Methods:** + + - ``forward(self, inputs)``: Forward pass for the attention + layer. + + - **Arguments:** + + - If ``cross_attention=False``, ``inputs`` must be a tuple + ``(hidden_states, attention_mask)``, where + ``hidden_states`` is assumed to be a tensor of dimensions + ``[N, S, H]``, where ``N`` is batch size, ``S`` is + sequence length, and ``H`` is ``hidden_size``. + ``attention_mask`` is assumed to be a tensor of + dimensions ``[N, 1, 1, S]``, where ``N`` is the + batch size, and ``S`` is the sequence length. + - If ``cross_attention=True``, ``inputs`` must be a tuple + ``(hidden_states, cross_states, attention_mask)``, where + ``hidden_states`` is assumed to be a tensor of dimensions + ``[N, S_1, H]``, where ``N`` is batch size, ``S_1`` is + sequence length, and ``H`` is ``hidden_size``. + ``cross_states`` is assumed to be a tensor of size + ``[N, S_2, H]``, similarly interpreted. + ``attention_mask`` is assumed to be a tensor of + dimensions ``[N, 1, 1, S_2]``, where ``N`` is the batch + size, and ``S_2`` is the sequence length. Keys and values + for the attention heads are computed using + ``cross_states``. + + - **Returns:** + + - A single tensor that is the output of the attention + layer. + +.. class:: smp.nn.DistributedTransformerOutputLayer(hidden_size=1024, intermediate_size=4096, hidden_dropout_prob=0.1, activation="gelu", layernorm_epsilon=1e-5, initializer_range=0.02, use_normal_initialization=False, pre_layernorm=False, post_layernorm=True, fp32_residual_addition=False) + :noindex: + + - Distributed implementation of a single transformer output layer. A + single :class:`smp.nn.DistributedTransformerLayer` with + ``add_cross_attention=False`` consists of a single + ``DistributedAttentionLayer`` immediately followed by a single + ``DistributedTransformerOutputLayer``. The latter linearly maps + the last channel of the input tensor from ``hidden_size`` to + ``intermediate_size``, and then maps it back to ``hidden_size``. + - **Arguments:** + + - See :class:`smp.nn.DistributedTransformerLayer` for descriptions of the + arguments. + - ``fp32_residual_addition``: Set to ``True`` if you want to avoid overflow + (NaN loss values) for large models with more than 100 billion parameters + when using FP16. (Default: False) + +.. class:: smp.nn.DistributedEmbedding(num_embeddings,embedding_dim, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False, _weight=None, initializer_range=0.02, _skip_allgather=False,_skip_scatter_and_merge=False,) + :noindex: + + - Distributed implementation of a single Embedding Layer. Currently + only supports splitting across the embedding_dim. + - **Arguments:** + + - See :class:`smp.nn.DistributedEmbedding` for descriptions of the + arguments. + +.. _enabling-tp: + :noindex: + +Enabling Tensor Parallelism +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +There are two ways tensor parallelism can be enabled. + +First, you can use +the distributed module implementations in ``smp.nn`` module directly in +your model definition. See :ref:`smdmp-supported-modules-for-tp` +for a complete list of built-in distributed modules. Here is an example +of how this can be done: + +.. code:: python + + import torch.nn as nn + import smdistributed.modelparallel.torch as smp + + class TransformerModel: + def __init__(self): + self.embedding = nn.Embedding(vocab_size, hidden_size) + + # directly instantiate smp.nn.DistributedTransformer and use it + self.encoder = smp.nn.DistributedTransformer(num_layers, hidden_size, **kwargs) + + self.pooler = nn.Linear(hidden_size, hidden_size) + + def forward(self, hidden_states): + emb_out = self.embedding(hidden_states) + enc_out = self.encoder(emb_out) + return self.pooler(enc_out) + +Second, you can enable tensor parallelism for specific modules or blocks +of code, which will automatically enable tensor parallelism for the +supported modules within that scope. To do this, you can use the +following API: + +.. decorator:: smp.tensor_parallelism(enabled=True, **kwargs) + + - A context manager that enables or disables tensor parallelism for + any supported module that is created inside. If there are nested + contexts, the innermost overrides the rest. If there are + multiple supported modules created within the context, where one + is the submodule of the other, only the outermost module will be + distributed. If a supported module shares weights with another + (supported or unsupported) module, or if its hyperparameters do + not support distribution (e.g., not divisible by the tensor + parallelism degree), tensor parallelism will **not** be enabled + for this module even if this API is used. + + **Example:** + + .. code:: python + + with smp.tensor_parallelism(): + self.m0 = nn.Linear(20, 20) # will be distributed + with smp.tensor_parallelism(enabled=False): + self.m1 = nn.Linear(20, 20) # will not be distributed + + - ``kwargs`` - Keyword arguments that can be used to modify the configurations of + the distributed modules created inside the context. + If a keyword argument provided through it matches any ``__init__`` method arguments + of a ``DistributedModule`` that substitutes a module created inside + the ``smp.tensor_parallelism`` context, this keyword will override + the value defined in the ``init_hook``. + + - (*For v1.7.0 and later*) Through the following additional keyword arguments, + the library supports `NVIDIA Megatron’s fused kernels + `_ + + - ``fused_softmax`` (bool) - Fusion of attention masking and softmax. + By default, it is set to ``True``. You can deactivate it by setting + ``fused_softmax=False`` in the ``smp.tensor_parallelism`` context manager. + - ``fused_bias_gelu`` (bool) - Fusion of bias addition and Gelu activation. + By default, it is set to ``False``. You can activate it by setting + ``fused_bias_gelu=True`` in the ``smp.tensor_parallelism`` context manager. + + + +.. function:: smp.set_tensor_parallelism(module, enabled=True, **kwargs) + :noindex: + + - Enables or disables tensor parallelism for the supported + submodules of ``module``. If enabling, the outermost supported + modules will be distributed. If disabling, tensor parallelism will + be disabled for the entire module subtree of ``module``. Unlike + the context manager, this API can be used after the model creation + (but before wrapping with :class:`smp.DistributedModel`), so direct + access to model definition code is not required. If a supported + module shares weights with another (supported or unsupported) + module, or if its hyperparameters do not support distribution + (e.g., not divisible by the tensor parallelism degree), tensor + parallelism will **not** be enabled for this module. + - Keyword arguments ``kwargs`` can be used to modify the + configurations of the distributed modules created inside the + context. If a keyword argument provided here matches any + ``__init__`` method arguments of a :class:`smp.DistributedModel` that + substitutes a module created inside the ``smp.tensor_parallelism`` + context, this keyword will override the value defined in the + ``init_hook``. + - **Example:** + + .. code:: python + + model = MyModel() + smp.set_tensor_parallelism(model.encoder, True) + smp.set_tensor_parallelism(model.encoder.embedding, True) + + # outermost supported submodules in model.encoder will be distributed, except for + # model.encoder.embedding + model = smp.DistributedModel(model) + optimizer = smp.DistributedOptimizer(optimizer) + +.. _activation-checkpointing-api: + :noindex: + +Activation Checkpointing APIs +----------------------------- + +``smdistributed.modelparallel`` provides three APIs to enable +activation checkpointing: one for checkpointing modules, +one for checkpointing sequential modules, and +one for checkpointing pretrained models. + +For a conceptual guide and examples, see +`Activation Checkpointing `_ +in the *SageMaker's Distributed Model Parallel developer guide*. + +.. class:: smdistributed.modelparallel.torch.patches.checkpoint.checkpoint(module, *args, preserve_rng_state=True) + :noindex: + + - Checkpoints the module passed. Throws error if, during manual + partitioning, all children of module are not on same rank as the + module itself, i.e. the module tree is split across multiple + partitions. During auto-partitioning, if the module is split + across multiple partitions, then this call is ignored(with a + warning). Note that this call applies to the module instance only, + not to the module class. + + - **Arguments:** + + - ``module (Instance of nn.Module)``: The module to be + checkpointed. Note that unlike native checkpointing in + PyTorch’s, activation checkpointing in + ``smdistributed.modelparallel`` is at the granularity of a + module. A generic function cannot be passed here. + - ``args``: Tuple containing inputs to the module. + - ``preserve_rng_state (bool, default=True)``: Omit stashing and + restoring the RNG state during each checkpoint. + +.. class:: smdistributed.modelparallel.torch.patches.checkpoint.checkpoint_sequential(sequential_module, input, strategy="each", preserve_rng_state=True, pack_args_as_tuple=False) + :noindex: + + - Checkpoints the modules inside + `nn.Sequential `__. + This can be used even if different layers that are part of the + sequential container lie on different partitions. Each layer part + of the sequential module that is checkpointed must lie completely + within one partition. If this is not the case during manual + partitioning, then an error will be thrown. If this is not the + case during auto partitioning, a warning will be raised and this + module will be run without checkpointing. + + - **Arguments** + + - ``sequential_module (nn.Sequential)``: the sequential module to + be checkpointed. + - ``input (torch.Tensor or a tuple of torch.Tensors)``: input to + the module, which can be a tensor or a tuple of tensors. If a + tuple is passed, then pack_args_as_tuple should be set to True. + - ``strategy (string, default=“each”)`` : Strategy determines how + many layers part of the sequential module need to be grouped + together for one checkpointing call. This determines how much + memory can be reduced. It can take the following values + + - ``each`` : The default is to checkpoint each module inside + the sequential separately. + - ``contiguous``: Groups consecutive layers on the same + partition together. For example, if a sequential consists of + [a, b, c, d] where a,b are on pp_rank0 and c,d are on + pp_rank 1, then this strategy would checkpoint a,b together + and then c,d together. This means effectively, inputs of a, + outputs of b, inputs of c, and outputs of d are in memory; + the reamining activations are recomputed. + - ``group_2, group_3, group_4, etc:`` More generally, + ``group_x`` where x is an integer. This strategy provides + more flexibility in how many layers to group together. + ``group_x`` groups x layers together on a best effort basis. + It can group x layers together if there are x layers + consecutively on the same partition. For example: + [a,b,c,d,e] where a,b are on pp_rank0 and c,d,e are on + pp_rank 1. If the strategy is ``group_3,`` then a,b are + checkpointed together on pp_rank0 and c,d,e are checkpointed + together on pp_rank1. + + - ``preserve_rng_state (bool, default=True)``: Set to ``False`` + to omit stashing and restoring the RNG state during each + checkpoint. + - ``pack_args_as_tuple (bool, default=False)``: To ensure that + backward works correctly, the autograd function has to unpack + any tuples received. If the checkpointed layer takes a tuple as + input, then this needs to be set to True. + +.. class:: smp.set_activation_checkpointing(module, preserve_rng_state=True, pack_args_as_tuple=False, strategy="each") + :noindex: + + - This API is recommended when importing pretrained models from + libraries, such as PyTorch and Hugging Face Transformers. This is + particularly useful when you don’t have access to the model + definition code and not be able to replace a module call with + checkpoint. + + - **Arguments**: + + - ``module (Instance of nn.Module or nn.Sequential)``: The module + to checkpoint. + - ``preserve_rng_state (bool, default=True)``: Set to ``False`` + to omit stashing and restoring the RNG state during each + checkpoint. + - ``pack_args_as_tuple (bool, default=False)``: *Can only be + passed when module is a sequential module.* To ensure that + backward works correctly, the autograd function has to unpack + any tuples received. If the layer checkpointed takes a tuple as + input, then this needs to be set to True. + - ``strategy: (string, default=“each”)``: *Can only be passed + when module is a sequential module.* Strategy determines how + many layers part of the sequential module need to be grouped + together for one checkpointing call. + - This determines how much memory can be reduced. It can take the + following values + + - ``each`` : The default is to checkpoint each module inside + the sequential separately. + - ``contiguous``: Groups consecutive layers on the same + partition together. For example if a sequential consists of + ``[a, b, c, d]`` where ``a, b`` are on ``pp_rank0`` and ``c, d`` are on + ``pp_rank 1``, then this strategy would checkpoint a,b together + and then ``c, d`` together. This means effectively, the inputs of + ``a``, outputs of ``b``, inputs of ``c``, and outputs of ``d`` are in + memory, and the rest of the activations are recomputed. + - ``group_2, group_3, group_4, etc:`` More generally, + ``group_x`` where x is an integer. This strategy provides + more flexibility in how many layers to group together. + ``group_x`` groups x number of layers together on a best + effort basis if there are x layers consecutively in the same + partition. **Example**: Assume a module with layers ``[a, b, + c, d, e]``. The layers a and b are on pp_rank0, and ``c``, ``d``, and + ``e`` are on ``pp_rank 1``. If the strategy is ``group_3,`` then ``a``, + ``b`` are checkpointed together on ``pp_rank0``, and ``c``, ``d``, ``e`` are + checkpointed together on ``pp_rank1``. + +.. _smdmp-tp-appendix: + :noindex: + +Appendix: Reference Implementations for Modules +----------------------------------------------- + +The following are reference implementations for transformer-related +modules. Note that this is not the actual ``smdistributed`` source code, +but the distributed implementations provided in the library are the +distributed versions of these reference implementations, and can be used +to determine whether the distributed modules perform the same operations +as the custom modules in your script. + +To keep the implementations simple, we only assume keyword arguments, +and assume the existence of a method ``parse_args(kwargs)``, which +parses the arguments to ``__init__`` methods and sets the relevant +attributes of the module, such as ``hidden_size`` and +``num_attention_heads``. + +``smp.nn.DistributedTransformer`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code:: python + + class Transformer(nn.Module): + def __init__(self, **kwargs): + super(Transformer, self).__init__() + self.parse_args(kwargs) + + self.layers = [] + for l in range(self.num_layers): + self.layers.append(TransformerLayer(**kwargs)) + + self.seq_layers = nn.Sequential(*self.layers) + + def forward(self, inp): + return self.seq_layers(inp) + +``smp.nn.DistributedTransformerLayer`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code:: python + + class TransformerLayer(nn.Module): + def __init__(self, **kwargs): + super(TransformerLayer, self).__init__() + self.parse_args(kwargs) + + self.attention = AttentionLayer(**kwargs) + self.output = TransformerOutputLayer(**kwargs) + + if self.add_cross_attention: + self.cross_attention = AttentionLayer(cross_attention=True, **kwargs) + + def forward(self, inp): + if self.add_cross_attention: + hidden_states, cross_states, attention_mask, cross_mask = inp + else: + hidden_states, attention_mask = inp + + attention_output = self.attention((hidden_states, attention_mask)) + if self.add_cross_attention: + attention_output = self.cross_attention((attention_output, + cross_states, + cross_mask)) + + output = self.output(attention_output) + + if self.add_cross_attention: + return output, cross_states, attention_mask, cross_mask + else: + return output, attention_mask + +``smp.nn.DistributedAttentionLayer`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code:: python + + class AttentionLayer(nn.Module): + def __init__(self, **kwargs): + super(AttentionLayer, self).__init__() + self.parse_args(kwargs) + self.attention_head_size = self.hidden_size // self.num_attention_heads + + self.query = nn.Linear(self.hidden_size, self.hidden_size) + self.key = nn.Linear(self.hidden_size, self.hidden_size) + self.value = nn.Linear(self.hidden_size, self.hidden_size) + self.dense = nn.Linear(self.hidden_size, self.hidden_size) + + self.dropout1 = nn.Dropout(self.attention_dropout_prob) + self.dropout2 = nn.Dropout(self.hidden_dropout_prob) + + if self.pre_layernorm: + self.pre_layernorm = nn.LayerNorm(self.hidden_size, + eps=self.layernorm_epsilon) + + if self.post_layernorm: + self.layernorm = nn.LayerNorm(self.hidden_size, + eps=self.layernorm_epsilon) + + def transpose(self, tensor, key=False): + shape = tensor.size()[:-1] + + (self.num_attention_heads, self.attention_head_size) + tensor = torch.reshape(tensor, shape) + if key: + return tensor.permute(0, 2, 3, 1) + else: + return tensor.permute(0, 2, 1, 3) + + def forward(self, inp): + if self.cross_attention: + hidden_states, cross_states, attention_mask = inp + else: + hidden_states, attention_mask = inp + + if self.pre_layernorm: + norm_states = self.pre_layernorm(hidden_states) + else: + norm_states = hidden_states + + query_layer = self.query(norm_states) + + if self.cross_attention: + key_layer = self.key(cross_states) + value_layer = self.value(cross_states) + else: + key_layer = self.key(norm_states) + value_layer = self.value(norm_states) + + query_layer = self.transpose(query_layer) + key_layer = self.transpose(key_layer, key=True) + value_layer = self.transpose(value_layer) + + attention_scores = torch.matmul(query_layer, key_layer) + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + + if not self.cross_attention and self.causal_mask is not None: + attention_scores = self.apply_causal_mask(attention_scores) + + attention_scores = attention_scores + attention_mask + + attention_probs = F.softmax(attention_scores, dim=-1) + attention_probs = self.dropout1(attention_probs) + + context_layer = torch.matmul(attention_probs, value_layer) + context_layer = context_layer.permute(0, 2, 1, 3) + new_context_layer_shape = context_layer.size()[:-2] + \ + (self.local_attention_size,) + context_layer = torch.reshape(context_layer, new_context_layer_shape) + + self_attention = self.dense(context_layer) + self_attention = self.dropout2(self_attention) + + if self.post_layernorm: + return self.layernorm(self_attention + hidden_states) + else: + return self_attention + +``smp.nn.DistributedTransformerOutputLayer`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code:: python + + class TransformerOutputLayer(nn.Module): + def __init__(self, **kwargs): + super(TransformerOutputLayer, self).__init__() + self.parse_args(kwargs) + + self.dense1 = nn.Linear(self.hidden_size, self.intermediate_size) + self.dense2 = nn.Linear(self.intermediate_size, self.hidden_size) + + self.dropout = nn.Dropout(self.attention_dropout_prob) + + if self.pre_layernorm: + self.pre_layernorm = nn.LayerNorm(self.hidden_size, + eps=self.layernorm_epsilon) + + if self.post_layernorm: + self.layernorm = nn.LayerNorm(self.hidden_size, + eps=self.layernorm_epsilon) + + def forward(self, inp): + if self.pre_layernorm: + norm_inp = self.pre_layernorm(inp) + else: + norm_inp = inp + + dense1_output = self.dense1(norm_inp) + if self.activation == "gelu": + act_output = F.gelu(dense1_output) + else: + act_output = F.relu(dense1_output) + + dense2_output = self.dense2(act_output) + output = self.dropout(dense2_output) + + if self.post_layernorm: + return self.layernorm(inp + output) + else: + return output diff --git a/doc/api/training/smp_versions/v1.9.0/smd_model_parallel_tensorflow.rst b/doc/api/training/smp_versions/v1.9.0/smd_model_parallel_tensorflow.rst new file mode 100644 index 0000000000..54ec558fc5 --- /dev/null +++ b/doc/api/training/smp_versions/v1.9.0/smd_model_parallel_tensorflow.rst @@ -0,0 +1,171 @@ +TensorFlow API +============== + +To use the TensorFlow-specific APIs for SageMaker distributed model parallism, +you need to add the following import statement at the top of your training script. + +.. code:: python + + import smdistributed.modelparallel.tensorflow as smp + +.. tip:: + + Refer to + `Modify a TensorFlow Training Script + `_ + to learn how to use the following APIs in your TensorFlow training script. + +.. class:: smp.DistributedModel + :noindex: + + A sub-class of the Keras \ ``Model`` class, which defines the model to + be partitioned. Model definition is done by sub-classing + ``smp.DistributedModel`` class, and implementing the ``call()`` method, + in the same way as the Keras model sub-classing API. Any operation that + is part of the \ ``smp.DistributedModel.call()`` method is subject to + partitioning, meaning that every operation placed inside executes in + exactly one of the devices (the operations outside run on all devices). + + + Similar to the regular Keras API, the forward pass is done by directly + calling the model object on the input tensors. For example: + + .. code:: python + + predictions = model(inputs)   # model is a smp.DistributedModel object + + However, ``model()`` calls can only be made inside a + ``smp.step``-decorated function. + + The outputs from a ``smp.DistributedModel`` are available in all ranks, + regardless of which rank computed the last operation. + + **Methods:** + + .. function:: save_model(save_path="/opt/ml/model") + :noindex: + + **Inputs** + - ``save_path`` (``string``): A path to save an unpartitioned model with latest training weights. + + Saves the entire, + unpartitioned model with the latest trained weights to ``save_path`` in + TensorFlow ``SavedModel`` format. Defaults to ``"/opt/ml/model"``, which + SageMaker monitors to upload the model artifacts to Amazon S3. + +.. function:: smp.partition(index) + :noindex: + + **Inputs** + + - ``index`` (``int``): The index of the partition. + + A context manager which places all operations defined inside into the + partition whose ID is equal to ``index``. When + ``smp.partition`` contexts are nested, the innermost context overrides + the rest. The ``index`` argument must be smaller than the number of + partitions. + + ``smp.partition`` is used in the manual partitioning API; + if \ ``"auto_partition"`` parameter is set to ``True`` while launching + training, then ``smp.partition`` contexts are ignored. Any operation + that is not placed in any ``smp.partition`` context is placed in the + ``default_partition``, as shown in the following example: + + .. code:: python + + # auto_partition: False + # default_partition: 0 + smp.init() + [...] + x = tf.constant(1.2)                     # placed in partition 0 + with smp.partition(1): +     y = tf.add(x, tf.constant(2.3))      # placed in partition 1 +     with smp.partition(3): +         z = tf.reduce_sum(y)             # placed in partition 3 + + +.. function:: register_post_partition_hook(hook) + :noindex: + + Registers a callable ``hook`` to + be executed after the model is partitioned. This is useful in situations + where an operation needs to be executed after the model partition during + the first call to ``smp.step``, but before the actual execution of the + first forward pass. + + .. code:: python + + @smp.register_post_partition_hook + def test_eager(): + # All statements here will be executed right after partition but before the first forward pass + tf.print("Entered hook through eager context") + +.. class:: smp.CheckpointManager + :noindex: + + + A subclass of TensorFlow + `CheckpointManager `__, + which is used to manage checkpoints. The usage is similar to TensorFlow + ``CheckpointManager``. + + The following returns a ``CheckpointManager`` object. + + .. code:: python + + smp.CheckpointManager(checkpoint, +                       directory="/opt/ml/checkpoints", +                       max_to_keep=None, +                       checkpoint_name="ckpt") + + **Parameters** + + - ``checkpoint``: A `tf.train.Checkpoint + `__ instance + that represents a model checkpoint. + + - ``directory``: (``str``) The path to a directory in which to write + checkpoints. A file named "checkpoint" is also written to this + directory (in a human-readable text format) which contains the state + of the ``CheckpointManager``. Defaults to + ``"/opt/ml/checkpoints"``, which is the directory that SageMaker + monitors for uploading the checkpoints to Amazon S3. + - ``max_to_keep`` (``int``): The number of checkpoints to keep. If + ``None``, all checkpoints are kept. + - ``checkpoint_name`` (``str``): Custom name for the checkpoint file. + Defaults to ``"ckpt"``. + + + **Methods:** + + .. function:: save( ) + :noindex: + + Saves a new checkpoint in the specified directory. Internally uses ``tf.train.CheckpointManager.save()``. + + .. function:: restore( ) + :noindex: + + Restores the latest checkpoint in the specified directory. + Internally uses ``tf.train.CheckpointManager.restore()``. + + + **Examples:** + + .. code:: python + + checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model) + ckpt_manager = smp.CheckpointManager(checkpoint, max_to_keep=5)  # use /opt/ml/checkpoints + + for inputs in train_ds: +     loss = train_step(inputs) +     # [...] +     ckpt_manager.save()  # save a new checkpoint in /opt/ml/checkpoints + + .. code:: python + + for step, inputs in enumerate(train_ds): +     if step == 0: +         ckpt_manager.restore() +     loss = train_step(inputs) diff --git a/doc/api/training/smp_versions/v1_9_0.rst b/doc/api/training/smp_versions/v1_9_0.rst new file mode 100644 index 0000000000..e2e9acd83a --- /dev/null +++ b/doc/api/training/smp_versions/v1_9_0.rst @@ -0,0 +1,13 @@ + +Version 1.7.0, 1.8.0, 1.8.1, 1.9.0 +================================== + +To use the library, reference the Common API documentation alongside the framework specific API documentation. + +.. toctree:: + :maxdepth: 1 + + v1.9.0/smd_model_parallel_common_api + v1.9.0/smd_model_parallel_pytorch + v1.9.0/smd_model_parallel_pytorch_tensor_parallel + v1.9.0/smd_model_parallel_tensorflow From cbac1fad4a9e23254f4d4e841401b3c49d29bffc Mon Sep 17 00:00:00 2001 From: Miyoung Choi Date: Wed, 22 Jun 2022 20:18:25 -0700 Subject: [PATCH 02/23] fix indexing --- .../v1.9.0/smd_model_parallel_pytorch.rst | 19 +++++++++---------- ...model_parallel_pytorch_tensor_parallel.rst | 1 + .../v1.9.0/smd_model_parallel_tensorflow.rst | 2 +- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/doc/api/training/smp_versions/v1.9.0/smd_model_parallel_pytorch.rst b/doc/api/training/smp_versions/v1.9.0/smd_model_parallel_pytorch.rst index 055f2b6dde..88d1a42165 100644 --- a/doc/api/training/smp_versions/v1.9.0/smd_model_parallel_pytorch.rst +++ b/doc/api/training/smp_versions/v1.9.0/smd_model_parallel_pytorch.rst @@ -17,7 +17,7 @@ you need to add the following import statement at the top of your training scrip to learn how to use the following API in your PyTorch training script. .. class:: smp.DistributedModel - :noindex: + :noindex: A sub-class of ``torch.nn.Module`` which specifies the model to be partitioned. Accepts a ``torch.nn.Module`` object ``module`` which is @@ -362,45 +362,45 @@ you need to add the following import statement at the top of your training scrip `documentation `__. .. function:: distributed_modules() - :noindex: + :noindex: - An iterator that runs over the set of distributed (tensor-parallelized) modules in the model .. function:: is_distributed_parameter(param) - :noindex: + :noindex: - Returns ``True`` if the given ``nn.Parameter`` is distributed over tensor-parallel ranks. .. function:: is_distributed_buffer(buf) - :noindex: + :noindex: - Returns ``True`` if the given buffer is distributed over tensor-parallel ranks. .. function:: is_scaled_batch_parameter(param) - :noindex: + :noindex: - Returns ``True`` if the given ``nn.Parameter`` is operates on the scaled batch (batch over the entire ``TP_GROUP``, and not only the local batch). .. function:: is_scaled_batch_buffer(buf) - :noindex: + :noindex: - Returns ``True`` if the parameter corresponding to the given buffer operates on the scaled batch (batch over the entire ``TP_GROUP``, and not only the local batch). .. function:: default_reducer_named_parameters() - :noindex: + :noindex: - Returns an iterator that runs over ``(name, param)`` tuples, for ``param`` that is allreduced over the ``DP_GROUP``. .. function:: scaled_batch_reducer_named_parameters() - :noindex: + :noindex: - Returns an iterator that runs over ``(name, param)`` tuples, for ``param`` that is allreduced over the ``RDP_GROUP``. @@ -512,6 +512,7 @@ you need to add the following import statement at the top of your training scrip .. function::smp.is_tracing( ) :noindex: + :noindex: Returns ``True`` if the current process is running the tracing step, and ``False`` otherwise. @@ -527,7 +528,6 @@ you need to add the following import statement at the top of your training scrip .. data:: smp.optimizers.FusedNovoGrad :noindex: - `Fused Novo Grad optimizer `__ is currently not supported by the library. ``smp.optimizers.FusedNovoGrad`` replaces ``apex`` ``FusedNovoGrad`` optimizer and provides the same functionality. This requires ``apex`` to @@ -536,7 +536,6 @@ you need to add the following import statement at the top of your training scrip .. data:: smp.optimizers.FusedLamb :noindex: - `FusedLamb optimizer `__ currently doesn’t work with the library. ``smp.optimizers.FusedLamb`` replaces ``apex`` ``FusedLamb`` optimizer and provides the same functionality. diff --git a/doc/api/training/smp_versions/v1.9.0/smd_model_parallel_pytorch_tensor_parallel.rst b/doc/api/training/smp_versions/v1.9.0/smd_model_parallel_pytorch_tensor_parallel.rst index 851408b4b8..c66595ddf2 100644 --- a/doc/api/training/smp_versions/v1.9.0/smd_model_parallel_pytorch_tensor_parallel.rst +++ b/doc/api/training/smp_versions/v1.9.0/smd_model_parallel_pytorch_tensor_parallel.rst @@ -460,6 +460,7 @@ supported modules within that scope. To do this, you can use the following API: .. decorator:: smp.tensor_parallelism(enabled=True, **kwargs) + :noindex: - A context manager that enables or disables tensor parallelism for any supported module that is created inside. If there are nested diff --git a/doc/api/training/smp_versions/v1.9.0/smd_model_parallel_tensorflow.rst b/doc/api/training/smp_versions/v1.9.0/smd_model_parallel_tensorflow.rst index 54ec558fc5..2c658b487c 100644 --- a/doc/api/training/smp_versions/v1.9.0/smd_model_parallel_tensorflow.rst +++ b/doc/api/training/smp_versions/v1.9.0/smd_model_parallel_tensorflow.rst @@ -102,7 +102,7 @@ you need to add the following import statement at the top of your training scrip tf.print("Entered hook through eager context") .. class:: smp.CheckpointManager - :noindex: + :noindex: A subclass of TensorFlow From 4e9f251b34fe997b58a44621cd28186bf36865ce Mon Sep 17 00:00:00 2001 From: Miyoung Choi Date: Thu, 23 Jun 2022 15:50:59 -0700 Subject: [PATCH 03/23] add new smp cpu memory apis --- .../latest/smd_model_parallel_pytorch.rst | 249 +++++++++++++----- ...model_parallel_pytorch_tensor_parallel.rst | 102 +++---- 2 files changed, 243 insertions(+), 108 deletions(-) diff --git a/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch.rst b/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch.rst index b05413965c..d829da43f8 100644 --- a/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch.rst +++ b/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch.rst @@ -2,7 +2,7 @@ PyTorch API =========== To use the PyTorch-specific APIs for SageMaker distributed model parallism, -you need to add the following import statement at the top of your training script. +import the ``smdistributed.modelparallel.torch`` package at the top of your training script. .. code:: python @@ -16,24 +16,33 @@ you need to add the following import statement at the top of your training scrip `_ to learn how to use the following API in your PyTorch training script. -.. class:: smp.DistributedModel +.. contents:: Topics + :depth: 3 + :local: + +smdistributed.modelparallel.torch.DistributedModel +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. class:: smdistributed.modelparallel.torch.DistributedModel A sub-class of ``torch.nn.Module`` which specifies the model to be partitioned. Accepts a ``torch.nn.Module`` object ``module`` which is the model to be partitioned. The returned ``DistributedModel`` object internally manages model parallelism and data parallelism. Only one model in the training script can be wrapped with - ``smp.DistributedModel``. + ``smdistributed.modelparallel.torch.DistributedModel``. **Example:** .. code:: python + import smdistributed.modelparallel.torch as smp + model = smp.DistributedModel(model) **Important**: The ``__call__`` and  ``backward`` method calls on the - ``smp.DistributedModel`` object (in the following example, the object - is \ ``model``) can only be made inside a ``smp.step``-decorated + ``smdistributed.modelparallel.torch.DistributedModel`` object (in the following example, the object + is \ ``model``) can only be made inside a ``smdistributed.modelparallel.torch.step``-decorated function. Since ``DistributedModel``  is a ``torch.nn.Module``, a forward pass can @@ -78,7 +87,7 @@ you need to add the following import statement at the top of your training scrip In these examples, all ``__call__``  and ``backward`` method calls on the model objects (``model(inputs)`` and ``model.backward(loss)``) must be made inside - a ``smp.step``-decorated function. + a ``smdistributed.modelparallel.torch.step``-decorated function. **Using DDP** @@ -89,7 +98,7 @@ you need to add the following import statement at the top of your training scrip Unlike the original DDP wrapper, when you use ``DistributedModel``, model parameters and buffers are not immediately broadcast across processes when the wrapper is called. Instead, the broadcast is deferred to the first call of the - ``smp.step``-decorated function when the partition is done. + ``smdistributed.modelparallel.torch.step``-decorated function when the partition is done. **Parameters** @@ -160,7 +169,7 @@ you need to add the following import statement at the top of your training scrip - ``partitioned``: Is ``True`` if the model is partitioned, ``False`` otherwise. Initialized to ``False`` when ``DistributedModel`` is first created. It becomes be ``True`` during the first call - to ``smp.step``-decorated function. Once the model is partitioned, the + to ``smdistributed.modelparallel.torch.step``-decorated function. Once the model is partitioned, the local parameters or local ``state_dict`` can be fetched using the following methods. @@ -240,7 +249,7 @@ you need to add the following import statement at the top of your training scrip Registers a callable ``hook`` to be executed after the model is partitioned. This is useful in situations where an operation needs to be executed after the model partition during - the first call to ``smp.step``, but before the actual execution of the + the first call to ``smdistributed.modelparallel.torch.step``, but before the actual execution of the first forward pass. Returns a ``RemovableHandle`` object ``handle``, which can be used to remove the hook by calling ``handle.remove()``. @@ -252,7 +261,7 @@ you need to add the following import statement at the top of your training scrip .. function:: join( ) A context manager to be used in conjunction with an instance of - ``smp.DistributedModel`` to be able to train with uneven inputs across + ``smdistributed.modelparallel.torch.DistributedModel`` to be able to train with uneven inputs across participating processes. This is only supported when ``ddp=True``. This will use the join with the wrapped ``DistributedDataParallel`` instance. For more information, see: `join `__ @@ -276,9 +285,9 @@ you need to add the following import statement at the top of your training scrip `register_comm_hook `__ in the PyTorch documentation. - **Behavior of** ``smp.DistributedModel`` **with Tensor Parallelism** + **Behavior of** ``smdistributed.modelparallel.torch.DistributedModel`` **with Tensor Parallelism** - When a model is wrapped by ``smp.DistributedModel``, the library + When a model is wrapped by ``smdistributed.modelparallel.torch.DistributedModel``, the library immediately traverses the modules of the model object, and replaces the modules that are supported for tensor parallelism with their distributed counterparts. This replacement happens in place. If there are no other @@ -293,6 +302,8 @@ you need to add the following import statement at the top of your training scrip # register DistributedSubmodule as the distributed version of Submodule # (note this is a hypothetical example, smp.nn.DistributedSubmodule does not exist) + import smdistributed.modelparallel.torch as smp + smp.tp_register_with_module(Submodule, smp.nn.DistributedSubmodule) class MyModule(nn.Module): @@ -319,20 +330,20 @@ you need to add the following import statement at the top of your training scrip placement of model partitions into GPUs and the initial broadcast of model parameters and buffers across data-parallel ranks take place immediately. This is because it does not need to wait for the model - partition when ``smp.DistributedModel`` wrapper is called. For other + partition when ``smdistributed.modelparallel.torch.DistributedModel`` wrapper is called. For other cases with ``pipeline_parallel_degree`` greater than 1, the broadcast and device placement will be deferred until the first call of an - ``smp.step``-decorated function happens. This is because the first - ``smp.step``-decorated function call is when the model partitioning + ``smdistributed.modelparallel.torch.step``-decorated function happens. This is because the first + ``smdistributed.modelparallel.torch.step``-decorated function call is when the model partitioning happens if pipeline parallelism is enabled. - Because of the module replacement during the ``smp.DistributedModel`` + Because of the module replacement during the ``smdistributed.modelparallel.torch.DistributedModel`` call, any ``load_state_dict`` calls on the model, as well as any direct access to model parameters, such as during the optimizer creation, - should be done **after** the ``smp.DistributedModel`` call. + should be done **after** the ``smdistributed.modelparallel.torch.DistributedModel`` call. Since the broadcast of the model parameters and buffers happens - immediately during ``smp.DistributedModel`` call when the degree of + immediately during ``smdistributed.modelparallel.torch.DistributedModel`` call when the degree of pipeline parallelism is 1, using ``@smp.step`` decorators is not required when tensor parallelism is used by itself (without pipeline parallelism). @@ -340,9 +351,9 @@ you need to add the following import statement at the top of your training scrip For more information about the library's tensor parallelism APIs for PyTorch, see :ref:`smdmp-pytorch-tensor-parallel`. - **Additional Methods of** ``smp.DistributedModel`` **for Tensor Parallelism** + **Additional Methods of** ``smdistributed.modelparallel.torch.DistributedModel`` **for Tensor Parallelism** - The following are the new methods of ``smp.DistributedModel``, in + The following are the new methods of ``smdistributed.modelparallel.torch.DistributedModel``, in addition to the ones listed in the `documentation `__. @@ -383,24 +394,26 @@ you need to add the following import statement at the top of your training scrip - Returns an iterator that runs over ``(name, param)`` tuples, for ``param`` that is allreduced over the ``RDP_GROUP``. +smdistributed.modelparallel.torch.DistributedOptimizer +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +.. class:: smdistributed.modelparallel.torch.DistributedOptimizer(optimizer) -.. class:: smp.DistributedOptimizer + An optimizer wrapper for saving and loading optimizer states. - **Parameters** - - ``optimizer`` + :param optimizer: An optimizer object. + :type optimizer: object - An optimizer wrapper for saving/loading optimizer states. This wrapper - returns ``optimizer`` with the following methods overridden: + This wrapper returns ``optimizer`` with the following methods overridden: - .. function:: state_dict( ) + .. method:: state_dict( ) Returns the ``state_dict`` that contains optimizer state for the entire model. It first collects the ``local_state_dict`` and gathers and merges the ``local_state_dict`` from all ``mp_rank``s to create a full ``state_dict``. - .. function:: load_state_dict( ) + .. method:: load_state_dict( ) Same as the ``torch.optimizer.load_state_dict()`` , except: @@ -409,7 +422,7 @@ you need to add the following import statement at the top of your training scrip - The actual loading happens after the model partition so that each rank knows its local parameters. - .. function:: local_state_dict( ) + .. method:: local_state_dict( ) Returns the ``state_dict`` that contains the local optimizer state that belongs to the current \ ``mp_rank``. This @@ -418,34 +431,140 @@ you need to add the following import statement at the top of your training scrip ``state_dict`` contains elements corresponding to only the current partition, or to the entire model. - ​ -.. function:: smp.partition(index) - :noindex: + .. method:: save_optimizer_backcompat(cast_to_cpu=True, gather_if_shard=True, fp32_states_only=False) + + Gets the local optimizer states and FP16 states if FP16 training is enabled. + + :param cast_to_cpu: Whether to cast the optimizer states and FP16 states to CPU. + :type cast_to_cpu: boolean + :param gather_if_shard: (for smdistributed-modelparallel v1.10 only) + Whether to gather the optimizer states and FP16 states to the 0th + ``rdp_rank`` when using the `optimizer state sharding + `_ feature. + If you want to save optimizer and also further reduce CPU memory + utilization for better performance, turn it off by setting + ``gather_if_shard=False``. However, you need to make sure that you + save the states on all ``rdp_rank``. To handle both cases, + use the following example code. + + + :type gather_if_shard: boolean + :param fp32_states_only: Whether to return the FP32 optimizer states only. + :type fp32_states_only: boolean + + **Example Usage:** + + .. code:: python + + import smdistributed.modelparallel.torch as smp + + # wrap optimizer + optimizer = torch.optim.Optimizer(...) + optimizer = smp.DistributedOptimizer(optimizer) + + # save optimizer + save_dict["optimizer"] = optimizer.save_optimizer_backcompat( + gather_if_shard=args.gather_if_shard + ) + if not args.gather_if_shard or smp.rdp_rank() == 0: + smp.save( + save_dict, output_save_file, partial=True, + v3=not args.gather_if_shard + ) + + The ``v3`` argument of the ``smp.save()`` function checks whether the value of + the ``gather_if_shard`` arg is ``True`` or ``False``. + If ``gather_if_shard=False``, the ``v3`` arg helps collect optimizer checkpoint + files by adding ``pp_rank``, ``tp_rank``, and ``rdp_rank`` as postfix + to avoid overwriting checkpoint files. + + .. method:: load_optimizer_backcompat(state_dict, gather_if_shard=False) + + Loads the saved optimizer states and FP16 states if FP16 training is enabled. + + :param state_dict: The ``state_dict`` to load. + :type state_dict: dict + :param gather_if_shard: Specify whether the optimizer state was saved with ``gather_if_shard=True`` + when using the :class:`smdistributed.modelparallel.torch.DistributedOptimizer.save_optimizer_backcompat()` method. + :type gather_if_shard: boolean + + **Example Usage:** + + .. code:: python + + import smdistributed.modelparallel.torch as smp + + # load optimizer + checkpoint = smp.load(local_ckpt_path, partial=True) + optimizer.load_optimizer_backcompat( + checkpoint["optimizer"], gather_if_shard=args.gather_if_shard + ) - **Inputs** +smdistributed.modelparallel.torch Context Managers and Util Functions +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - - ``index`` (int) - The index of the partition. +.. function:: smdistributed.modelparallel.torch.model_creation(tensor_parallelism=False, dtype=None, **tensor_parallel_config) + + Context manager to create a ``torch`` model. This API combines both the + :class:`smdistributed.modelparallel.torch.tensor_parallelism` and + ``smdistributed.modelparallel.torch.delay_param_initialization`` decorators + so user need to simply use a single context when creating the torch model. + + :param tensor_parallelism: Whether tensor parallel should be enabled during model creation. + :type tensor_parallelism: boolean + :param dtype: The dtype to use when creating the model. It has the following rules. + + * If dtype is specified, it will be used during model creation. + * If dtype is not specified, the default dtype will be used during model creation, + which is usually FP32. This is for the best performance on CPU. + * Any model that causes out-of-memory problems with FP32 initialization + is recommended to be created with + :class:`smdistributed.modelparallel.torch.delayed_parameter_initialization`. + * ``FP16_Module`` casts the model back to FP16 if FP16 training is enabled with the smp config. + :type dtype: torch.dtype + :param tensor_parallel_config: kwargs to specifiy other tensor parallel configs. + This is not used if tensor_parallelism is False + :type tensor_parallel_config: dict + + **Example Usage:** + + .. code:: python + + import smdistributed.modelparallel.torch as smp + + with smp.model_creation( + tensor_parallelism=smp.tp_size() > 1, + dtype=torch.float16 if args.fp16 else torch.get_default_dtype() + ): + model = MyModel(...) + +.. function:: smdistributed.modelparallel.torch.partition(index) + + :param index: The index of the partition. + :type index: int A context manager which places all modules defined inside into the partition with ID ``index``.  The ``index`` argument must be less than the number of partitions. - Use ``smp.partition`` to implement manual partitioning. + Use ``smdistributed.modelparallel.torch.partition`` to implement manual partitioning. If ``"auto_partition"`` is ``True``, then the - ``smp.partition`` contexts are ignored. Any module that is not placed in - any ``smp.partition`` context is placed in the + ``smdistributed.modelparallel.torch.partition`` contexts are ignored. Any module that is not placed in + any ``smdistributed.modelparallel.torch.partition`` context is placed in the ``default_partition`` defined through the SageMaker Python SDK. - When ``smp.partition`` contexts are nested, the innermost context + When ``smdistributed.modelparallel.torch.partition`` contexts are nested, the innermost context overrides the rest (see the following example). In PyTorch, manual partitioning should be done inside the module \ ``__init__``, and the partition assignment applies to the modules that are *created* inside - the ``smp.partition`` context. + the ``smdistributed.modelparallel.torch.partition`` context. Example: .. code:: python + import smdistributed.modelparallel.torch as smp + class Model(torch.nn.Module):     def __init__(self):         with smp.partition(1): @@ -455,29 +574,40 @@ you need to add the following import statement at the top of your training scrip             self.child2 = Child2()            # child2 on partition 1         self.child3 = Child3()                # child3 on default_partition -.. function:: smp.get_world_process_group( ) +.. data:: smdistributed.modelparallel.torch.amp.GradScaler + + `Torch AMP Gradscaler `__ + currently doesn’t work with the library. ``smdistributed.modelparallel.torch.amp.GradScaler`` replaces + ``torch.amp.GradScaler`` and provides the same functionality. + +.. function:: smdistributed.modelparallel.torch.delayed_parameter_initialization(enabled=True) + + If enabled, it delays the initialization of parameters + to save CPU memory; it initializes after the model is partitioned on GPU. + +.. function:: smdistributed.modelparallel.torch.get_world_process_group( ) Returns a ``torch.distributed`` ``ProcessGroup`` that consists of all processes, which can be used with the ``torch.distributed`` API. Requires ``"ddp": True`` in SageMaker Python SDK parameters. -.. function:: smp.get_mp_process_group( ) +.. function:: smdistributed.modelparallel.torch.get_mp_process_group( ) Returns a ``torch.distributed`` ``ProcessGroup`` that consists of the processes in the ``MP_GROUP`` which contains the current process, which can be used with the \ ``torch.distributed`` API. Requires ``"ddp": True`` in SageMaker Python SDK parameters. -.. function:: smp.get_dp_process_group( ) +.. function:: smdistributed.modelparallel.torch.get_dp_process_group( ) Returns a ``torch.distributed`` ``ProcessGroup`` that consists of the processes in the ``DP_GROUP`` which contains the current process, which can be used with the \ ``torch.distributed`` API. Requires ``"ddp": True`` in SageMaker Python SDK parameters. -.. function:: smp.is_initialized( ) +.. function:: smdistributed.modelparallel.torch.is_initialized( ) - Returns ``True`` if ``smp.init`` has already been called for the + Returns ``True`` if ``smdistributed.modelparallel.torch.init`` has already been called for the process, and ``False`` otherwise. .. function::smp.is_tracing( ) @@ -485,41 +615,35 @@ you need to add the following import statement at the top of your training scrip Returns ``True`` if the current process is running the tracing step, and ``False`` otherwise. -.. data:: smp.nn.FusedLayerNorm +.. data:: smdistributed.modelparallel.torch.nn.FusedLayerNorm `Apex Fused Layer Norm `__ is currently not - supported by the library. ``smp.nn.FusedLayerNorm`` replaces ``apex`` + supported by the library. ``smdistributed.modelparallel.torch.nn.FusedLayerNorm`` replaces ``apex`` ``FusedLayerNorm`` and provides the same functionality. This requires ``apex`` to be installed on the system. -.. data:: smp.optimizers.FusedNovoGrad +.. data:: smdistributed.modelparallel.torch.optimizers.FusedNovoGrad `Fused Novo Grad optimizer `__ is - currently not supported by the library. ``smp.optimizers.FusedNovoGrad`` replaces ``apex`` ``FusedNovoGrad`` + currently not supported by the library. ``smdistributed.modelparallel.torch.optimizers.FusedNovoGrad`` replaces ``apex`` ``FusedNovoGrad`` optimizer and provides the same functionality. This requires ``apex`` to be installed on the system. -.. data:: smp.optimizers.FusedLamb +.. data:: smdistributed.modelparallel.torch.optimizers.FusedLamb `FusedLamb optimizer `__ - currently doesn’t work with the library. ``smp.optimizers.FusedLamb`` replaces + currently doesn’t work with the library. ``smdistributed.modelparallel.torch.optimizers.FusedLamb`` replaces ``apex`` ``FusedLamb`` optimizer and provides the same functionality. This requires ``apex`` to be installed on the system. -.. data:: smp.amp.GradScaler - - `Torch AMP Gradscaler `__ - currently doesn’t work with the library. ``smp.amp.GradScaler`` replaces - ``torch.amp.GradScaler`` and provides the same functionality. - .. _pytorch_saving_loading: APIs for Saving and Loading ^^^^^^^^^^^^^^^^^^^^^^^^^^^ -.. function:: smp.save( ) +.. function:: smdistributed.modelparallel.torch.save( ) Saves an object. This operation is similar to ``torch.save()``, except it has an additional keyword argument, ``partial``, and accepts only @@ -534,16 +658,18 @@ APIs for Saving and Loading - ``partial`` (bool, default= ``True``):  When set to ``True``, each ``mp_rank`` saves a separate checkpoint file and the library adds an ``mp_rank`` index to the saved file. If you want to be able to load - and further train a model that you save with ``smp.save()``, you must + and further train a model that you save with ``smdistributed.modelparallel.torch.save()``, you must set ``partial=True``. - ``pickle_module`` (picklemodule, default = module ``"pickle"`` from ``"/opt/conda/lib/python3.6/pickle.py"``): A module used for pickling metadata and objects. - ``pickle_protocol``  (int, default=2): Can be specified to override the defaultprotocol. + - ``v3`` (bool, default=``False``): When set to ``True``, save optimizer state checkpoints + in V3 file format to add all ``pp_rank``, ``tp_rank``, and ``rdp_rank`` as postfix. -.. function:: smp.load( ) +.. function:: smdistributed.modelparallel.torch.load( ) - Loads an object saved with ``smp.save()`` from a file. + Loads an object saved with ``smdistributed.modelparallel.torch.save()`` from a file. Similar to, `torch.load() `__, except it has an additional keyword argument, ``partial``, and accepts @@ -577,13 +703,13 @@ The library can save partial or full checkpoints. - For full checkpoints, the library saves a single checkpoint that contains entire model parameters. -When **saving** using ``smp.save()``, each rank only holds its own +When **saving** using ``smdistributed.modelparallel.torch.save()``, each rank only holds its own parameters. If you want to save the full model, there will be some communication between the ranks to create the full model. If you save checkpoints often, you should save partial checkpoints for best performance. -When **loading** using ``smp.load()``, the library can load either partial or | +When **loading** using ``smdistributed.modelparallel.torch.load()``, the library can load either partial or | full checkpoints or full checkpoints saved by a non-model-parallel model. If you want to resume training with a non-model-parallel model or do inference, you need a full checkpoint. @@ -592,6 +718,7 @@ The following is an example of how you can save and load a checkpoint: .. code:: python + import smdistributed.modelparallel.torch as smp # Original model and optimizer model = MyModel(...) optimizer = MyOpt(...) diff --git a/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch_tensor_parallel.rst b/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch_tensor_parallel.rst index e0ea1ba6c8..7a75b8f9f3 100644 --- a/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch_tensor_parallel.rst +++ b/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch_tensor_parallel.rst @@ -17,7 +17,7 @@ place on a best-effort basis for those module supported for tensor parallelism. Alternatively, you can directly import and use the library’s distributed modules in the model definition. -Some of the supported modules (such as ``smp.nn.Transformer``) are high-level +Some of the supported modules (such as ``smdistributed.modelparallel.torch.nn.Transformer``) are high-level blocks that contain many operations. Because custom implementations (as opposed to the built-in PyTorch modules) are typically used for these high-level blocks, the library offers an API that you can use to register @@ -47,9 +47,9 @@ use is functionally equivalent to the distributed module. You can verify this by taking a look at the equivalent reference implementations in the :ref:`smdmp-tp-appendix`. These implementations are functionally equivalent to their distributed -versions in ``smp.nn`` module. +versions in ``smdistributed.modelparallel.torch.nn`` module. -.. decorator:: @smp.tp_register(dist_module, init_hook=None, forward_hook=None, return_hook=None) +.. decorator:: @smdistributed.modelparallel.torch.tp_register(dist_module, init_hook=None, forward_hook=None, return_hook=None) - A class decorator that registers the ``dist_module`` class with the module class that it is attached to. The hooks can be used to @@ -57,10 +57,10 @@ versions in ``smp.nn`` module. ``forward`` methods. - **Arguments:** - - ``dist_module``: A subclass of ``smp.nn.DistributedModule`` + - ``dist_module``: A subclass of ``smdistributed.modelparallel.torch.nn.DistributedModule`` that implements the distributed version of the module class the decorator is attached to. Any distributed module class defined - in ``smp.nn`` module can be used. + in ``smdistributed.modelparallel.torch.nn`` module can be used. - ``init_hook``: A callable that translates the arguments of the original module ``__init__`` method to an ``(args, kwargs)`` tuple compatible with the arguments of the corresponding @@ -89,6 +89,8 @@ versions in ``smp.nn`` module. .. code:: python + import smdistributed.modelparallel.torch as smp + init_hook = lambda config: ((), config.to_dict()) # register smp.nn.DistributedTransformer @@ -101,7 +103,7 @@ versions in ``smp.nn`` module. def forward(self, hidden_states, attention_mask): ... -.. function:: smp.tp_register_with_module(module_cls, dist_module, init_hook=None, forward_hook=None, return_hook=None) +.. function:: smdistributed.modelparallel.torch.tp_register_with_module(module_cls, dist_module, init_hook=None, forward_hook=None, return_hook=None) - When you do not have direct access to model definition code, you can use this API to similarly register a distributed module with @@ -111,10 +113,10 @@ versions in ``smp.nn`` module. - ``module_cls``: The existing module class that will be distributed. - - ``dist_module``: A subclass of ``smp.nn.DistributedModule`` + - ``dist_module``: A subclass of ``smdistributed.modelparallel.torch.nn.DistributedModule`` that implements the distributed version of the module class the decorator is attached to. Any distributed module class defined - in ``smp.nn`` module can be used. + in ``smdistributed.modelparallel.torch.nn`` module can be used. - ``init_hook``: A callable that translates the arguments of the original module ``__init__`` method to an ``(args, kwargs)`` tuple compatible with the arguments of the corresponding @@ -143,6 +145,8 @@ versions in ``smp.nn`` module. .. code:: python + import smdistributed.modelparallel.torch as smp + from somelibrary import MyTransformer init_hook = lambda config: ((), config.to_dict()) @@ -160,13 +164,13 @@ Supported Modules for Tensor Parallelism The following modules are supported for tensor parallelism. -- ``smp.nn.DistributedLinear`` (implements ``nn.Linear``) -- ``smp.nn.DistributedTransformerLMHead`` -- ``smp.nn.DistributedTransformer`` -- ``smp.nn.DistributedTransformerLayer`` -- ``smp.nn.DistributedAttentionLayer`` -- ``smp.nn.DistributedTransformerOutputLayer`` -- ``smp.nn.DistributedEmbedding`` +- ``smdistributed.modelparallel.torch.nn.DistributedLinear`` (implements ``nn.Linear``) +- ``smdistributed.modelparallel.torch.nn.DistributedTransformerLMHead`` +- ``smdistributed.modelparallel.torch.nn.DistributedTransformer`` +- ``smdistributed.modelparallel.torch.nn.DistributedTransformerLayer`` +- ``smdistributed.modelparallel.torch.nn.DistributedAttentionLayer`` +- ``smdistributed.modelparallel.torch.nn.DistributedTransformerOutputLayer`` +- ``smdistributed.modelparallel.torch.nn.DistributedEmbedding`` .. contents:: Topics :depth: 3 @@ -177,7 +181,7 @@ parallelism. Tensor Parallelism Module APIs ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. class:: smp.nn.DistributedLinear(in_features, out_features) +.. class:: smdistributed.modelparallel.torch.nn.DistributedLinear(in_features, out_features) - Tensor-parallel implementation of the ``nn.Linear`` class. Functionally equivalent to an ``nn.Linear`` module with the same @@ -191,7 +195,7 @@ Tensor Parallelism Module APIs - ``out_features``: The total number of output channels for the linear layer across all tensor-parallel ranks. -.. class:: smp.nn.DistributedTransformerLMHead(num_layers=12, num_attention_heads=32, attention_head_size=32, hidden_size=1024, intermediate_size=4096, vocab_size=30522, num_positions=1024, attention_dropout_prob=0.1, hidden_dropout_prob=0.1, activation="gelu", layernorm_epsilon=1e-5, num_token_types=0, causal_mask_size=None, add_cross_attention=False, add_lm_head=True, initializer_range=0.02, use_normal_initialization=False, pre_layernorm=False, post_layernorm=True) +.. class:: smdistributed.modelparallel.torch.nn.DistributedTransformerLMHead(num_layers=12, num_attention_heads=32, attention_head_size=32, hidden_size=1024, intermediate_size=4096, vocab_size=30522, num_positions=1024, attention_dropout_prob=0.1, hidden_dropout_prob=0.1, activation="gelu", layernorm_epsilon=1e-5, num_token_types=0, causal_mask_size=None, add_cross_attention=False, add_lm_head=True, initializer_range=0.02, use_normal_initialization=False, pre_layernorm=False, post_layernorm=True) - Constructs a distributed transformer model, including embeddings and a single LM head. A word embedding of size @@ -205,7 +209,7 @@ Tensor Parallelism Module APIs if ``add_lm_head`` is ``True``, the output passes through a single LM head, which is a linear module without bias whose weight is tied to the word embeddings. - - See :class:`smp.nn.DistributedTransformerLayer` for descriptions of the rest + - See :class:`smdistributed.modelparallel.torch.nn.DistributedTransformerLayer` for descriptions of the rest of the arguments. - **Methods:** @@ -223,18 +227,18 @@ Tensor Parallelism Module APIs - ``attention_mask`` is assumed to be a 0-1 tensor of shape ``[N, S]``, where 1 represents a masked position. -.. class:: smp.nn.DistributedTransformer(num_layers=12, num_attention_heads=32, attention_head_size=32, hidden_size=1024, intermediate_size=4096, attention_dropout_prob=0.1, hidden_dropout_prob=0.1, activation="gelu", layernorm_epsilon=1e-5, initializer_range=0.02, use_normal_initialization=False, causal_mask_size=None, add_cross_attention=False, pre_layernorm=False, post_layernorm=True) +.. class:: smdistributed.modelparallel.torch.nn.DistributedTransformer(num_layers=12, num_attention_heads=32, attention_head_size=32, hidden_size=1024, intermediate_size=4096, attention_dropout_prob=0.1, hidden_dropout_prob=0.1, activation="gelu", layernorm_epsilon=1e-5, initializer_range=0.02, use_normal_initialization=False, causal_mask_size=None, add_cross_attention=False, pre_layernorm=False, post_layernorm=True) - - A sequence of ``smp.nn.DistributedTransformerLayer``\ s, whose + - A sequence of ``smdistributed.modelparallel.torch.nn.DistributedTransformerLayer``\ s, whose number is given by ``num_layers`` argument. For the other arguments and methods, refer to - ``smp.nn.DistributedTransformerLayer``. + ``smdistributed.modelparallel.torch.nn.DistributedTransformerLayer``. - If both ``pre_layernorm`` and ``post_layernorm`` are ``True``, layer normalization is applied to both the input and the output of the ``DistributedTransformer``, in addition to the intermediate attention and transformer-output layers. -.. class:: smp.nn.DistributedTransformerLayer(num_attention_heads=32, attention_head_size=32, hidden_size=1024, intermediate_size=4096, attention_dropout_prob=0.1, hidden_dropout_prob=0.1, activation="gelu", layernorm_epsilon=1e-5, initializer_range=0.02, use_normal_initialization=False, causal_mask_size=None, add_cross_attention=False, pre_layernorm=False, post_layernorm=True) +.. class:: smdistributed.modelparallel.torch.nn.DistributedTransformerLayer(num_attention_heads=32, attention_head_size=32, hidden_size=1024, intermediate_size=4096, attention_dropout_prob=0.1, hidden_dropout_prob=0.1, activation="gelu", layernorm_epsilon=1e-5, initializer_range=0.02, use_normal_initialization=False, causal_mask_size=None, add_cross_attention=False, pre_layernorm=False, post_layernorm=True) - Tensor-parallel implementation of a single transformer layer. Number of attention heads, hidden size, and intermediate size @@ -336,7 +340,7 @@ Tensor Parallelism Module APIs and the next three tensors are the same as the input arguments. -.. class:: smp.nn.DistributedAttentionLayer(num_attention_heads=32, attention_head_size=32, hidden_size=1024, attention_dropout_prob=0.1, hidden_dropout_prob=0.1, layernorm_epsilon=1e-5, initializer_range=0.02, use_normal_initialization=False, cross_attention=False, causal_mask_size=None, pre_layernorm=False, post_layernorm=True) +.. class:: smdistributed.modelparallel.torch.nn.DistributedAttentionLayer(num_attention_heads=32, attention_head_size=32, hidden_size=1024, attention_dropout_prob=0.1, hidden_dropout_prob=0.1, layernorm_epsilon=1e-5, initializer_range=0.02, use_normal_initialization=False, cross_attention=False, causal_mask_size=None, pre_layernorm=False, post_layernorm=True) - A distributed implementation for the attention block. Includes the computation of the self- or cross-attention (context layer), @@ -344,7 +348,7 @@ Tensor Parallelism Module APIs followed by the residual-connection and layer normalization. - **Arguments:** - - See :class:`smp.nn.DistributedTransformerLayer` for descriptions of the + - See :class:`smdistributed.modelparallel.torch.nn.DistributedTransformerLayer` for descriptions of the arguments. - ``cross_attention``: If ``True``, it computes the attentions with respect to the ``cross_states`` tensor of the ``forward`` @@ -383,10 +387,10 @@ Tensor Parallelism Module APIs - A single tensor that is the output of the attention layer. -.. class:: smp.nn.DistributedTransformerOutputLayer(hidden_size=1024, intermediate_size=4096, hidden_dropout_prob=0.1, activation="gelu", layernorm_epsilon=1e-5, initializer_range=0.02, use_normal_initialization=False, pre_layernorm=False, post_layernorm=True, fp32_residual_addition=False) +.. class:: smdistributed.modelparallel.torch.nn.DistributedTransformerOutputLayer(hidden_size=1024, intermediate_size=4096, hidden_dropout_prob=0.1, activation="gelu", layernorm_epsilon=1e-5, initializer_range=0.02, use_normal_initialization=False, pre_layernorm=False, post_layernorm=True, fp32_residual_addition=False) - Distributed implementation of a single transformer output layer. A - single :class:`smp.nn.DistributedTransformerLayer` with + single :class:`smdistributed.modelparallel.torch.nn.DistributedTransformerLayer` with ``add_cross_attention=False`` consists of a single ``DistributedAttentionLayer`` immediately followed by a single ``DistributedTransformerOutputLayer``. The latter linearly maps @@ -394,19 +398,19 @@ Tensor Parallelism Module APIs ``intermediate_size``, and then maps it back to ``hidden_size``. - **Arguments:** - - See :class:`smp.nn.DistributedTransformerLayer` for descriptions of the + - See :class:`smdistributed.modelparallel.torch.nn.DistributedTransformerLayer` for descriptions of the arguments. - ``fp32_residual_addition``: Set to ``True`` if you want to avoid overflow (NaN loss values) for large models with more than 100 billion parameters when using FP16. (Default: False) -.. class:: smp.nn.DistributedEmbedding(num_embeddings,embedding_dim, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False, _weight=None, initializer_range=0.02, _skip_allgather=False,_skip_scatter_and_merge=False,) +.. class:: smdistributed.modelparallel.torch.nn.DistributedEmbedding(num_embeddings,embedding_dim, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False, _weight=None, initializer_range=0.02, _skip_allgather=False,_skip_scatter_and_merge=False,) - Distributed implementation of a single Embedding Layer. Currently only supports splitting across the embedding_dim. - **Arguments:** - - See :class:`smp.nn.DistributedEmbedding` for descriptions of the + - See :class:`smdistributed.modelparallel.torch.nn.DistributedEmbedding` for descriptions of the arguments. .. _enabling-tp: @@ -417,7 +421,7 @@ Enabling Tensor Parallelism There are two ways tensor parallelism can be enabled. First, you can use -the distributed module implementations in ``smp.nn`` module directly in +the distributed module implementations in ``smdistributed.modelparallel.torch.nn`` module directly in your model definition. See :ref:`smdmp-supported-modules-for-tp` for a complete list of built-in distributed modules. Here is an example of how this can be done: @@ -446,7 +450,7 @@ of code, which will automatically enable tensor parallelism for the supported modules within that scope. To do this, you can use the following API: -.. decorator:: smp.tensor_parallelism(enabled=True, **kwargs) +.. decorator:: smdistributed.modelparallel.torch.tensor_parallelism(enabled=True, **kwargs) - A context manager that enables or disables tensor parallelism for any supported module that is created inside. If there are nested @@ -463,6 +467,8 @@ following API: .. code:: python + import smdistributed.modelparallel.torch as smp + with smp.tensor_parallelism(): self.m0 = nn.Linear(20, 20) # will be distributed with smp.tensor_parallelism(enabled=False): @@ -472,7 +478,7 @@ following API: the distributed modules created inside the context. If a keyword argument provided through it matches any ``__init__`` method arguments of a ``DistributedModule`` that substitutes a module created inside - the ``smp.tensor_parallelism`` context, this keyword will override + the ``smdistributed.modelparallel.torch.tensor_parallelism`` context, this keyword will override the value defined in the ``init_hook``. - (*For v1.7.0 and later*) Through the following additional keyword arguments, @@ -481,21 +487,21 @@ following API: - ``fused_softmax`` (bool) - Fusion of attention masking and softmax. By default, it is set to ``True``. You can deactivate it by setting - ``fused_softmax=False`` in the ``smp.tensor_parallelism`` context manager. + ``fused_softmax=False`` in the ``smdistributed.modelparallel.torch.tensor_parallelism`` context manager. - ``fused_bias_gelu`` (bool) - Fusion of bias addition and Gelu activation. By default, it is set to ``False``. You can activate it by setting - ``fused_bias_gelu=True`` in the ``smp.tensor_parallelism`` context manager. + ``fused_bias_gelu=True`` in the ``smdistributed.modelparallel.torch.tensor_parallelism`` context manager. -.. function:: smp.set_tensor_parallelism(module, enabled=True, **kwargs) +.. function:: smdistributed.modelparallel.torch.set_tensor_parallelism(module, enabled=True, **kwargs) - Enables or disables tensor parallelism for the supported submodules of ``module``. If enabling, the outermost supported modules will be distributed. If disabling, tensor parallelism will be disabled for the entire module subtree of ``module``. Unlike the context manager, this API can be used after the model creation - (but before wrapping with :class:`smp.DistributedModel`), so direct + (but before wrapping with :class:`smdistributed.modelparallel.torch.DistributedModel`), so direct access to model definition code is not required. If a supported module shares weights with another (supported or unsupported) module, or if its hyperparameters do not support distribution @@ -504,14 +510,16 @@ following API: - Keyword arguments ``kwargs`` can be used to modify the configurations of the distributed modules created inside the context. If a keyword argument provided here matches any - ``__init__`` method arguments of a :class:`smp.DistributedModel` that - substitutes a module created inside the ``smp.tensor_parallelism`` + ``__init__`` method arguments of a :class:`smdistributed.modelparallel.torch.DistributedModel` that + substitutes a module created inside the ``smdistributed.modelparallel.torch.tensor_parallelism`` context, this keyword will override the value defined in the ``init_hook``. - **Example:** .. code:: python + import smdistributed.modelparallel.torch as smp + model = MyModel() smp.set_tensor_parallelism(model.encoder, True) smp.set_tensor_parallelism(model.encoder.embedding, True) @@ -608,7 +616,7 @@ in the *SageMaker's Distributed Model Parallel developer guide*. any tuples received. If the checkpointed layer takes a tuple as input, then this needs to be set to True. -.. class:: smp.set_activation_checkpointing(module, preserve_rng_state=True, pack_args_as_tuple=False, strategy="each") +.. class:: smdistributed.modelparallel.torch.set_activation_checkpointing(module, preserve_rng_state=True, pack_args_as_tuple=False, strategy="each") - This API is recommended when importing pretrained models from libraries, such as PyTorch and Hugging Face Transformers. This is @@ -673,8 +681,8 @@ parses the arguments to ``__init__`` methods and sets the relevant attributes of the module, such as ``hidden_size`` and ``num_attention_heads``. -``smp.nn.DistributedTransformer`` -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +``smdistributed.modelparallel.torch.nn.DistributedTransformer`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. code:: python @@ -692,8 +700,8 @@ attributes of the module, such as ``hidden_size`` and def forward(self, inp): return self.seq_layers(inp) -``smp.nn.DistributedTransformerLayer`` -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +``smdistributed.modelparallel.torch.nn.DistributedTransformerLayer`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. code:: python @@ -727,8 +735,8 @@ attributes of the module, such as ``hidden_size`` and else: return output, attention_mask -``smp.nn.DistributedAttentionLayer`` -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +``smdistributed.modelparallel.torch.nn.DistributedAttentionLayer`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. code:: python @@ -812,8 +820,8 @@ attributes of the module, such as ``hidden_size`` and else: return self_attention -``smp.nn.DistributedTransformerOutputLayer`` -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +``smdistributed.modelparallel.torch.nn.DistributedTransformerOutputLayer`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. code:: python From 052a719eff7ab6cb8ac2d0480065040d4a923ab4 Mon Sep 17 00:00:00 2001 From: Miyoung Choi Date: Thu, 23 Jun 2022 16:54:07 -0700 Subject: [PATCH 04/23] add new params --- doc/api/training/smd_model_parallel_general.rst | 10 ++++++++++ .../latest/smd_model_parallel_pytorch.rst | 15 +++++++++++---- 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/doc/api/training/smd_model_parallel_general.rst b/doc/api/training/smd_model_parallel_general.rst index a35e0d60bc..fbb99f5224 100644 --- a/doc/api/training/smd_model_parallel_general.rst +++ b/doc/api/training/smd_model_parallel_general.rst @@ -178,6 +178,16 @@ PyTorch-specific Parameters - 1 - The number of devices over which the tensor parallel modules will be distributed. If ``tensor_parallel_degree`` is greater than 1, then ``ddp`` must be set to ``True``. + * - ``fp16`` (**smdistributed-modelparallel**>=v1.10) + - bool + - ``False`` + - To run FP16 training, add ``"fp16"'": True`` to the smp configuration. + Other APIs remain the same between FP16 and FP32. + If ``fp16`` is enabled and when user calls ``smp.DistributedModel``, + the model will be wrapped with ``FP16_Module``, which converts the model + to FP16 dtype and deals with forward pass in FP16. + If ``fp16`` is enabled and when user calls ``smp.DistributedOptimizer``, + the optimizer will be wrapped with ``FP16_Optimizer``. * - ``fp16_params`` (**smdistributed-modelparallel**>=v1.6) - bool - ``False`` diff --git a/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch.rst b/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch.rst index d829da43f8..3340587641 100644 --- a/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch.rst +++ b/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch.rst @@ -397,12 +397,18 @@ smdistributed.modelparallel.torch.DistributedModel smdistributed.modelparallel.torch.DistributedOptimizer ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -.. class:: smdistributed.modelparallel.torch.DistributedOptimizer(optimizer) +.. class:: smdistributed.modelparallel.torch.DistributedOptimizer(optimizer, static_loss_scale=1, dynamic_loss_scale=False, **dynamic_loss_args) An optimizer wrapper for saving and loading optimizer states. :param optimizer: An optimizer object. :type optimizer: object + :param static_loss_scale: Available only for FP16 training. Set to ``1`` to use static loss scale. The default value is ``1``. + :type static_loss_scale: float + :param dynamic_loss_scale: Available only for FP16 training. Set to ``True`` to use dynamic loss scale. + :type dynamic_loss_scale: boolean + :param dynamic_loss_args: Available only for FP16 training. If ``dynamic_loss_scale=True``, specify parameters for dynamic loss scale. + :type dynamic_loss_args: dict This wrapper returns ``optimizer`` with the following methods overridden: @@ -523,7 +529,7 @@ smdistributed.modelparallel.torch Context Managers and Util Functions * ``FP16_Module`` casts the model back to FP16 if FP16 training is enabled with the smp config. :type dtype: torch.dtype :param tensor_parallel_config: kwargs to specifiy other tensor parallel configs. - This is not used if tensor_parallelism is False + This is not used if ``tensor_parallelism`` is ``False`` :type tensor_parallel_config: dict **Example Usage:** @@ -580,10 +586,11 @@ smdistributed.modelparallel.torch Context Managers and Util Functions currently doesn’t work with the library. ``smdistributed.modelparallel.torch.amp.GradScaler`` replaces ``torch.amp.GradScaler`` and provides the same functionality. -.. function:: smdistributed.modelparallel.torch.delayed_parameter_initialization(enabled=True) +.. function:: smdistributed.modelparallel.torch.delay_param_initialization(enabled=True) If enabled, it delays the initialization of parameters - to save CPU memory; it initializes after the model is partitioned on GPU. + to save CPU memory. That is, parameter initialization takes place + after the model is partitioned on GPUs. .. function:: smdistributed.modelparallel.torch.get_world_process_group( ) From c9101c3b0211614fc9aaab4822c7ffa9292620f3 Mon Sep 17 00:00:00 2001 From: Miyoung Choi Date: Thu, 23 Jun 2022 17:41:57 -0700 Subject: [PATCH 05/23] add dynamic scale params, add reference --- .../latest/smd_model_parallel_pytorch.rst | 64 ++++++++++++++++--- 1 file changed, 55 insertions(+), 9 deletions(-) diff --git a/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch.rst b/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch.rst index 3340587641..4eb4c7aeff 100644 --- a/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch.rst +++ b/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch.rst @@ -397,26 +397,73 @@ smdistributed.modelparallel.torch.DistributedModel smdistributed.modelparallel.torch.DistributedOptimizer ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -.. class:: smdistributed.modelparallel.torch.DistributedOptimizer(optimizer, static_loss_scale=1, dynamic_loss_scale=False, **dynamic_loss_args) +.. class:: smdistributed.modelparallel.torch.DistributedOptimizer(optimizer, static_loss_scale=1.0, dynamic_loss_scale=False, **dynamic_loss_args) An optimizer wrapper for saving and loading optimizer states. :param optimizer: An optimizer object. :type optimizer: object - :param static_loss_scale: Available only for FP16 training. Set to ``1`` to use static loss scale. The default value is ``1``. + :param static_loss_scale: Available only for FP16 training. The default value is ``1.0``. :type static_loss_scale: float :param dynamic_loss_scale: Available only for FP16 training. Set to ``True`` to use dynamic loss scale. :type dynamic_loss_scale: boolean - :param dynamic_loss_args: Available only for FP16 training. If ``dynamic_loss_scale=True``, specify parameters for dynamic loss scale. + :param dynamic_loss_args: Available only for FP16 training. + If you set ``dynamic_loss_scale=True``, configure scale parameters for dynamic loss scale. + The following list shows available parameters. + + * ``"init_scale"``: Default is ``2**32`` + * ``"scale_factor"``: Default is ``2.`` + * ``"scale_window"``: Default is ``1000`` + * ``"min_scale"``: Default is ``1`` + * ``"delayed_shift"``: Default is ``1`` + * ``"consecutive_hysteresis"``: Default is ``False`` :type dynamic_loss_args: dict - This wrapper returns ``optimizer`` with the following methods overridden: + **Example Usage for an FP32 Optimizer:** + + .. code:: python + + optimizer = torch.optim.AdaDelta(model.parameters(), lr=4.0) + optimizer = smdistributed.modelparallel.torch.DistributedOptimizer(optimizer) + + + **Example Usage for an FP16 Optimizer:** + + .. code:: python + + optimizer = smdistributed.modelparallel.torch.DistributedOptimizer( + optimizer, + static_loss_scale=None, + dynamic_loss_scale=True, + dynamic_loss_args={ + "scale_window": 1000, + "min_scale": 1, + "delayed_shift": 2 + } + ) + + .. tip:: + + After you modify training scripts with + :class:`smdistributed.modelparallel.torch.DistributedModel` and + :class:`smdistributed.modelparallel.torch.DistributedOptimizer`, + use the SageMaker PyTorch estimator's distribution configuration o enable FP16 training. + You simply need to add ``"fp16": True`` to the ``smp_options`` config dictionary's + ``"parameters"`` key as shown in + `Using the SageMaker TensorFlow and PyTorch Estimators + `_. + For more information about available parameters for the ``smp_options`` config, + see :ref:`sm-sdk-modelparallel-general`. + + + + This wrapper returns an ``optimizer`` object with the following methods overridden: .. method:: state_dict( ) Returns the ``state_dict`` that contains optimizer state for the entire model. It first collects the ``local_state_dict`` and gathers and merges - the ``local_state_dict`` from all ``mp_rank``s to create a full + the ``local_state_dict`` from all ``mp_rank``\ s to create a full ``state_dict``. .. method:: load_state_dict( ) @@ -450,10 +497,9 @@ smdistributed.modelparallel.torch.DistributedOptimizer If you want to save optimizer and also further reduce CPU memory utilization for better performance, turn it off by setting ``gather_if_shard=False``. However, you need to make sure that you - save the states on all ``rdp_rank``. To handle both cases, + save the states on all ``rdp_rank``\ s. To handle both cases, use the following example code. - :type gather_if_shard: boolean :param fp32_states_only: Whether to return the FP32 optimizer states only. :type fp32_states_only: boolean @@ -465,7 +511,7 @@ smdistributed.modelparallel.torch.DistributedOptimizer import smdistributed.modelparallel.torch as smp # wrap optimizer - optimizer = torch.optim.Optimizer(...) + optimizer = torch.optim.AdaDelta(...) optimizer = smp.DistributedOptimizer(optimizer) # save optimizer @@ -482,7 +528,7 @@ smdistributed.modelparallel.torch.DistributedOptimizer the ``gather_if_shard`` arg is ``True`` or ``False``. If ``gather_if_shard=False``, the ``v3`` arg helps collect optimizer checkpoint files by adding ``pp_rank``, ``tp_rank``, and ``rdp_rank`` as postfix - to avoid overwriting checkpoint files. + to avoid overwriting optimizer checkpoint files. .. method:: load_optimizer_backcompat(state_dict, gather_if_shard=False) From 741a7f91bdd786f28ed9b2c12a27cf458efe3c46 Mon Sep 17 00:00:00 2001 From: Miyoung Choi Date: Thu, 23 Jun 2022 18:00:56 -0700 Subject: [PATCH 06/23] minor fix --- .../training/smp_versions/latest/smd_model_parallel_pytorch.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch.rst b/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch.rst index 4eb4c7aeff..e36fbb9993 100644 --- a/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch.rst +++ b/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch.rst @@ -447,7 +447,7 @@ smdistributed.modelparallel.torch.DistributedOptimizer After you modify training scripts with :class:`smdistributed.modelparallel.torch.DistributedModel` and :class:`smdistributed.modelparallel.torch.DistributedOptimizer`, - use the SageMaker PyTorch estimator's distribution configuration o enable FP16 training. + use the SageMaker PyTorch estimator's distribution configuration to enable FP16 training. You simply need to add ``"fp16": True`` to the ``smp_options`` config dictionary's ``"parameters"`` key as shown in `Using the SageMaker TensorFlow and PyTorch Estimators From cb0111974d58ad775441f9161b417626bfbb0423 Mon Sep 17 00:00:00 2001 From: Miyoung Choi Date: Tue, 28 Jun 2022 15:10:08 -0700 Subject: [PATCH 07/23] minor fixes --- .../latest/smd_model_parallel_pytorch.rst | 43 ++++--- ...model_parallel_pytorch_tensor_parallel.rst | 106 ++++++++++-------- 2 files changed, 90 insertions(+), 59 deletions(-) diff --git a/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch.rst b/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch.rst index e36fbb9993..520542846a 100644 --- a/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch.rst +++ b/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch.rst @@ -403,12 +403,14 @@ smdistributed.modelparallel.torch.DistributedOptimizer :param optimizer: An optimizer object. :type optimizer: object - :param static_loss_scale: Available only for FP16 training. The default value is ``1.0``. + :param static_loss_scale: Effective only for FP16 training. The default value is ``1.0``. :type static_loss_scale: float - :param dynamic_loss_scale: Available only for FP16 training. Set to ``True`` to use dynamic loss scale. + :param dynamic_loss_scale: Effective only for FP16 training. Set to ``True`` to + use dynamic loss scale. The default value is ``False``. :type dynamic_loss_scale: boolean - :param dynamic_loss_args: Available only for FP16 training. - If you set ``dynamic_loss_scale=True``, configure scale parameters for dynamic loss scale. + :param dynamic_loss_args: Effective only for FP16 training. + If ``dynamic_loss_scale=True``, you can configure additional scale + parameters for dynamic loss scale. The following list shows available parameters. * ``"init_scale"``: Default is ``2**32`` @@ -423,14 +425,24 @@ smdistributed.modelparallel.torch.DistributedOptimizer .. code:: python - optimizer = torch.optim.AdaDelta(model.parameters(), lr=4.0) + optimizer = torch.optim.AdaDelta(...) optimizer = smdistributed.modelparallel.torch.DistributedOptimizer(optimizer) + **Example Usage for an FP16 Optimizer with static loss scale:** - **Example Usage for an FP16 Optimizer:** + .. code:: python + + optimizer = torch.optim.AdaDelta(...) + optimizer = smdistributed.modelparallel.torch.DistributedOptimizer( + optimizer, + static_loss_scale=1.0 + ) + + **Example Usage for an FP16 Optimizer with dynamic loss scale:** .. code:: python + optimizer = torch.optim.AdaDelta(...) optimizer = smdistributed.modelparallel.torch.DistributedOptimizer( optimizer, static_loss_scale=None, @@ -455,8 +467,6 @@ smdistributed.modelparallel.torch.DistributedOptimizer For more information about available parameters for the ``smp_options`` config, see :ref:`sm-sdk-modelparallel-general`. - - This wrapper returns an ``optimizer`` object with the following methods overridden: .. method:: state_dict( ) @@ -555,14 +565,14 @@ smdistributed.modelparallel.torch.DistributedOptimizer smdistributed.modelparallel.torch Context Managers and Util Functions ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -.. function:: smdistributed.modelparallel.torch.model_creation(tensor_parallelism=False, dtype=None, **tensor_parallel_config) +.. function:: smdistributed.modelparallel.torch.model_creation(tensor_parallelism=False, dtype=None, distribute_embedding=False, **tensor_parallel_config) Context manager to create a ``torch`` model. This API combines both the :class:`smdistributed.modelparallel.torch.tensor_parallelism` and - ``smdistributed.modelparallel.torch.delay_param_initialization`` decorators - so user need to simply use a single context when creating the torch model. + :class:`smdistributed.modelparallel.torch.delay_param_initialization` decorators, + so you can simply use this single context when creating the torch model. - :param tensor_parallelism: Whether tensor parallel should be enabled during model creation. + :param tensor_parallelism: Whether to enable tensor parallelism during model creation. :type tensor_parallelism: boolean :param dtype: The dtype to use when creating the model. It has the following rules. @@ -572,10 +582,12 @@ smdistributed.modelparallel.torch Context Managers and Util Functions * Any model that causes out-of-memory problems with FP32 initialization is recommended to be created with :class:`smdistributed.modelparallel.torch.delayed_parameter_initialization`. - * ``FP16_Module`` casts the model back to FP16 if FP16 training is enabled with the smp config. + * ``FP16_Module`` casts the model back to FP16 if FP16 training is enabled with the ``smp`` config. :type dtype: torch.dtype + :param distribute_embedding: Whether to enable vocabulary parallelism for NLP models. + :type dtype: boolean :param tensor_parallel_config: kwargs to specifiy other tensor parallel configs. - This is not used if ``tensor_parallelism`` is ``False`` + This is not used if ``tensor_parallelism`` is ``False``. :type tensor_parallel_config: dict **Example Usage:** @@ -586,7 +598,8 @@ smdistributed.modelparallel.torch Context Managers and Util Functions with smp.model_creation( tensor_parallelism=smp.tp_size() > 1, - dtype=torch.float16 if args.fp16 else torch.get_default_dtype() + dtype=torch.float16 if args.fp16 else torch.get_default_dtype(), + distribute_embedding=False ): model = MyModel(...) diff --git a/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch_tensor_parallel.rst b/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch_tensor_parallel.rst index 7a75b8f9f3..c101e0025d 100644 --- a/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch_tensor_parallel.rst +++ b/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch_tensor_parallel.rst @@ -49,9 +49,9 @@ this by taking a look at the equivalent reference implementations in the These implementations are functionally equivalent to their distributed versions in ``smdistributed.modelparallel.torch.nn`` module. -.. decorator:: @smdistributed.modelparallel.torch.tp_register(dist_module, init_hook=None, forward_hook=None, return_hook=None) +.. class:: smdistributed.modelparallel.torch.tp_register(dist_module, init_hook=None, forward_hook=None, return_hook=None) - - A class decorator that registers the ``dist_module`` class with + - A decorator class that registers the ``dist_module`` class with the module class that it is attached to. The hooks can be used to adapt to different interfaces used with ``__init__`` and ``forward`` methods. @@ -161,16 +161,7 @@ versions in ``smdistributed.modelparallel.torch.nn`` module. Supported Modules for Tensor Parallelism ---------------------------------------- -The following modules are supported for tensor -parallelism. - -- ``smdistributed.modelparallel.torch.nn.DistributedLinear`` (implements ``nn.Linear``) -- ``smdistributed.modelparallel.torch.nn.DistributedTransformerLMHead`` -- ``smdistributed.modelparallel.torch.nn.DistributedTransformer`` -- ``smdistributed.modelparallel.torch.nn.DistributedTransformerLayer`` -- ``smdistributed.modelparallel.torch.nn.DistributedAttentionLayer`` -- ``smdistributed.modelparallel.torch.nn.DistributedTransformerOutputLayer`` -- ``smdistributed.modelparallel.torch.nn.DistributedEmbedding`` +The following modules are supported for tensor parallelism. .. contents:: Topics :depth: 3 @@ -181,14 +172,27 @@ parallelism. Tensor Parallelism Module APIs ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +- :class:`smdistributed.modelparallel.torch.nn.DistributedLinear` (implements ``nn.Linear``) +- :class:`smdistributed.modelparallel.torch.nn.DistributedTransformerLMHead` +- :class:`smdistributed.modelparallel.torch.nn.DistributedTransformer` +- :class:`smdistributed.modelparallel.torch.nn.DistributedTransformerLayer` +- :class:`smdistributed.modelparallel.torch.nn.DistributedAttentionLayer` +- :class:`smdistributed.modelparallel.torch.nn.DistributedTransformerOutputLayer` +- :class:`smdistributed.modelparallel.torch.nn.DistributedEmbedding` + .. class:: smdistributed.modelparallel.torch.nn.DistributedLinear(in_features, out_features) - - Tensor-parallel implementation of the ``nn.Linear`` class. - Functionally equivalent to an ``nn.Linear`` module with the same - ``in_features`` and ``out_features``. In other words, - ``in_features`` and ``out_features`` are the number of *global* - channels across tensor-parallel ranks. - - **Arguments:** + Tensor-parallel implementation of the ``nn.Linear`` class. + Functionally equivalent to an ``nn.Linear`` module with the same + ``in_features`` and ``out_features``. In other words, + ``in_features`` and ``out_features`` are the number of *global* + channels across tensor-parallel ranks. + + For more information about what's the reference implementation of this module, + see :ref:`smdmp-tp-appendix`. + + + - **Arguments:** - ``in_features``: The total number of input channels for the linear layer across all tensor-parallel ranks. @@ -197,21 +201,22 @@ Tensor Parallelism Module APIs .. class:: smdistributed.modelparallel.torch.nn.DistributedTransformerLMHead(num_layers=12, num_attention_heads=32, attention_head_size=32, hidden_size=1024, intermediate_size=4096, vocab_size=30522, num_positions=1024, attention_dropout_prob=0.1, hidden_dropout_prob=0.1, activation="gelu", layernorm_epsilon=1e-5, num_token_types=0, causal_mask_size=None, add_cross_attention=False, add_lm_head=True, initializer_range=0.02, use_normal_initialization=False, pre_layernorm=False, post_layernorm=True) - - Constructs a distributed transformer model, including embeddings - and a single LM head. A word embedding of size - ``(vocab_size, hidden_size)`` is created, as well as a positional - embedding of size ``(num_positions, hidden_size)``, and the - embeddings are added together. If ``num_token_types`` is larger - than 0, a separate embedding of size - ``(num_token_types, hidden_size)`` is created, and further added - on top. - - The embeddings are fed through a ``DistributedTransformer``, and - if ``add_lm_head`` is ``True``, the output passes through a single - LM head, which is a linear module without bias whose weight is - tied to the word embeddings. - - See :class:`smdistributed.modelparallel.torch.nn.DistributedTransformerLayer` for descriptions of the rest - of the arguments. - - **Methods:** + Constructs a distributed transformer model, including embeddings + and a single LM head. A word embedding of size + ``(vocab_size, hidden_size)`` is created, as well as a positional + embedding of size ``(num_positions, hidden_size)``, and the + embeddings are added together. If ``num_token_types`` is larger + than 0, a separate embedding of size + ``(num_token_types, hidden_size)`` is created, and further added + on top. + + - The embeddings are fed through a ``DistributedTransformer``, and + if ``add_lm_head`` is ``True``, the output passes through a single + LM head, which is a linear module without bias whose weight is + tied to the word embeddings. + - See :class:`smdistributed.modelparallel.torch.nn.DistributedTransformerLayer` for descriptions of the rest + of the arguments. + - **Methods:** - ``forward(self, inputs)`` @@ -229,10 +234,11 @@ Tensor Parallelism Module APIs .. class:: smdistributed.modelparallel.torch.nn.DistributedTransformer(num_layers=12, num_attention_heads=32, attention_head_size=32, hidden_size=1024, intermediate_size=4096, attention_dropout_prob=0.1, hidden_dropout_prob=0.1, activation="gelu", layernorm_epsilon=1e-5, initializer_range=0.02, use_normal_initialization=False, causal_mask_size=None, add_cross_attention=False, pre_layernorm=False, post_layernorm=True) - - A sequence of ``smdistributed.modelparallel.torch.nn.DistributedTransformerLayer``\ s, whose - number is given by ``num_layers`` argument. For the other - arguments and methods, refer to - ``smdistributed.modelparallel.torch.nn.DistributedTransformerLayer``. + A sequence of :class:`smdistributed.modelparallel.torch.nn.DistributedTransformerLayer`\ s, whose + number is given by ``num_layers`` argument. For the other + arguments and methods, refer to + :class:`smdistributed.modelparallel.torch.nn.DistributedTransformerLayer`. + - If both ``pre_layernorm`` and ``post_layernorm`` are ``True``, layer normalization is applied to both the input and the output of the ``DistributedTransformer``, in addition to the intermediate @@ -240,9 +246,13 @@ Tensor Parallelism Module APIs .. class:: smdistributed.modelparallel.torch.nn.DistributedTransformerLayer(num_attention_heads=32, attention_head_size=32, hidden_size=1024, intermediate_size=4096, attention_dropout_prob=0.1, hidden_dropout_prob=0.1, activation="gelu", layernorm_epsilon=1e-5, initializer_range=0.02, use_normal_initialization=False, causal_mask_size=None, add_cross_attention=False, pre_layernorm=False, post_layernorm=True) - - Tensor-parallel implementation of a single transformer layer. - Number of attention heads, hidden size, and intermediate size - refer to the global quantities across all tensor-parallel ranks. + Tensor-parallel implementation of a single transformer layer. + Number of attention heads, hidden size, and intermediate size + refer to the global quantities across all tensor-parallel ranks. + + For more information about what's the reference implementation of this module, + see :ref:`smdmp-tp-appendix`. + - **Arguments:** - ``num_attention_heads``: The total number of attention heads @@ -342,10 +352,14 @@ Tensor Parallelism Module APIs .. class:: smdistributed.modelparallel.torch.nn.DistributedAttentionLayer(num_attention_heads=32, attention_head_size=32, hidden_size=1024, attention_dropout_prob=0.1, hidden_dropout_prob=0.1, layernorm_epsilon=1e-5, initializer_range=0.02, use_normal_initialization=False, cross_attention=False, causal_mask_size=None, pre_layernorm=False, post_layernorm=True) - - A distributed implementation for the attention block. Includes the - computation of the self- or cross-attention (context layer), - followed by a linear mapping and dropout, which is optionally - followed by the residual-connection and layer normalization. + A distributed implementation for the attention block. Includes the + computation of the self- or cross-attention (context layer), + followed by a linear mapping and dropout, which is optionally + followed by the residual-connection and layer normalization. + + For more information about what's the reference implementation of this module, + see :ref:`smdmp-tp-appendix`. + - **Arguments:** - See :class:`smdistributed.modelparallel.torch.nn.DistributedTransformerLayer` for descriptions of the @@ -396,6 +410,10 @@ Tensor Parallelism Module APIs ``DistributedTransformerOutputLayer``. The latter linearly maps the last channel of the input tensor from ``hidden_size`` to ``intermediate_size``, and then maps it back to ``hidden_size``. + + For more information about what's the reference implementation of this module, + see :ref:`smdmp-tp-appendix`. + - **Arguments:** - See :class:`smdistributed.modelparallel.torch.nn.DistributedTransformerLayer` for descriptions of the From 82ca9c1f0d384bb8dd8376786dbbcc0ba325358f Mon Sep 17 00:00:00 2001 From: Miyoung Choi Date: Wed, 13 Jul 2022 17:19:09 -0700 Subject: [PATCH 08/23] rm temp methods --- .../latest/smd_model_parallel_pytorch.rst | 69 +------------------ 1 file changed, 1 insertion(+), 68 deletions(-) diff --git a/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch.rst b/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch.rst index 520542846a..1d3ea83337 100644 --- a/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch.rst +++ b/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch.rst @@ -493,74 +493,7 @@ smdistributed.modelparallel.torch.DistributedOptimizer a partial \ ``state_dict``, which indicates whether the ``state_dict`` contains elements corresponding to only the current partition, or to the entire model. - - .. method:: save_optimizer_backcompat(cast_to_cpu=True, gather_if_shard=True, fp32_states_only=False) - - Gets the local optimizer states and FP16 states if FP16 training is enabled. - - :param cast_to_cpu: Whether to cast the optimizer states and FP16 states to CPU. - :type cast_to_cpu: boolean - :param gather_if_shard: (for smdistributed-modelparallel v1.10 only) - Whether to gather the optimizer states and FP16 states to the 0th - ``rdp_rank`` when using the `optimizer state sharding - `_ feature. - If you want to save optimizer and also further reduce CPU memory - utilization for better performance, turn it off by setting - ``gather_if_shard=False``. However, you need to make sure that you - save the states on all ``rdp_rank``\ s. To handle both cases, - use the following example code. - - :type gather_if_shard: boolean - :param fp32_states_only: Whether to return the FP32 optimizer states only. - :type fp32_states_only: boolean - - **Example Usage:** - - .. code:: python - - import smdistributed.modelparallel.torch as smp - - # wrap optimizer - optimizer = torch.optim.AdaDelta(...) - optimizer = smp.DistributedOptimizer(optimizer) - - # save optimizer - save_dict["optimizer"] = optimizer.save_optimizer_backcompat( - gather_if_shard=args.gather_if_shard - ) - if not args.gather_if_shard or smp.rdp_rank() == 0: - smp.save( - save_dict, output_save_file, partial=True, - v3=not args.gather_if_shard - ) - - The ``v3`` argument of the ``smp.save()`` function checks whether the value of - the ``gather_if_shard`` arg is ``True`` or ``False``. - If ``gather_if_shard=False``, the ``v3`` arg helps collect optimizer checkpoint - files by adding ``pp_rank``, ``tp_rank``, and ``rdp_rank`` as postfix - to avoid overwriting optimizer checkpoint files. - - .. method:: load_optimizer_backcompat(state_dict, gather_if_shard=False) - - Loads the saved optimizer states and FP16 states if FP16 training is enabled. - - :param state_dict: The ``state_dict`` to load. - :type state_dict: dict - :param gather_if_shard: Specify whether the optimizer state was saved with ``gather_if_shard=True`` - when using the :class:`smdistributed.modelparallel.torch.DistributedOptimizer.save_optimizer_backcompat()` method. - :type gather_if_shard: boolean - - **Example Usage:** - - .. code:: python - - import smdistributed.modelparallel.torch as smp - - # load optimizer - checkpoint = smp.load(local_ckpt_path, partial=True) - optimizer.load_optimizer_backcompat( - checkpoint["optimizer"], gather_if_shard=args.gather_if_shard - ) + smdistributed.modelparallel.torch Context Managers and Util Functions ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ From fc94daaa59a746dda69a19202810062c57b99ff5 Mon Sep 17 00:00:00 2001 From: Miyoung Choi Date: Wed, 13 Jul 2022 21:06:03 -0700 Subject: [PATCH 09/23] add new checkpoint save/load functions, doc improvement --- .../latest/smd_model_parallel_pytorch.rst | 111 +++++++++++++++--- 1 file changed, 94 insertions(+), 17 deletions(-) diff --git a/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch.rst b/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch.rst index 1d3ea83337..cb1c37e400 100644 --- a/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch.rst +++ b/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch.rst @@ -17,7 +17,7 @@ import the ``smdistributed.modelparallel.torch`` package at the top of your trai to learn how to use the following API in your PyTorch training script. .. contents:: Topics - :depth: 3 + :depth: 1 :local: smdistributed.modelparallel.torch.DistributedModel @@ -421,14 +421,14 @@ smdistributed.modelparallel.torch.DistributedOptimizer * ``"consecutive_hysteresis"``: Default is ``False`` :type dynamic_loss_args: dict - **Example Usage for an FP32 Optimizer:** + **Example usage of an FP32 Optimizer:** .. code:: python optimizer = torch.optim.AdaDelta(...) optimizer = smdistributed.modelparallel.torch.DistributedOptimizer(optimizer) - **Example Usage for an FP16 Optimizer with static loss scale:** + **Example usage of an FP16 Optimizer with static loss scale:** .. code:: python @@ -438,7 +438,7 @@ smdistributed.modelparallel.torch.DistributedOptimizer static_loss_scale=1.0 ) - **Example Usage for an FP16 Optimizer with dynamic loss scale:** + **Example usage of an FP16 Optimizer with dynamic loss scale:** .. code:: python @@ -493,7 +493,7 @@ smdistributed.modelparallel.torch.DistributedOptimizer a partial \ ``state_dict``, which indicates whether the ``state_dict`` contains elements corresponding to only the current partition, or to the entire model. - + smdistributed.modelparallel.torch Context Managers and Util Functions ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -515,10 +515,15 @@ smdistributed.modelparallel.torch Context Managers and Util Functions * Any model that causes out-of-memory problems with FP32 initialization is recommended to be created with :class:`smdistributed.modelparallel.torch.delayed_parameter_initialization`. - * ``FP16_Module`` casts the model back to FP16 if FP16 training is enabled with the ``smp`` config. - :type dtype: torch.dtype + * ``FP16_Module`` casts the model back to FP16 if FP16 training is enabled + with the ``smp`` config. For more inforamtion about FP16 training + in SageMaker with the model parallel library, see `FP16 Training + `_ + in the *Amazon SageMaker Developer Guide*. + + :type dtype: ``torch.dtype`` :param distribute_embedding: Whether to enable vocabulary parallelism for NLP models. - :type dtype: boolean + :type distribute_embedding: boolean :param tensor_parallel_config: kwargs to specifiy other tensor parallel configs. This is not used if ``tensor_parallelism`` is ``False``. :type tensor_parallel_config: dict @@ -639,12 +644,13 @@ smdistributed.modelparallel.torch Context Managers and Util Functions .. _pytorch_saving_loading: -APIs for Saving and Loading -^^^^^^^^^^^^^^^^^^^^^^^^^^^ +smdistributed.modelparallel.torch APIs for Saving and Loading +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -.. function:: smdistributed.modelparallel.torch.save( ) +.. function:: smdistributed.modelparallel.torch.save(obj, f, partial=True, pickel_module=picklemodule, pickle_protocol=2, ) - Saves an object. This operation is similar to ``torch.save()``, except + Saves an object. This operation is similar to `torch.save() + `_, except that it has an additional keyword argument, ``partial``, and accepts only string type for the argument ``f`` (file). If ``partial=True``, each ``mp_rank`` saves a separate checkpoint file and the library adds an ``mp_rank`` @@ -663,10 +669,8 @@ APIs for Saving and Loading A module used for pickling metadata and objects. - ``pickle_protocol``  (int, default=2): Can be specified to override the defaultprotocol. - - ``v3`` (bool, default=``False``): When set to ``True``, save optimizer state checkpoints - in V3 file format to add all ``pp_rank``, ``tp_rank``, and ``rdp_rank`` as postfix. -.. function:: smdistributed.modelparallel.torch.load( ) +.. function:: smdistributed.modelparallel.torch.load(f, map_location, pickle_module, pickle_load_args, partial=True) Loads an object saved with ``smdistributed.modelparallel.torch.save()`` from a file. @@ -690,10 +694,83 @@ APIs for Saving and Loading ``mp_rank`` loads the checkpoint corresponding to the ``mp_rank``. Should be used when loading a model trained with the library. +.. function:: smdistributed.modelparallel.torch.save_checkpoint(path, tag, partial=True, model=None, optimizer=None, user_content=None, translate_if_full=True, num_kept_partial_checkpoints=None) + + Saves a checkpoint. While :class:`smdistributed.modelparallel.torch.save` saves + model and optimizer objects, + this function checkpoints model and optimizer and saves the checkpoints as separate files. + It creates checkpoint folders in the following structure. + + .. code:: text + + - path + - ${tag}_partial (folder for partial checkpoint) + - model_rankinfo.pt + - optimizer_rankinfo.pt + - fp16_states_rankinfo.pt + - user_content.pt + - $tag (checkpoint file for full checkpoint) + - user_content_$tag (user_content file for full checkpoint) + - newest (a file that indicates the newest checkpoint) + + **Parameters** + + * ``path`` (str) (required): Path to save the checkpoint. The library creates + the directory if it does not already exist. + For example, ``/opt/ml/checkpoint/model_parallel``. + * ``tag`` (str) (required): A tag for the current checkpoint, usually the train + steps. Note: tag needs to be the same across all ranks (GPU workers). + When ``partial=False`` this will be the checkpoint file name. + * ``partial`` (boolean) (default: True): Whether to save the partial checkpoint. + * ``model`` (:class:`smdistributed.modelparallel.torch.DistributedModel`) + (default: None): The model to save. It needs to an ``smp.DistributedModel`` object. + * ``optimizer`` (:class:`smdistributed.modelparallel.torch.DistributedOptimizer`) + (default: None): The optimizer to save. It needs to be an ``smp.DistributedOptimizer`` object. + * ``user_content`` (any) (default: None): User-defined content to save. + * ``translate_if_full`` (boolean) (default: True): Whether to translate the + full ``state_dict`` to HF ``state_dict`` if possible. + * ``num_kept_partial_checkpoints`` (int) (default: None): The maximum number + of partial checkpoints to keep on disk. + +.. function:: smdistributed.modelparallel.torch.resume_from_checkpoint(path, tag=None, partial=True, strict=True, load_optimizer_states=True, translate_function=None) + + While :class:`smdistributed.modelparallel.torch.load` loads saved + model and optimizer objects, this function resumes from a saved checkpoint file. + + **Parameters** + + * ``path`` (str) (required): Path to load the checkpoint. + * ``tag`` (str) (default: None): Tag of the checkpoint to resume. If not provided, + the library tries to locate the newest checkpoint from the saved newest file. + * ``partial`` (boolean) (default: True): Whether to load the partial checkpoint. + * ``strict`` (boolean) (default: True): Load with strict load, no extra key or + missing key is allowed. + * ``load_optimizer_states`` (boolean) (default: True): Whether to load ``optimizer_states``. + * ``translate_function`` (function) (default: None): function to translate the full + checkpoint into smdistributed.modelparallel format. + For supported models, this is not required. + + **Example usage** + + .. code:: python + + # Save + smp.save_checkpoint( + checkpoint_dir, + tag=f"total_steps{total_steps}", + partial=True, + model=model, + optimizer=optimizer, + user_content=user_content + num_kept_partial_checkpoints=args.num_kept_checkpoints) + + # Load: this will automatically load the newest checkpoint + user_content = smp.resume_from_checkpoint(path, partial=partial) + .. _pytorch_saving_loading_instructions: -General Instruction For Saving and Loading -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +General instruction on saving and loading +----------------------------------------- The library can save partial or full checkpoints. From 03d8abc0c12e316803eb90ffead7dc8b26855f70 Mon Sep 17 00:00:00 2001 From: Miyoung Choi Date: Thu, 14 Jul 2022 10:39:16 -0700 Subject: [PATCH 10/23] pass doc8 --- .../smd_model_parallel_pytorch_tensor_parallel.rst | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch_tensor_parallel.rst b/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch_tensor_parallel.rst index c101e0025d..de7d20aaa2 100644 --- a/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch_tensor_parallel.rst +++ b/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch_tensor_parallel.rst @@ -238,11 +238,11 @@ Tensor Parallelism Module APIs number is given by ``num_layers`` argument. For the other arguments and methods, refer to :class:`smdistributed.modelparallel.torch.nn.DistributedTransformerLayer`. - - - If both ``pre_layernorm`` and ``post_layernorm`` are ``True``, - layer normalization is applied to both the input and the output of - the ``DistributedTransformer``, in addition to the intermediate - attention and transformer-output layers. + + If both ``pre_layernorm`` and ``post_layernorm`` are ``True``, + layer normalization is applied to both the input and the output of + the ``DistributedTransformer``, in addition to the intermediate + attention and transformer-output layers. .. class:: smdistributed.modelparallel.torch.nn.DistributedTransformerLayer(num_attention_heads=32, attention_head_size=32, hidden_size=1024, intermediate_size=4096, attention_dropout_prob=0.1, hidden_dropout_prob=0.1, activation="gelu", layernorm_epsilon=1e-5, initializer_range=0.02, use_normal_initialization=False, causal_mask_size=None, add_cross_attention=False, pre_layernorm=False, post_layernorm=True) From a821a5cdc424c9b0dae60782b851abead6c68cba Mon Sep 17 00:00:00 2001 From: Miyoung Choi Date: Thu, 14 Jul 2022 22:35:09 -0700 Subject: [PATCH 11/23] Trigger Build From 68518bc42a29aaecb7cd70d0d19524fdd2f924aa Mon Sep 17 00:00:00 2001 From: Miyoung Choi Date: Wed, 22 Jun 2022 20:08:31 -0700 Subject: [PATCH 12/23] archive doc for past versions --- doc/api/training/smp_versions/archives.rst | 1 + doc/api/training/smp_versions/latest.rst | 2 +- .../v1.9.0/smd_model_parallel_common_api.rst | 538 +++++++++++ .../v1.9.0/smd_model_parallel_pytorch.rst | 678 ++++++++++++++ ...model_parallel_pytorch_tensor_parallel.rst | 875 ++++++++++++++++++ .../v1.9.0/smd_model_parallel_tensorflow.rst | 171 ++++ doc/api/training/smp_versions/v1_9_0.rst | 13 + 7 files changed, 2277 insertions(+), 1 deletion(-) create mode 100644 doc/api/training/smp_versions/v1.9.0/smd_model_parallel_common_api.rst create mode 100644 doc/api/training/smp_versions/v1.9.0/smd_model_parallel_pytorch.rst create mode 100644 doc/api/training/smp_versions/v1.9.0/smd_model_parallel_pytorch_tensor_parallel.rst create mode 100644 doc/api/training/smp_versions/v1.9.0/smd_model_parallel_tensorflow.rst create mode 100644 doc/api/training/smp_versions/v1_9_0.rst diff --git a/doc/api/training/smp_versions/archives.rst b/doc/api/training/smp_versions/archives.rst index fe893928ef..8c87476e99 100644 --- a/doc/api/training/smp_versions/archives.rst +++ b/doc/api/training/smp_versions/archives.rst @@ -3,6 +3,7 @@ .. toctree:: :maxdepth: 1 + v1_9_0.rst v1_6_0.rst v1_5_0.rst v1_4_0.rst diff --git a/doc/api/training/smp_versions/latest.rst b/doc/api/training/smp_versions/latest.rst index 49085d9347..ee606b8c34 100644 --- a/doc/api/training/smp_versions/latest.rst +++ b/doc/api/training/smp_versions/latest.rst @@ -10,7 +10,7 @@ depending on which version of the library you need to use. To use the library, reference the **Common API** documentation alongside the framework specific API documentation. -Version 1.7.0, 1.8.0, 1.8.1, 1.9.0 (Latest) +Version 1.10.0 (Latest) =========================================== To use the library, reference the Common API documentation alongside the framework specific API documentation. diff --git a/doc/api/training/smp_versions/v1.9.0/smd_model_parallel_common_api.rst b/doc/api/training/smp_versions/v1.9.0/smd_model_parallel_common_api.rst new file mode 100644 index 0000000000..b4713b2707 --- /dev/null +++ b/doc/api/training/smp_versions/v1.9.0/smd_model_parallel_common_api.rst @@ -0,0 +1,538 @@ +Common API +========== + +The following SageMaker distribute model parallel APIs are common across all frameworks. + +.. contents:: Table of Contents + :depth: 3 + :local: + +The Library's Core APIs +----------------------- + +This API document assumes you use the following import statement in your training scripts. + +**TensorFlow** + +.. code:: python + + import smdistributed.modelparallel.tensorflow as smp + +**PyTorch** + +.. code:: python + + import smdistributed.modelparallel.torch as smp + + +.. function:: smp.init( ) + :noindex: + + Initialize the library. Must be called at the beginning of training script. + +.. function:: @smp.step(non_split_inputs, input_split_axes, [*args, **kwargs]) + :noindex: + + A decorator that must be placed over a function that represents a single + forward and backward pass (for training use cases), or a single forward + pass (for evaluation use cases). Any computation that is defined inside + the ``smp.step``-decorated function is executed in a pipelined manner. + + By default, every tensor input to the function is split across its batch + dimension into a number of microbatches specified while launching the + training job. This behavior can be customized through the arguments to + ``smp.step``, described below. The library then orchestrates the execution of + each microbatch across all partitions, based on the chosen pipeline + type. + + In a typical use case, forward pass and back-propagation are executed + inside an \ ``smp.step``-decorated function and gradients, loss, and + other relevant metrics (such as accuracy, etc.) are returned from + ``smp.step``-decorated function. + + Any gradient post-processing operation, such as gradient clipping and + allreduce, as well as ``optimizer.apply_gradients`` calls (for TF) or + ``optimizer.step`` (for PT) should be applied on the gradients returned + from the ``smp.step`` function, and not inside the ``smp.step`` + function. This is because every operation inside ``smp.step`` is + executed once per microbatch, so having these operations inside + ``smp.step`` can either be inefficient (in the case of allreduce), or + lead to wrong results (in the case of ``apply_gradients`` / + ``optimizer.step``). + + If the objects returned from the ``smp.step``-decorated function contain + ``tf.Tensor``\ s / ``torch.Tensor``\ s, they are converted to + ``StepOutput`` objects. A ``StepOutput`` object encapsulates all + versions of the tensor across different microbatches + (see ``StepOutput`` entry for more information). + + The argument to ``smp.step`` decorated function should either be a tensor + or an instance of list, tuple, dict or set for it to be split across + microbatches. If your object doesn't fall into this category, you can make + the library split your object, by implementing ``smp_slice`` method. + + Below is an example of how to use it with PyTorch. + + .. code:: python + + class CustomType: + def __init__(self, tensor): + self.data = tensor + + # The library will call this to invoke slicing on the object passing in total microbatches (num_mb) + # and the current microbatch index (mb). + def smp_slice(self, num_mb, mb, axis): + dim_size = list(self.data.size())[axis] + + split_size = dim_size // num_mb + sliced_tensor = self.data.narrow(axis, mb * split_size, split_size) + return CustomType(sliced_tensor, self.other) + + custom_obj = CustomType(torch.ones(4,)) + + @smp.step() + def step(custom_obj): + loss = model(custom_obj) + model.backward(loss) + return loss + + + **Important:** ``smp.step`` splits the batch into microbatches, and + executes everything inside the decorated function once per microbatch. + This might affect the behavior of batch normalization, any operation + that explicitly uses the batch size information, or any other Python + code that is expected to run once. + + **TensorFlow-specific behavior** + + ``smp.step`` is a wrapper that + inherits from and extends the behavior of ``tf.function``, and as such, + all the caveats that apply to the use of ``tf.function``\ s also apply + to ``smp.step``. In particular, any operation that is inside + ``smp.step`` executes in graph mode, and not eager mode. + + In the first call, ``smp.step`` performs tracing of the wrapped function every time + one of the tensor arguments changes their shape or dtype, or for every + new value of a Python argument, if there is one. Tracing is expensive, + so such scenarios should be avoided as much as possible or, + alternatively, an ``input_signature`` argument must be provided. For + more information on the usage of ``tf.function``, refer to the + TensorFlow documentation: + + - https://www.tensorflow.org/api_docs/python/tf/function\ + - https://www.tensorflow.org/guide/function\ + + Each ``smp.step`` decorated function must have a return value that depends on the + output of ``smp.DistributedModel``. + + **Common parameters** + + - ``non_split_inputs`` (``list``): The list of arguments to the decorated function + that should not be split along the batch dimension. Should be used + for all input tensors that do not have a batch dimension. Should be a + list of argument names as ``str``, as they appear in the signature of + the ``smp.step``-decorated function. By default it is considered an + empty list. + + - ``input_split_axes`` (``dict``): A dict that maps the argument name to its batch + axis. The keys should be the argument names as ``str``, as they + appear in the signature of the ``smp.step``-decorated function.  By + default all batch axes are assumed to be the 0-axis. + + **TensorFlow-only parameters** + + - All arguments of ``tf.function``. Note: + The \ ``experimental_compile`` argument of ``tf.function`` may not + work as expected with ``smp.step``, since it interferes with + pipelining and model partitioning. To enable XLA with the library, you can + instead use \ ``tf.config.optimizer.set_jit(True)``. + + **PyTorch-only parameters** + + - ``detach_outputs`` (``bool``) : If ``True``, calls ``torch.Tensor.detach()`` on + all returned ``torch.Tensor`` outputs. Setting it to ``False`` + increases memory consumption, unless ``detach()`` is manually called + on the returned tensors, because the model graph is not cleared from + memory after the training step. Set to \ ``True`` by default. + + **Returns** + + - The same object(s) returned from the decorated function. All + returned \ ``tf.Tensor``, \ ``tf.Variable``  objects (for TF) or + ``torch.Tensor`` objects (for PT) are wrapped inside + a \ ``StepOutput`` object, even when they are inside a Python + ``list``, ``tuple``, or ``dict``. + + + +.. class:: StepOutput + :noindex: + + + A class that encapsulates all versions of a ``tf.Tensor`` + or \ ``torch.Tensor`` across all microbatches. + + When a particular ``tf.Tensor`` or ``torch.Tensor`` is computed inside + ``smp.step``, different versions of the tensor are computed for each + microbatch. + + When this tensor is returned from ``smp.step`` and is accessed outside + of the decorated function, it appears as a ``StepOutput`` object, which + contains all such versions. For example, + + - In the case of Tensorflow, the gradient for a particular + ``tf.Variable`` is computed on each microbatch individually, and if + this gradient is returned from ``smp.step``, all gradients for this + ``tf.Variable`` become part of the same ``StepOutput`` object. The + ``StepOutput`` class offers the following API for commonly-used + post-processing operations on such tensors. + - In the case of PyTorch, the loss for each microbatch is computed + individually and all the ``torch.Tensor``\ s that represent the loss + for different microbatches become part of same ``StepOutput`` object, + if loss is returned from the ``smp.step`` function. + + + The ``StepOutput`` class offers the following API for commonly-used + post-processing operations on tensors. + + .. data:: StepOutput.outputs + :noindex: + + Returns a list of the underlying tensors, indexed by microbatch. + + .. function:: StepOutput.reduce_mean( ) + :noindex: + + Returns a ``tf.Tensor``, ``torch.Tensor`` that averages the constituent ``tf.Tensor`` s + ``torch.Tensor`` s. This is commonly used for averaging loss and gradients across microbatches. + + .. function:: StepOutput.reduce_sum( ) + :noindex: + + Returns a ``tf.Tensor`` / + ``torch.Tensor`` that sums the constituent + ``tf.Tensor``\ s/\ ``torch.Tensor``\ s. + + .. function:: StepOutput.concat( ) + :noindex: + + Returns a + ``tf.Tensor``/``torch.Tensor`` that concatenates tensors along the + batch dimension using ``tf.concat`` / ``torch.cat``. + + .. function:: StepOutput.stack( ) + :noindex: + + Applies ``tf.stack`` / ``torch.stack`` + operation to the list of constituent ``tf.Tensor``\ s / + ``torch.Tensor``\ s. + + **TensorFlow-only methods** + + .. function:: StepOutput.merge( ) + :noindex: + + Returns a ``tf.Tensor`` that + concatenates the constituent ``tf.Tensor``\ s along the batch + dimension. This is commonly used for merging the model predictions + across microbatches. + + .. function:: StepOutput.accumulate(method="variable", var=None) + :noindex: + + Functionally the same as ``StepOutput.reduce_mean()``. However, it is + more memory-efficient, especially for large numbers of microbatches, + since it does not wait for all constituent \ ``tf.Tensor``\ s to be + ready to start averaging them, thereby saving memory. + + In some cases (XLA for example) ``StepOutput.reduce_mean()`` might end + up being more memory-efficient than ``StepOutput.accumulate()``. + + **Parameters** + + - ``method`` (``"add_n"`` or ``"accumulate_n"`` or ``"variable"``): + If ``"add_n"`` or ``"accumulate_n"``, the library uses + ``tf.add_n`` and ``tf.accumulate_n``, respectively, to implement + accumulation. If ``"variable"``, the library uses an internal ``tf.Variable`` + into which to accumulate the tensors. Default is \ ``"variable"``. + Note: Memory usage behavior of these choices can depend on the model + and implementation. + + - ``var``: A ``tf.Variable`` into which, if provided, the library uses to + accumulate the tensors. If \ ``None``, the library internally creates a + variable. If ``method`` is not ``"variable"``, this argument is + ignored. + +.. _mpi_basics: + :noindex: + +MPI Basics +---------- + +The library exposes the following basic MPI primitives to its Python API: + +**Global** + +- ``smp.rank()`` : The global rank of the current process. +- ``smp.size()`` : The total number of processes. +- ``smp.get_world_process_group()`` : + ``torch.distributed.ProcessGroup`` that contains all processes. +- ``smp.CommGroup.WORLD``: The communication group corresponding to all processes. +- ``smp.local_rank()``: The rank among the processes on the current instance. +- ``smp.local_size()``: The total number of processes on the current instance. +- ``smp.get_mp_group()``: The list of ranks over which the current model replica is partitioned. +- ``smp.get_dp_group()``: The list of ranks that hold different replicas of the same model partition. + +**Tensor Parallelism** + +- ``smp.tp_rank()`` : The rank of the process within its + tensor-parallelism group. +- ``smp.tp_size()`` : The size of the tensor-parallelism group. +- ``smp.get_tp_process_group()`` : Equivalent to + ``torch.distributed.ProcessGroup`` that contains the processes in the + current tensor-parallelism group. +- ``smp.CommGroup.TP_GROUP`` : The communication group corresponding to + the current tensor parallelism group. + +**Pipeline Parallelism** + +- ``smp.pp_rank()`` : The rank of the process within its + pipeline-parallelism group. +- ``smp.pp_size()`` : The size of the pipeline-parallelism group. +- ``smp.get_pp_process_group()`` : ``torch.distributed.ProcessGroup`` + that contains the processes in the current pipeline-parallelism group. +- ``smp.CommGroup.PP_GROUP`` : The communication group corresponding to + the current pipeline parallelism group. + +**Reduced-Data Parallelism** + +- ``smp.rdp_rank()`` : The rank of the process within its + reduced-data-parallelism group. +- ``smp.rdp_size()`` : The size of the reduced-data-parallelism group. +- ``smp.get_rdp_process_group()`` : ``torch.distributed.ProcessGroup`` + that contains the processes in the current reduced data parallelism + group. +- ``smp.CommGroup.RDP_GROUP`` : The communication group corresponding + to the current reduced data parallelism group. + +**Model Parallelism** + +- ``smp.mp_rank()`` : The rank of the process within its model-parallelism + group. +- ``smp.mp_size()`` : The size of the model-parallelism group. +- ``smp.get_mp_process_group()`` : ``torch.distributed.ProcessGroup`` + that contains the processes in the current model-parallelism group. +- ``smp.CommGroup.MP_GROUP`` : The communication group corresponding to + the current model parallelism group. + +**Data Parallelism** + +- ``smp.dp_rank()`` : The rank of the process within its data-parallelism + group. +- ``smp.dp_size()`` : The size of the data-parallelism group. +- ``smp.get_dp_process_group()`` : ``torch.distributed.ProcessGroup`` + that contains the processes in the current data-parallelism group. +- ``smp.CommGroup.DP_GROUP`` : The communication group corresponding to + the current data-parallelism group. + +.. _communication_api: + :noindex: + +Communication API +----------------- + +The library provides a few communication primitives which can be helpful while +developing the training script. These primitives use the following +``enum`` s as arguments to specify which processes the communication +should involve. +​ + +**Helper structures** + +.. data:: smp.CommGroup + :noindex: + + An ``enum`` that takes the values + ``CommGroup.WORLD``, ``CommGroup.MP_GROUP``, and ``CommGroup.DP_GROUP``. + These values can also be accessed as ``smp.WORLD``, ``smp.MP_GROUP``, + and ``smp.DP_GROUP`` respectively. + + - ``CommGroup.WORLD``: Represents the entire group of processes used in + training + - ``CommGroup.MP_GROUP``: Represents the group of processes that hold + the same model replica as the current process. The processes in a + single ``MP_GROUP`` collectively store an entire replica of the + model. + - ``CommGroup.DP_GROUP``: Represents the group of processes that hold + the same model partition as the current process. The processes in a + single ``DP_GROUP`` perform data parallelism/allreduce among + themselves. + +.. data:: smp.RankType + :noindex: + + An ``enum`` that takes the values + ``RankType.WORLD_RANK``, ``RankType.MP_RANK``, and ``RankType.DP_RANK``. + + - ``RankType.WORLD_RANK``: The associated rank is to be interpreted as + the rank of the process across all processes used in training. + - ``RankType.MP_RANK``: The associated rank is to be interpreted as the + rank of the process within the ``MP_GROUP``. + - ``RankType.DP_RANK``: The associated rank is to be interpreted as the + rank of the process within the ``DP_GROUP``. + + +**Communication primitives:** + +.. function:: smp.broadcast(obj, group) + :noindex: + + Sends the object to all processes in the + group. The receiving process must call ``smp.recv_from`` to receive the + sent object. + + **Inputs** + + - ``obj``: An arbitrary picklable Python object that will be broadcast. + + - ``group``: A ``CommGroup`` argument that represents to which group of + processes the object will be sent. + + **Notes** + + - When you use ``broadcast`` on the sender process, there needs + to be an accompanying ``smp.recv_from()`` call on the receiver + processes. + + - This is a synchronous call; the ``broadcast`` statement + returns only after all ranks participating in the call have made a + matching ``recv_from`` call. + + **Example** + + .. code:: python + + if smp.rank() == 0: +     smp.broadcast(something, group=smp.CommGroup.WORLD) + else: +     smp.recv_from(0, rank_type=smp.RankType.WORLD_RANK) + +.. function:: smp.send(obj, dest_rank, rank_type) + :noindex: + + Sends the object ``obj`` to + ``dest_rank``, which is of a type specified by ``rank_type``. + + **Inputs** + + - ``obj``: An arbitrary picklable Python object that will be sent. + + - ``dest_rank`` (``int``): An integer denoting the rank of the receiving process. + + - ``rank_type`` (``enum``): A ``smp.RankType`` ``enum`` that determines how + ``dest_rank`` is to be interpreted. For example if ``dest_rank`` is 1 + and ``rank_type`` is ``MP_RANK``, then ``obj`` is sent to process + with ``mp_rank`` 1 in the ``MP_GROUP`` which contains the current + process. + + **Notes** + + - Note: \ This is a synchronous call; the ``send`` statement returns + only after the destination rank has made a matching + ``recv_from`` call. + +.. function:: smp.recv_from(src_rank, rank_type) + :noindex: + + Receive an object from a peer process. Can be used with a matching + ``smp.send`` or a ``smp.broadcast`` call. + + **Inputs** + + - ``src_rank`` (``int``): An integer denoting rank of the sending process. + + - ``rank_type`` (``enum``): A ``smp.RankType`` ``enum`` that determines how + ``dest_rank`` is to be interpreted. For example if ``src_rank`` is 1 + and ``rank_type`` is ``MP_RANK``, then the object is received from + the process with ``mp_rank`` 1 in the ``MP_GROUP`` which contains the + current process. + + **Returns** + + Returns the python object that is sent by the peer process. + + **Notes** + + - Note: This is a synchronous call; the ``recv_from`` statement returns + only after the source rank has made a matching ``send`` or + ``broadcast`` call, and the object is received. + +.. function:: smp.allgather(obj, group) + :noindex: + + A collective call that gathers all the + submitted objects across all ranks in the specified ``group``. Returns a + list whose ``i``\ th index contains the object submitted by the + ``i``\ th rank in ``group``. + + **Inputs** + + - ``obj``: An arbitrary picklable Python object that will be + allgathered. + + - ``group`` : A ``CommGroup`` argument that represents which group of + processes participate in ``allgather``. + + **Notes** + + - Note: This is a synchronous call; the ``allgather`` statement returns + only after all ranks participating in the call have made a matching + ``allgather`` call, and all the objects are received at the current + rank. + + **Examples** + + .. code:: python + + # assuming mp_size() == 2 + + if smp.mp_rank() == 0: +     out = smp.allgather(obj1, smp.CommGroup.MP_GROUP)  # returns [obj1, obj2] + else: +     out = smp.allgather(obj2, smp.CommGroup.MP_GROUP)  # returns [obj1, obj2] + +.. function:: smp.barrier(group=smp.WORLD) + :noindex: + + A statement that hangs until all + processes in the specified group reach the barrier statement, similar to + ``MPI_Barrier()``. + + **Inputs** + + - ``group``: An ``smp.CommGroup`` ``enum`` that specifies the group of + processes participating in the barrier call. Defaults to + ``smp.WORLD``. + + **Examples** + + - Assume there are 8 processes and 2 model partitions, and + therefore 4 \ ``mp_group``\ s, and 2 ``dp_group``\ s. If + the \ ``barrier`` call is passed the value ``smp.MP_GROUP`` for its + group argument, then each process only waits until the other process + of its own ``mp_group`` reaches that point. It does not wait for + processes outside that ``mp_group``. + +.. function:: smp.dp_barrier() + :noindex: + + Same as passing ``smp.DP_GROUP``\ to ``smp.barrier()``. + Waits for the processes in the same \ ``dp_group`` as + the current process to reach the same point in execution. + +.. function:: smp.mp_barrier() + :noindex: + + Same as passing ``smp.MP_GROUP`` to + ``smp.barrier()``. Waits for the processes in the same ``mp_group`` as + the current process to reach the same point in execution. diff --git a/doc/api/training/smp_versions/v1.9.0/smd_model_parallel_pytorch.rst b/doc/api/training/smp_versions/v1.9.0/smd_model_parallel_pytorch.rst new file mode 100644 index 0000000000..055f2b6dde --- /dev/null +++ b/doc/api/training/smp_versions/v1.9.0/smd_model_parallel_pytorch.rst @@ -0,0 +1,678 @@ +PyTorch API +=========== + +To use the PyTorch-specific APIs for SageMaker distributed model parallism, +you need to add the following import statement at the top of your training script. + +.. code:: python + + import smdistributed.modelparallel.torch as smp + + +.. tip:: + + Refer to + `Modify a PyTorch Training Script + `_ + to learn how to use the following API in your PyTorch training script. + +.. class:: smp.DistributedModel + :noindex: + + A sub-class of ``torch.nn.Module`` which specifies the model to be + partitioned. Accepts a ``torch.nn.Module`` object ``module`` which is + the model to be partitioned. The returned ``DistributedModel`` object + internally manages model parallelism and data parallelism. Only one + model in the training script can be wrapped with + ``smp.DistributedModel``. + + **Example:** + + .. code:: python + + model = smp.DistributedModel(model) + + **Important**: The ``__call__`` and  ``backward`` method calls on the + ``smp.DistributedModel`` object (in the following example, the object + is \ ``model``) can only be made inside a ``smp.step``-decorated + function. + + Since ``DistributedModel``  is a ``torch.nn.Module``, a forward pass can + be performed by calling the \ ``DistributedModel`` object on the input + tensors. + + .. code:: python + + predictions = model(inputs)   # model is a smp.DistributedModel object + + For a backward pass, one needs to call the backward function on + the \ ``DistributedModel`` object, with tensors and gradients as + arguments, replacing the PyTorch operations \ ``torch.Tensor.backward`` + or ``torch.autograd.backward``. + + The API for ``model.backward`` is very similar to + ``torch.autograd.backward``. For example, the following + ``backward`` calls: + + .. code:: python + + torch.autograd.backward(loss) or loss.backward() + + should be replaced with: + + .. code:: python + + model.backward(loss) # loss is a tensor with only one element as its data + + Similarly, for non-scalar tensors, replace the following + ``backward`` call containing incoming gradient arguments: + + .. code:: python + + torch.autograd.backward(outputs, out_grads) + + with the following line: + + .. code:: python + + model.backward(outputs, out_grads) + + In these examples, all ``__call__``  and ``backward`` method calls on + the model objects (``model(inputs)`` and ``model.backward(loss)``) must be made inside + a ``smp.step``-decorated function. + + **Using DDP** + + If DDP is enabled with the SageMaker model parallel library, do not not place a PyTorch + ``DistributedDataParallel`` wrapper around the ``DistributedModel`` because + the ``DistributedModel`` wrapper will also handle data parallelism. + + Unlike the original DDP wrapper, when you use ``DistributedModel``, + model parameters and buffers are not immediately broadcast across + processes when the wrapper is called. Instead, the broadcast is deferred to the first call of the + ``smp.step``-decorated function when the partition is done. + + **Parameters** + + - ``module`` (``torch.nn.Module``): Module to be distributed (data parallelism and model parallelism). + + - ``trace_device`` (``"cpu"`` or ``"gpu"``) (default: ``"gpu"``) + Whether to perform the tracing step on the GPU or CPU. The tracing step gathers + information on the order of execution of modules, the shapes of + intermediate outputs, and execution times, to be used by the + partitioning algorithm. If ``trace_device`` is set to GPU, accurate + module execution times can be gathered during tracing for potentially + improved partitioning decision. However, if the model is too large to + fit in a single GPU, then ``trace_device`` should be set to CPU. + + - ``trace_execution_times`` (``bool``) (default: ``False``): If ``True``, + the library profiles the execution time of each module during tracing, and uses + it in the partitioning decision. This improves the partitioning + decision, but it might make the tracing slower. It may also introduce + some degree of non-determinism in partitioning results, because of the + inherent randomness in module execution times. Must be ``False`` if + ``trace_device`` is ``"cpu"``. + + - ``overlapping_allreduce`` (``bool``) (default: ``True``): This is only + applicable for hybrid data parallelism/model parallelism use cases (when + ``ddp`` is set to ``True`` while launching training). The library uses this flag + to decide whether to do overlapping allreduce whenever a parameter + gradients are ready. This leads to overlapping of communication and + computation and can improve performance. If this is set to ``False`` , + allreduce is performed at the end of the step. + + - ``backward_passes_per_step`` (``int``) (default: 1): This is only + applicable for hybrid data parallelism/model parallelism use cases (when + ``ddp`` is set to ``True`` in config). This parameter indicates the + number of backward passes to perform before calling allreduce on DDP. + This allows accumulating updates over multiple mini-batches before + reducing and applying them. + + - ``average_grads_across_microbatches`` (``bool``) (default: ``True``): + Whether or not the computed gradients should be averaged across + microbatches. If ``False``, the computed gradients will be summed across + microbatches, but not divided by the number of microbatches. In typical + use case where the computed loss is averaged over the mini-batch, this + should be left as ``True``. If you use a loss function that only sums + the per-sample loss across the batch (and not divide by the batch size), + then this must be set to ``False`` for correctness. + + - ``bucket_cap_mb`` (default: 25): \ ``DistributedDataParallel`` buckets + parameters into multiple buckets so that gradient reduction of each + bucket can potentially overlap with backward + computation. \ ``bucket_cap_mb``\ controls the bucket size in MegaBytes + (MB). + + - ``trace_memory_usage`` (default: False): When set to True, the library attempts + to measure memory usage per module during tracing. If this is disabled, + memory usage will be estimated through the sizes of tensors returned from + the module. + + - ``broadcast_buffers`` (default: True): Flag to be used with ``ddp=True``. + This parameter is forwarded to the underlying ``DistributedDataParallel`` wrapper. + Please see: `broadcast_buffer `__. + + - ``gradient_as_bucket_view`` (default: False): To be + used with ``ddp=True``. This parameter is forwarded to the underlying + ``DistributedDataParallel`` wrapper. Please see `gradient_as_bucket_view `__. + + **Properties** + + - ``partitioned``: Is ``True`` if the model is partitioned, ``False`` + otherwise. Initialized to ``False`` when ``DistributedModel`` is first + created. It becomes be ``True`` during the first call + to ``smp.step``-decorated function. Once the model is partitioned, the + local parameters or local ``state_dict`` can be fetched using the + following methods. + + **Methods** + + .. function:: backward(tensors, grad_tensors) + :noindex: + + Triggers a distributed backward + pass across model partitions. Example usage provided in the previous + section. The API is very similar + to https://pytorch.org/docs/stable/autograd.html#torch.autograd.backward. + ``retain_grad`` and ``create_graph``  flags are not supported. + + .. function:: local_buffers( ) + :noindex: + + Returns an iterator over buffers for the modules in + the partitioned model that have been assigned to the current process. + + .. function:: local_named_buffers( ) + :noindex: + + Returns an iterator over buffers for the + modules in the partitioned model that have been assigned to the current + process. This yields both the name of the buffer as well as the buffer + itself. + + .. function:: local_parameters( ) + :noindex: + + Returns an iterator over parameters for the + modules in the partitioned model that have been assigned to the current + process. + + .. function:: local_named_parameters( ) + :noindex: + + Returns an iterator over parameters for + the modules in the partitioned model that have been assigned to the + current process. This yields both the name of the parameter as well as + the parameter itself. + + .. function:: local_modules( ) + :noindex: + + Returns an iterator over the modules in the + partitioned model that have been assigned to the current process. + + .. function:: local_named_modules( ) + :noindex: + + Returns an iterator over the modules in the + partitioned model that have been assigned to the current process. This + yields both the name of the module as well as the module itself. + + .. function:: local_state_dict( ) + :noindex: + + Returns the ``state_dict`` that contains local + parameters that belong to the current \ ``mp_rank``. This ``state_dict`` + contains a key \ ``_smp_is_partial`` to indicate this is a + partial \ ``state_dict``, which indicates whether the + ``state_dict`` contains elements corresponding to only the current + partition, or to the entire model. + + .. function:: state_dict( ) + :noindex: + + Returns the ``state_dict`` that contains parameters + for the entire model. It first collects the \ ``local_state_dict``  and + gathers and merges the \ ``local_state_dict`` from all ``mp_rank``\ s to + create a full ``state_dict``. Please note that this needs to be called on all ranks with + ``dp_rank()==0`` to ensure the gather happens properly. + If it is only called on all such ranks, it can hang. + + .. function:: load_state_dict( ) + :noindex: + + Same as the ``torch.module.load_state_dict()`` , + except: It first gathers and merges the ``state_dict``\ s across + ``mp_rank``\ s, if they are partial. The actual loading happens after the + model partition so that each rank knows its local parameters. + + .. function:: register_post_partition_hook(hook) + :noindex: + + Registers a callable ``hook`` to + be executed after the model is partitioned. This is useful in situations + where an operation needs to be executed after the model partition during + the first call to ``smp.step``, but before the actual execution of the + first forward pass. Returns a ``RemovableHandle`` object ``handle``, + which can be used to remove the hook by calling ``handle.remove()``. + + .. function:: cpu( ) + :noindex: + + Allgathers parameters and buffers across all ``mp_rank``\ s and moves them + to the CPU. + + .. function:: join( ) + :noindex: + + A context manager to be used in conjunction with an instance of + ``smp.DistributedModel`` to be able to train with uneven inputs across + participating processes. This is only supported when ``ddp=True``. This will use the join with the wrapped + ``DistributedDataParallel`` instance. For more information, see: + `join `__ + in the PyTorch documentation. + + .. function:: register_comm_hook( state, callable ) + :noindex: + + **Available for PyTorch 1.8.1 only** + Registers a communication hook which is an enhancement that provides + a flexible hook ``callable`` to users where they can specify how + gradients are aggregated across multiple workers. This method will be called on the wrapped ``DistributedDataParallel`` instance. + + Please note that when you register a comm hook you have full control of how the gradients are processed. + When using only data parallelism with Torch DDP you are expected to average grads across data parallel replicas within the hook. + Similarly, when using DistributedModel you have to averaging grads across data parallel replicas within the hook. + In addition to that, you also have to average grads across microbatches within the hook unless you explicitly desire to not average based on your loss function. + See ``average_grads_across_microbatches`` for more information about averaging grads across microbatches. + + This is only supported when ``ddp=True`` and ``overlapping_allreduce=True`` (default). + For more information, see: + `register_comm_hook `__ + in the PyTorch documentation. + + **Behavior of** ``smp.DistributedModel`` **with Tensor Parallelism** + + When a model is wrapped by ``smp.DistributedModel``, the library + immediately traverses the modules of the model object, and replaces the + modules that are supported for tensor parallelism with their distributed + counterparts. This replacement happens in place. If there are no other + references to the original modules in the script, they are + garbage-collected. The module attributes that previously referred to the + original submodules now refer to the distributed versions of those + submodules. + + **Example:** + + .. code:: python + + # register DistributedSubmodule as the distributed version of Submodule + # (note this is a hypothetical example, smp.nn.DistributedSubmodule does not exist) + smp.tp_register_with_module(Submodule, smp.nn.DistributedSubmodule) + + class MyModule(nn.Module): + def __init__(self): + ... + + self.submodule = Submodule() + ... + + # enabling tensor parallelism for the entire model + with smp.tensor_parallelism(): + model = MyModule() + + # here model.submodule is still a Submodule object + assert isinstance(model.submodule, Submodule) + + model = smp.DistributedModel(model) + + # now model.submodule is replaced with an equivalent instance + # of smp.nn.DistributedSubmodule + assert isinstance(model.module.submodule, smp.nn.DistributedSubmodule) + + If ``pipeline_parallel_degree`` (equivalently, ``partitions``) is 1, the + placement of model partitions into GPUs and the initial broadcast of + model parameters and buffers across data-parallel ranks take place + immediately. This is because it does not need to wait for the model + partition when ``smp.DistributedModel`` wrapper is called. For other + cases with ``pipeline_parallel_degree`` greater than 1, the broadcast + and device placement will be deferred until the first call of an + ``smp.step``-decorated function happens. This is because the first + ``smp.step``-decorated function call is when the model partitioning + happens if pipeline parallelism is enabled. + + Because of the module replacement during the ``smp.DistributedModel`` + call, any ``load_state_dict`` calls on the model, as well as any direct + access to model parameters, such as during the optimizer creation, + should be done **after** the ``smp.DistributedModel`` call. + + Since the broadcast of the model parameters and buffers happens + immediately during ``smp.DistributedModel`` call when the degree of + pipeline parallelism is 1, using ``@smp.step`` decorators is not + required when tensor parallelism is used by itself (without pipeline + parallelism). + + For more information about the library's tensor parallelism APIs for PyTorch, + see :ref:`smdmp-pytorch-tensor-parallel`. + + **Additional Methods of** ``smp.DistributedModel`` **for Tensor Parallelism** + + The following are the new methods of ``smp.DistributedModel``, in + addition to the ones listed in the + `documentation `__. + + .. function:: distributed_modules() + :noindex: + + - An iterator that runs over the set of distributed + (tensor-parallelized) modules in the model + + .. function:: is_distributed_parameter(param) + :noindex: + + - Returns ``True`` if the given ``nn.Parameter`` is distributed over + tensor-parallel ranks. + + .. function:: is_distributed_buffer(buf) + :noindex: + + - Returns ``True`` if the given buffer is distributed over + tensor-parallel ranks. + + .. function:: is_scaled_batch_parameter(param) + :noindex: + + - Returns ``True`` if the given ``nn.Parameter`` is operates on the + scaled batch (batch over the entire ``TP_GROUP``, and not only the + local batch). + + .. function:: is_scaled_batch_buffer(buf) + :noindex: + + - Returns ``True`` if the parameter corresponding to the given + buffer operates on the scaled batch (batch over the entire + ``TP_GROUP``, and not only the local batch). + + .. function:: default_reducer_named_parameters() + :noindex: + + - Returns an iterator that runs over ``(name, param)`` tuples, for + ``param`` that is allreduced over the ``DP_GROUP``. + + .. function:: scaled_batch_reducer_named_parameters() + :noindex: + + - Returns an iterator that runs over ``(name, param)`` tuples, for + ``param`` that is allreduced over the ``RDP_GROUP``. + + + +.. class:: smp.DistributedOptimizer + :noindex: + + **Parameters** + - ``optimizer`` + + An optimizer wrapper for saving/loading optimizer states. This wrapper + returns ``optimizer`` with the following methods overridden: + + .. function:: state_dict( ) + :noindex: + + Returns the ``state_dict`` that contains optimizer state for the entire model. + It first collects the ``local_state_dict`` and gathers and merges + the ``local_state_dict`` from all ``mp_rank``s to create a full + ``state_dict``. + + .. function:: load_state_dict( ) + :noindex: + + Same as the ``torch.optimizer.load_state_dict()`` , except: + + - It first gathers and merges the local ``state_dict``\ s if they are + partial. + - The actual loading happens after the model partition so that each + rank knows its local parameters. + + .. function:: local_state_dict( ) + :noindex: + + Returns the ``state_dict`` that contains the + local optimizer state that belongs to the current \ ``mp_rank``. This + ``state_dict`` contains a key \ ``_smp_is_partial`` to indicate this is + a partial \ ``state_dict``, which indicates whether the + ``state_dict`` contains elements corresponding to only the current + partition, or to the entire model. + + ​ +.. function:: smp.partition(index) + :noindex: + + **Inputs** + + - ``index`` (int) - The index of the partition. + + A context manager which places all modules defined inside into the + partition with ID ``index``.  The ``index`` argument must be less than + the number of partitions. + + Use ``smp.partition`` to implement manual partitioning. + If ``"auto_partition"`` is ``True``, then the + ``smp.partition`` contexts are ignored. Any module that is not placed in + any ``smp.partition`` context is placed in the + ``default_partition`` defined through the SageMaker Python SDK. + + When ``smp.partition`` contexts are nested, the innermost context + overrides the rest (see the following example). In PyTorch, manual + partitioning should be done inside the module \ ``__init__``, and the + partition assignment applies to the modules that are *created* inside + the ``smp.partition`` context. + + Example: + + .. code:: python + + class Model(torch.nn.Module): +     def __init__(self): +         with smp.partition(1): +             self.child0 = Child0()            # child0 on partition 1 +             with smp.partition(2): +                 self.child1 = Child1()        # child1 on partition 2 +             self.child2 = Child2()            # child2 on partition 1 +         self.child3 = Child3()                # child3 on default_partition + +.. function:: smp.get_world_process_group( ) + :noindex: + + Returns a ``torch.distributed`` ``ProcessGroup`` that consists of all + processes, which can be used with the ``torch.distributed`` API. + Requires ``"ddp": True`` in SageMaker Python SDK parameters. + +.. function:: smp.get_mp_process_group( ) + :noindex: + + Returns a ``torch.distributed`` ``ProcessGroup`` that consists of the + processes in the ``MP_GROUP`` which contains the current process, which + can be used with the \ ``torch.distributed`` API. Requires + ``"ddp": True`` in SageMaker Python SDK parameters. + +.. function:: smp.get_dp_process_group( ) + :noindex: + + Returns a ``torch.distributed`` ``ProcessGroup`` that consists of the + processes in the ``DP_GROUP`` which contains the current process, which + can be used with the \ ``torch.distributed`` API. Requires + ``"ddp": True`` in SageMaker Python SDK parameters. + +.. function:: smp.is_initialized( ) + :noindex: + + Returns ``True`` if ``smp.init`` has already been called for the + process, and ``False`` otherwise. + +.. function::smp.is_tracing( ) + :noindex: + + Returns ``True`` if the current process is running the tracing step, and + ``False`` otherwise. + +.. data:: smp.nn.FusedLayerNorm + :noindex: + + `Apex Fused Layer Norm `__ is currently not + supported by the library. ``smp.nn.FusedLayerNorm`` replaces ``apex`` + ``FusedLayerNorm`` and provides the same functionality. This requires + ``apex`` to be installed on the system. + +.. data:: smp.optimizers.FusedNovoGrad + :noindex: + + + `Fused Novo Grad optimizer `__ is + currently not supported by the library. ``smp.optimizers.FusedNovoGrad`` replaces ``apex`` ``FusedNovoGrad`` + optimizer and provides the same functionality. This requires ``apex`` to + be installed on the system. + +.. data:: smp.optimizers.FusedLamb + :noindex: + + + `FusedLamb optimizer `__ + currently doesn’t work with the library. ``smp.optimizers.FusedLamb`` replaces + ``apex`` ``FusedLamb`` optimizer and provides the same functionality. + This requires ``apex`` to be installed on the system. + +.. data:: smp.amp.GradScaler + :noindex: + + `Torch AMP Gradscaler `__ + currently doesn’t work with the library. ``smp.amp.GradScaler`` replaces + ``torch.amp.GradScaler`` and provides the same functionality. + +.. _pytorch_saving_loading: + :noindex: + +APIs for Saving and Loading +^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. function:: smp.save( ) + :noindex: + + Saves an object. This operation is similar to ``torch.save()``, except + it has an additional keyword argument, ``partial``, and accepts only + string type for the argument ``f`` (file). If ``partial=True``, each + ``mp_rank`` saves a separate checkpoint file and the library adds an ``mp_rank`` + index to your saved file. + + **Parameters** + + - ``obj`` (dict): A saved object. + - ``f`` (str): A string containing a file name. + - ``partial`` (bool, default= ``True``):  When set to ``True``, each + ``mp_rank`` saves a separate checkpoint file and the library adds an + ``mp_rank`` index to the saved file. If you want to be able to load + and further train a model that you save with ``smp.save()``, you must + set ``partial=True``. + - ``pickle_module`` (picklemodule, default = module ``"pickle"`` from ``"/opt/conda/lib/python3.6/pickle.py"``): + A module used for pickling metadata and objects. + - ``pickle_protocol``  (int, default=2): Can be specified to + override the defaultprotocol. + +.. function:: smp.load( ) + :noindex: + + Loads an object saved with ``smp.save()`` from a file. + + Similar to, `torch.load() `__, + except it has an additional keyword argument, ``partial``, and accepts + only string type for the argument ``f`` (file). If \ ``partial=True``, + then each ``mp_rank`` loads a separate checkpoint file. + + **Parameters** + + - ``f`` (string): A string containing a file name. + - ``map_location`` (function): A function + `torch.device `__, + a string, or a dict specifying how to remap storage locations. + - ``pickle_module`` (pickle module): A module used for unpickling + metadata and objects (has to match the \ ``pickle_module``\ used to + serialize file). + - ``pickle_load_args`` (Python 3 only): Optional keyword arguments + passed to ``pickle_module.load()`` and ``pickle_module.Unpickler()``. + - ``partial`` (bool, default= ``True``): When set to ``True``, each + ``mp_rank`` loads the checkpoint corresponding to the ``mp_rank``. + Should be used when loading a model trained with the library. + +.. _pytorch_saving_loading_instructions: + :noindex: + +General Instruction For Saving and Loading +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The library can save partial or full checkpoints. + +- For partial checkpoints, each ``mp_rank`` saves its own checkpoint + file with only the parameters that belong to that rank. +- For full checkpoints, the library saves a single checkpoint that contains + entire model parameters. + +When **saving** using ``smp.save()``, each rank only holds its own +parameters. If you want to save the full model, there will be some +communication between the ranks to create the full model. If you save +checkpoints often, you should save partial checkpoints for best +performance. + +When **loading** using ``smp.load()``, the library can load either partial or | +full checkpoints or full checkpoints saved by a non-model-parallel model. If you +want to resume training with a non-model-parallel model or do inference, you need +a full checkpoint. + +The following is an example of how you can save and load a checkpoint: + +.. code:: python + + # Original model and optimizer + model = MyModel(...) + optimizer = MyOpt(...) + + # model parallel wrapper + model = smp.DistributedModel(model) + optimizer = smp.DistributedOptimizer(optimizer) + + # To save, always save on dp_rank 0 to avoid data racing + if partial: +     # To save the partial model on each mp rank +     # the library will create `checkpoint.pt_{mprank}` for each mp rank +     if save_partial_model: +         if smp.dp_rank() == 0: +             model_dict = model.local_state_dict() # save the partial model +             opt_dict = optimizer.local_state_dict() # save the partial optimizer state +             smp.save( +                 {"model_state_dict": model_dict, "optimizer_state_dict": opt_dict}, +                 f"/checkpoint.pt", +                 partial=True, +             ) + +     # To save the full model +     if save_full_model: +         if smp.dp_rank() == 0: +             model_dict = model.state_dict() # save the full model +             opt_dict = optimizer.state_dict() # save the full optimizer state +             smp.save( +                 {"model_state_dict": model_dict, "optimizer_state_dict": opt_dict}, +                 "/checkpoint.pt", +                 partial=False, +             ) + + # To load, load on all ranks. + # The only difference for partial/full loading is the partial flag in smp.load + # Load partial checkpoint + if partial_checkpoint: +    checkpoint = smp.load("/checkpoint.pt", partial=True) +    model.load_state_dict(checkpoint["model_state_dict"]) +    optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) + # Load full checkpoint + if full_checkpoint: +    checkpoint = smp.load("/checkpoint.pt", partial=False) +    model.load_state_dict(checkpoint["model_state_dict"]) +    optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) diff --git a/doc/api/training/smp_versions/v1.9.0/smd_model_parallel_pytorch_tensor_parallel.rst b/doc/api/training/smp_versions/v1.9.0/smd_model_parallel_pytorch_tensor_parallel.rst new file mode 100644 index 0000000000..851408b4b8 --- /dev/null +++ b/doc/api/training/smp_versions/v1.9.0/smd_model_parallel_pytorch_tensor_parallel.rst @@ -0,0 +1,875 @@ +.. _smdmp-pytorch-tensor-parallel: + :noindex: + +PyTorch API for Tensor Parallelism +================================== + +SageMaker distributed tensor parallelism works by replacing specific submodules +in the model with their distributed implementations. The distributed modules +have their parameters and optimizer states partitioned across tensor-parallel +ranks. This is to compute the same output as it would have been computed by +the original modules. Since tensor parallelism occurs across data-parallel +ranks, a rank might collect slices of the activations corresponding to the +data shards on other devices that are part of the same tensor parallelism group. + +You can enable or disable tensor parallelism for specific parts of the model. +Within the enabled parts, the replacements with distributed modules will take +place on a best-effort basis for those module supported for tensor parallelism. +Alternatively, you can directly import and use the library’s distributed +modules in the model definition. + +Some of the supported modules (such as ``smp.nn.Transformer``) are high-level +blocks that contain many operations. Because custom implementations +(as opposed to the built-in PyTorch modules) are typically used for these +high-level blocks, the library offers an API that you can use to register +specific distributed versions with such custom modules (provided that they +are functionally equivalent). This allows the library to automatically replace +the occurrences of such PyTorch modules with their distributed counterparts +provided by the library. +For more information, see the following topics. + +.. contents:: Topics + :depth: 3 + :local: + +.. _registering-tp-modules: + :noindex: + +Registering Tensor Parallelism Distributed Modules +-------------------------------------------------- + +Although PyTorch natively provides some of the commonly used (and +tensor-parallelizable) building blocks such as Transformer, users often +use custom implementations for such higher-level modules. To distribute +such modules with tensor parallelism, you need to register the +distributed modules to the custom module implementation in your class, +so that the library knows how to distribute the custom module. When you +register the distributed modules, make sure the custom module that you +use is functionally equivalent to the distributed module. You can verify +this by taking a look at the equivalent reference implementations in the +:ref:`smdmp-tp-appendix`. +These implementations are functionally equivalent to their distributed +versions in ``smp.nn`` module. + +.. decorator:: @smp.tp_register(dist_module, init_hook=None, forward_hook=None, return_hook=None) + + - A class decorator that registers the ``dist_module`` class with + the module class that it is attached to. The hooks can be used to + adapt to different interfaces used with ``__init__`` and + ``forward`` methods. + - **Arguments:** + + - ``dist_module``: A subclass of ``smp.nn.DistributedModule`` + that implements the distributed version of the module class the + decorator is attached to. Any distributed module class defined + in ``smp.nn`` module can be used. + - ``init_hook``: A callable that translates the arguments of the + original module ``__init__`` method to an ``(args, kwargs)`` + tuple compatible with the arguments of the corresponding + distributed module ``__init__`` method. Must return a tuple, + whose first element is an iterable representing the positional + arguments, and second element is a ``dict`` representing the + keyword arguments. The input signature of the ``init_hook`` + must **exactly** match the signature of the original + ``__init__`` method (including argument order and default + values), except it must exclude ``self``. + - ``forward_hook``: A callable that translates the arguments of + the original module ``forward`` method to an ``(args, kwargs)`` + tuple compatible with the arguments of the corresponding + distributed module ``forward`` method. Must return a tuple, + whose first element is an iterable representing the positional + arguments, and second element is a ``dict`` representing the + keyword arguments. The input signature of the ``init_hook`` + must **exactly** match the signature of the original + ``forward`` method (including argument order and default + values), except it must exclude ``self``. + - ``return_hook``: A callable that translates the object returned + from the distributed module to the return object expected of + the original module. + + - **Example:** + + .. code:: python + + init_hook = lambda config: ((), config.to_dict()) + + # register smp.nn.DistributedTransformer + # as the distributed version of MyTransformer + @smp.tp_register(smp.nn.DistributedTransformer, init_hook=init_hook) + class MyTransformer(nn.Module): + def __init__(self, config): + ... + + def forward(self, hidden_states, attention_mask): + ... + +.. function:: smp.tp_register_with_module(module_cls, dist_module, init_hook=None, forward_hook=None, return_hook=None) + :noindex: + + - When you do not have direct access to model definition code, you + can use this API to similarly register a distributed module with + an existing module class. + + - **Arguments:** + + - ``module_cls``: The existing module class that will be + distributed. + - ``dist_module``: A subclass of ``smp.nn.DistributedModule`` + that implements the distributed version of the module class the + decorator is attached to. Any distributed module class defined + in ``smp.nn`` module can be used. + - ``init_hook``: A callable that translates the arguments of the + original module ``__init__`` method to an ``(args, kwargs)`` + tuple compatible with the arguments of the corresponding + distributed module ``__init__`` method. Must return a tuple, + whose first element is an iterable representing the positional + arguments, and second element is a ``dict`` representing the + keyword arguments. The input signature of the ``init_hook`` + must **exactly** match the signature of the original + ``__init__`` method (including argument order and default + values), except it must exclude ``self``. + - ``forward_hook``: A callable that translates the arguments of + the original module ``forward`` method to an ``(args, kwargs)`` + tuple compatible with the arguments of the corresponding + distributed module ``forward`` method. Must return a tuple, + whose first element is an iterable representing the positional + arguments, and second element is a ``dict`` representing the + keyword arguments. The input signature of the ``init_hook`` + must **exactly** match the signature of the original + ``forward`` method (including argument order and default + values), except it must exclude ``self``. + - ``return_hook``: A callable that translates the object returned + from the distributed module to the return object expected of + the original module. + + - **Example:** + + .. code:: python + + from somelibrary import MyTransformer + + init_hook = lambda config: ((), config.to_dict()) + + # register smp.nn.DistributedTransformer as the distributed version of MyTransformer + smp.tp_register_with_module(MyTransformer, + smp.nn.DistributedTransformer, + init_hook=init_hook) + +.. _smdmp-supported-modules-for-tp: + :noindex: + +Supported Modules for Tensor Parallelism +---------------------------------------- + +The following modules are supported for tensor +parallelism. + +- ``smp.nn.DistributedLinear`` (implements ``nn.Linear``) +- ``smp.nn.DistributedTransformerLMHead`` +- ``smp.nn.DistributedTransformer`` +- ``smp.nn.DistributedTransformerLayer`` +- ``smp.nn.DistributedAttentionLayer`` +- ``smp.nn.DistributedTransformerOutputLayer`` +- ``smp.nn.DistributedEmbedding`` + +.. contents:: Topics + :depth: 3 + :local: + +.. _tp-module-api: + :noindex: + +Tensor Parallelism Module APIs +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. class:: smp.nn.DistributedLinear(in_features, out_features) + :noindex: + + - Tensor-parallel implementation of the ``nn.Linear`` class. + Functionally equivalent to an ``nn.Linear`` module with the same + ``in_features`` and ``out_features``. In other words, + ``in_features`` and ``out_features`` are the number of *global* + channels across tensor-parallel ranks. + - **Arguments:** + + - ``in_features``: The total number of input channels for the + linear layer across all tensor-parallel ranks. + - ``out_features``: The total number of output channels for the + linear layer across all tensor-parallel ranks. + +.. class:: smp.nn.DistributedTransformerLMHead(num_layers=12, num_attention_heads=32, attention_head_size=32, hidden_size=1024, intermediate_size=4096, vocab_size=30522, num_positions=1024, attention_dropout_prob=0.1, hidden_dropout_prob=0.1, activation="gelu", layernorm_epsilon=1e-5, num_token_types=0, causal_mask_size=None, add_cross_attention=False, add_lm_head=True, initializer_range=0.02, use_normal_initialization=False, pre_layernorm=False, post_layernorm=True) + :noindex: + + - Constructs a distributed transformer model, including embeddings + and a single LM head. A word embedding of size + ``(vocab_size, hidden_size)`` is created, as well as a positional + embedding of size ``(num_positions, hidden_size)``, and the + embeddings are added together. If ``num_token_types`` is larger + than 0, a separate embedding of size + ``(num_token_types, hidden_size)`` is created, and further added + on top. + - The embeddings are fed through a ``DistributedTransformer``, and + if ``add_lm_head`` is ``True``, the output passes through a single + LM head, which is a linear module without bias whose weight is + tied to the word embeddings. + - See :class:`smp.nn.DistributedTransformerLayer` for descriptions of the rest + of the arguments. + - **Methods:** + + - ``forward(self, inputs)`` + + - If ``add_cross_attention`` is ``True``, ``inputs`` must be a + tuple + ``(input_ids, attention_mask, token_type_ids, position_ids, cross_states, cross_states, cross_mask, labels)``. + - Otherwise, ``inputs`` must be a tuple + ``(input_ids, attention_mask, token_type_ids, position_ids, labels)``. + - If ``token_type_ids`` is ``None``, token type embedding will + not be used. + - ``input_ids`` is assumed to be of shape ``[N, S]``, where + ``N`` is the batch size and ``S`` is sequence length. + - ``attention_mask`` is assumed to be a 0-1 tensor of shape + ``[N, S]``, where 1 represents a masked position. + +.. class:: smp.nn.DistributedTransformer(num_layers=12, num_attention_heads=32, attention_head_size=32, hidden_size=1024, intermediate_size=4096, attention_dropout_prob=0.1, hidden_dropout_prob=0.1, activation="gelu", layernorm_epsilon=1e-5, initializer_range=0.02, use_normal_initialization=False, causal_mask_size=None, add_cross_attention=False, pre_layernorm=False, post_layernorm=True) + :noindex: + + - A sequence of ``smp.nn.DistributedTransformerLayer``\ s, whose + number is given by ``num_layers`` argument. For the other + arguments and methods, refer to + ``smp.nn.DistributedTransformerLayer``. + - If both ``pre_layernorm`` and ``post_layernorm`` are ``True``, + layer normalization is applied to both the input and the output of + the ``DistributedTransformer``, in addition to the intermediate + attention and transformer-output layers. + +.. class:: smp.nn.DistributedTransformerLayer(num_attention_heads=32, attention_head_size=32, hidden_size=1024, intermediate_size=4096, attention_dropout_prob=0.1, hidden_dropout_prob=0.1, activation="gelu", layernorm_epsilon=1e-5, initializer_range=0.02, use_normal_initialization=False, causal_mask_size=None, add_cross_attention=False, pre_layernorm=False, post_layernorm=True) + :noindex: + + - Tensor-parallel implementation of a single transformer layer. + Number of attention heads, hidden size, and intermediate size + refer to the global quantities across all tensor-parallel ranks. + - **Arguments:** + + - ``num_attention_heads``: The total number of attention heads + across tensor-parallel ranks + - ``attention_head_size``: The number of channels of a single + attention head. + - ``hidden_size``: The hidden dimension of the transformer. The + input tensor ``hidden_states`` is assumed to have its last + dimension size equal to ``hidden_size``. + - ``intermediate_size``: The number of output channels in the + first linear transformation of the transformer output layer. + ``DistributedTransformerOutputLayer`` first maps + ``hidden_size`` dimensions of its input tensor into + ``intermediate_size`` dimensions, and then maps it back into + ``hidden_size`` dimensions. + - ``attention_dropout_prob``: The dropout probability applied to + the attention probabilities. + - ``hidden_dropout_prob``: The dropout probability used in + dropout layers other than the one applied to the attention + probabilities. + - ``activation``: Choice of activation function to use at the + output layer. Must be ``"gelu"`` or ``"relu"``. + - ``layernorm_epsilon``: The epsilon added to the denominator of + layer normalization for numerical stability. + - ``initializer_range``: If ``use_normal_initialization`` is + ``True``, the standard deviation of the normal random variable + to initialize the weights with. + - ``use_normal_initialization``: If ``True``, the weights are + initialized with normal distribution with standard deviation + given by ``initializer_range``. Otherwise, default PyTorch + initialization is used. + - ``causal_mask_size``: If ``None``, no causal mask is used on + attentions. Otherwise, should be set to maximum sequence length + to apply a causal mask to the attention scores. This is used, + for instance, in GPT-2. + - ``add_cross_attention``: If ``True``, a cross-attention layer + will be added after the self-attention block. The + cross-attention layer computes the attention keys and values + based on the ``cross_states`` input (instead of + ``hidden_states`` input, as in self-attention. This is used in + the decoder block of encoder-decoder architectures. For + encoder-only architectures that only use self-attention, this + should be kept ``False``. + - ``pre_layernorm``: If ``True``, inserts layer normalization at + the input. At least one of ``pre_layernorm`` and + ``post_layernorm`` must be ``True``. + - ``post_layernorm``: If ``True``, inserts layer normalization at + the output. At least one of ``pre_layernorm`` and + ``post_layernorm`` must be ``True``. + + - **Methods:** + + - ``forward(self, inputs)``: Forward pass for the transformer + layer. + + - **Arguments:** + + - If ``add_cross_attention=False``, ``inputs`` must be a + tuple ``(hidden_states, attention_mask)``, where + ``hidden_states`` is assumed to be a tensor of dimensions + ``[N, S, H]``, where ``N`` is batch size, ``S`` is + sequence length, and ``H`` is ``hidden_size``. + ``attention_mask`` is assumed to be a tensor of + dimensions ``[N, 1, 1, S]``, where ``N`` is the batch + size, and ``S`` is the sequence length. + - If ``add_cross_attention=True``, ``inputs`` must be a + tuple + ``(hidden_states, cross_states, attention_mask, cross_mask)``, + where ``hidden_states`` is assumed to be a tensor of + dimensions ``[N, S_1, H]``, where ``N`` is batch size, + ``S_1`` is sequence length, and ``H`` is ``hidden_size``. + ``cross_states`` is assumed to be a tensor of size + ``[N, S_2, H]``, similarly interpreted. + ``attention_mask`` is assumed to be a tensor of + dimensions ``[N, 1, 1, S_1]``, where ``N`` is the batch + size, and ``S_1`` is the sequence length, and + ``cross_mask`` is assumed to be a tensor of size + ``[N, 1, 1, S_2]``. Keys and values for the attention + heads in the cross-attention layer (but not the + self-attention layer) are computed using + ``cross_states``, and ``cross_mask`` is applied as the + attention mask in the cross-attention layer (but not the + self-attention layer). + + - **Returns:** + + - If ``add_cross_attention=False``, a tuple + ``(hidden_states, attention_mask)``, where + ``hidden_states`` is the output of the transformer, and + ``attention_mask`` is the same the ``attention_mask`` + argument. + - If ``add_cross_attention=True``, a tuple + ``(hidden_states, cross_states, attention_mask, cross_mask)``, + where ``hidden_states`` is the output of the transformer, + and the next three tensors are the same as the input + arguments. + +.. class:: smp.nn.DistributedAttentionLayer(num_attention_heads=32, attention_head_size=32, hidden_size=1024, attention_dropout_prob=0.1, hidden_dropout_prob=0.1, layernorm_epsilon=1e-5, initializer_range=0.02, use_normal_initialization=False, cross_attention=False, causal_mask_size=None, pre_layernorm=False, post_layernorm=True) + :noindex: + + - A distributed implementation for the attention block. Includes the + computation of the self- or cross-attention (context layer), + followed by a linear mapping and dropout, which is optionally + followed by the residual-connection and layer normalization. + - **Arguments:** + + - See :class:`smp.nn.DistributedTransformerLayer` for descriptions of the + arguments. + - ``cross_attention``: If ``True``, it computes the attentions + with respect to the ``cross_states`` tensor of the ``forward`` + method input tuple. (Default: ``False``) + + - **Methods:** + + - ``forward(self, inputs)``: Forward pass for the attention + layer. + + - **Arguments:** + + - If ``cross_attention=False``, ``inputs`` must be a tuple + ``(hidden_states, attention_mask)``, where + ``hidden_states`` is assumed to be a tensor of dimensions + ``[N, S, H]``, where ``N`` is batch size, ``S`` is + sequence length, and ``H`` is ``hidden_size``. + ``attention_mask`` is assumed to be a tensor of + dimensions ``[N, 1, 1, S]``, where ``N`` is the + batch size, and ``S`` is the sequence length. + - If ``cross_attention=True``, ``inputs`` must be a tuple + ``(hidden_states, cross_states, attention_mask)``, where + ``hidden_states`` is assumed to be a tensor of dimensions + ``[N, S_1, H]``, where ``N`` is batch size, ``S_1`` is + sequence length, and ``H`` is ``hidden_size``. + ``cross_states`` is assumed to be a tensor of size + ``[N, S_2, H]``, similarly interpreted. + ``attention_mask`` is assumed to be a tensor of + dimensions ``[N, 1, 1, S_2]``, where ``N`` is the batch + size, and ``S_2`` is the sequence length. Keys and values + for the attention heads are computed using + ``cross_states``. + + - **Returns:** + + - A single tensor that is the output of the attention + layer. + +.. class:: smp.nn.DistributedTransformerOutputLayer(hidden_size=1024, intermediate_size=4096, hidden_dropout_prob=0.1, activation="gelu", layernorm_epsilon=1e-5, initializer_range=0.02, use_normal_initialization=False, pre_layernorm=False, post_layernorm=True, fp32_residual_addition=False) + :noindex: + + - Distributed implementation of a single transformer output layer. A + single :class:`smp.nn.DistributedTransformerLayer` with + ``add_cross_attention=False`` consists of a single + ``DistributedAttentionLayer`` immediately followed by a single + ``DistributedTransformerOutputLayer``. The latter linearly maps + the last channel of the input tensor from ``hidden_size`` to + ``intermediate_size``, and then maps it back to ``hidden_size``. + - **Arguments:** + + - See :class:`smp.nn.DistributedTransformerLayer` for descriptions of the + arguments. + - ``fp32_residual_addition``: Set to ``True`` if you want to avoid overflow + (NaN loss values) for large models with more than 100 billion parameters + when using FP16. (Default: False) + +.. class:: smp.nn.DistributedEmbedding(num_embeddings,embedding_dim, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False, _weight=None, initializer_range=0.02, _skip_allgather=False,_skip_scatter_and_merge=False,) + :noindex: + + - Distributed implementation of a single Embedding Layer. Currently + only supports splitting across the embedding_dim. + - **Arguments:** + + - See :class:`smp.nn.DistributedEmbedding` for descriptions of the + arguments. + +.. _enabling-tp: + :noindex: + +Enabling Tensor Parallelism +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +There are two ways tensor parallelism can be enabled. + +First, you can use +the distributed module implementations in ``smp.nn`` module directly in +your model definition. See :ref:`smdmp-supported-modules-for-tp` +for a complete list of built-in distributed modules. Here is an example +of how this can be done: + +.. code:: python + + import torch.nn as nn + import smdistributed.modelparallel.torch as smp + + class TransformerModel: + def __init__(self): + self.embedding = nn.Embedding(vocab_size, hidden_size) + + # directly instantiate smp.nn.DistributedTransformer and use it + self.encoder = smp.nn.DistributedTransformer(num_layers, hidden_size, **kwargs) + + self.pooler = nn.Linear(hidden_size, hidden_size) + + def forward(self, hidden_states): + emb_out = self.embedding(hidden_states) + enc_out = self.encoder(emb_out) + return self.pooler(enc_out) + +Second, you can enable tensor parallelism for specific modules or blocks +of code, which will automatically enable tensor parallelism for the +supported modules within that scope. To do this, you can use the +following API: + +.. decorator:: smp.tensor_parallelism(enabled=True, **kwargs) + + - A context manager that enables or disables tensor parallelism for + any supported module that is created inside. If there are nested + contexts, the innermost overrides the rest. If there are + multiple supported modules created within the context, where one + is the submodule of the other, only the outermost module will be + distributed. If a supported module shares weights with another + (supported or unsupported) module, or if its hyperparameters do + not support distribution (e.g., not divisible by the tensor + parallelism degree), tensor parallelism will **not** be enabled + for this module even if this API is used. + + **Example:** + + .. code:: python + + with smp.tensor_parallelism(): + self.m0 = nn.Linear(20, 20) # will be distributed + with smp.tensor_parallelism(enabled=False): + self.m1 = nn.Linear(20, 20) # will not be distributed + + - ``kwargs`` - Keyword arguments that can be used to modify the configurations of + the distributed modules created inside the context. + If a keyword argument provided through it matches any ``__init__`` method arguments + of a ``DistributedModule`` that substitutes a module created inside + the ``smp.tensor_parallelism`` context, this keyword will override + the value defined in the ``init_hook``. + + - (*For v1.7.0 and later*) Through the following additional keyword arguments, + the library supports `NVIDIA Megatron’s fused kernels + `_ + + - ``fused_softmax`` (bool) - Fusion of attention masking and softmax. + By default, it is set to ``True``. You can deactivate it by setting + ``fused_softmax=False`` in the ``smp.tensor_parallelism`` context manager. + - ``fused_bias_gelu`` (bool) - Fusion of bias addition and Gelu activation. + By default, it is set to ``False``. You can activate it by setting + ``fused_bias_gelu=True`` in the ``smp.tensor_parallelism`` context manager. + + + +.. function:: smp.set_tensor_parallelism(module, enabled=True, **kwargs) + :noindex: + + - Enables or disables tensor parallelism for the supported + submodules of ``module``. If enabling, the outermost supported + modules will be distributed. If disabling, tensor parallelism will + be disabled for the entire module subtree of ``module``. Unlike + the context manager, this API can be used after the model creation + (but before wrapping with :class:`smp.DistributedModel`), so direct + access to model definition code is not required. If a supported + module shares weights with another (supported or unsupported) + module, or if its hyperparameters do not support distribution + (e.g., not divisible by the tensor parallelism degree), tensor + parallelism will **not** be enabled for this module. + - Keyword arguments ``kwargs`` can be used to modify the + configurations of the distributed modules created inside the + context. If a keyword argument provided here matches any + ``__init__`` method arguments of a :class:`smp.DistributedModel` that + substitutes a module created inside the ``smp.tensor_parallelism`` + context, this keyword will override the value defined in the + ``init_hook``. + - **Example:** + + .. code:: python + + model = MyModel() + smp.set_tensor_parallelism(model.encoder, True) + smp.set_tensor_parallelism(model.encoder.embedding, True) + + # outermost supported submodules in model.encoder will be distributed, except for + # model.encoder.embedding + model = smp.DistributedModel(model) + optimizer = smp.DistributedOptimizer(optimizer) + +.. _activation-checkpointing-api: + :noindex: + +Activation Checkpointing APIs +----------------------------- + +``smdistributed.modelparallel`` provides three APIs to enable +activation checkpointing: one for checkpointing modules, +one for checkpointing sequential modules, and +one for checkpointing pretrained models. + +For a conceptual guide and examples, see +`Activation Checkpointing `_ +in the *SageMaker's Distributed Model Parallel developer guide*. + +.. class:: smdistributed.modelparallel.torch.patches.checkpoint.checkpoint(module, *args, preserve_rng_state=True) + :noindex: + + - Checkpoints the module passed. Throws error if, during manual + partitioning, all children of module are not on same rank as the + module itself, i.e. the module tree is split across multiple + partitions. During auto-partitioning, if the module is split + across multiple partitions, then this call is ignored(with a + warning). Note that this call applies to the module instance only, + not to the module class. + + - **Arguments:** + + - ``module (Instance of nn.Module)``: The module to be + checkpointed. Note that unlike native checkpointing in + PyTorch’s, activation checkpointing in + ``smdistributed.modelparallel`` is at the granularity of a + module. A generic function cannot be passed here. + - ``args``: Tuple containing inputs to the module. + - ``preserve_rng_state (bool, default=True)``: Omit stashing and + restoring the RNG state during each checkpoint. + +.. class:: smdistributed.modelparallel.torch.patches.checkpoint.checkpoint_sequential(sequential_module, input, strategy="each", preserve_rng_state=True, pack_args_as_tuple=False) + :noindex: + + - Checkpoints the modules inside + `nn.Sequential `__. + This can be used even if different layers that are part of the + sequential container lie on different partitions. Each layer part + of the sequential module that is checkpointed must lie completely + within one partition. If this is not the case during manual + partitioning, then an error will be thrown. If this is not the + case during auto partitioning, a warning will be raised and this + module will be run without checkpointing. + + - **Arguments** + + - ``sequential_module (nn.Sequential)``: the sequential module to + be checkpointed. + - ``input (torch.Tensor or a tuple of torch.Tensors)``: input to + the module, which can be a tensor or a tuple of tensors. If a + tuple is passed, then pack_args_as_tuple should be set to True. + - ``strategy (string, default=“each”)`` : Strategy determines how + many layers part of the sequential module need to be grouped + together for one checkpointing call. This determines how much + memory can be reduced. It can take the following values + + - ``each`` : The default is to checkpoint each module inside + the sequential separately. + - ``contiguous``: Groups consecutive layers on the same + partition together. For example, if a sequential consists of + [a, b, c, d] where a,b are on pp_rank0 and c,d are on + pp_rank 1, then this strategy would checkpoint a,b together + and then c,d together. This means effectively, inputs of a, + outputs of b, inputs of c, and outputs of d are in memory; + the reamining activations are recomputed. + - ``group_2, group_3, group_4, etc:`` More generally, + ``group_x`` where x is an integer. This strategy provides + more flexibility in how many layers to group together. + ``group_x`` groups x layers together on a best effort basis. + It can group x layers together if there are x layers + consecutively on the same partition. For example: + [a,b,c,d,e] where a,b are on pp_rank0 and c,d,e are on + pp_rank 1. If the strategy is ``group_3,`` then a,b are + checkpointed together on pp_rank0 and c,d,e are checkpointed + together on pp_rank1. + + - ``preserve_rng_state (bool, default=True)``: Set to ``False`` + to omit stashing and restoring the RNG state during each + checkpoint. + - ``pack_args_as_tuple (bool, default=False)``: To ensure that + backward works correctly, the autograd function has to unpack + any tuples received. If the checkpointed layer takes a tuple as + input, then this needs to be set to True. + +.. class:: smp.set_activation_checkpointing(module, preserve_rng_state=True, pack_args_as_tuple=False, strategy="each") + :noindex: + + - This API is recommended when importing pretrained models from + libraries, such as PyTorch and Hugging Face Transformers. This is + particularly useful when you don’t have access to the model + definition code and not be able to replace a module call with + checkpoint. + + - **Arguments**: + + - ``module (Instance of nn.Module or nn.Sequential)``: The module + to checkpoint. + - ``preserve_rng_state (bool, default=True)``: Set to ``False`` + to omit stashing and restoring the RNG state during each + checkpoint. + - ``pack_args_as_tuple (bool, default=False)``: *Can only be + passed when module is a sequential module.* To ensure that + backward works correctly, the autograd function has to unpack + any tuples received. If the layer checkpointed takes a tuple as + input, then this needs to be set to True. + - ``strategy: (string, default=“each”)``: *Can only be passed + when module is a sequential module.* Strategy determines how + many layers part of the sequential module need to be grouped + together for one checkpointing call. + - This determines how much memory can be reduced. It can take the + following values + + - ``each`` : The default is to checkpoint each module inside + the sequential separately. + - ``contiguous``: Groups consecutive layers on the same + partition together. For example if a sequential consists of + ``[a, b, c, d]`` where ``a, b`` are on ``pp_rank0`` and ``c, d`` are on + ``pp_rank 1``, then this strategy would checkpoint a,b together + and then ``c, d`` together. This means effectively, the inputs of + ``a``, outputs of ``b``, inputs of ``c``, and outputs of ``d`` are in + memory, and the rest of the activations are recomputed. + - ``group_2, group_3, group_4, etc:`` More generally, + ``group_x`` where x is an integer. This strategy provides + more flexibility in how many layers to group together. + ``group_x`` groups x number of layers together on a best + effort basis if there are x layers consecutively in the same + partition. **Example**: Assume a module with layers ``[a, b, + c, d, e]``. The layers a and b are on pp_rank0, and ``c``, ``d``, and + ``e`` are on ``pp_rank 1``. If the strategy is ``group_3,`` then ``a``, + ``b`` are checkpointed together on ``pp_rank0``, and ``c``, ``d``, ``e`` are + checkpointed together on ``pp_rank1``. + +.. _smdmp-tp-appendix: + :noindex: + +Appendix: Reference Implementations for Modules +----------------------------------------------- + +The following are reference implementations for transformer-related +modules. Note that this is not the actual ``smdistributed`` source code, +but the distributed implementations provided in the library are the +distributed versions of these reference implementations, and can be used +to determine whether the distributed modules perform the same operations +as the custom modules in your script. + +To keep the implementations simple, we only assume keyword arguments, +and assume the existence of a method ``parse_args(kwargs)``, which +parses the arguments to ``__init__`` methods and sets the relevant +attributes of the module, such as ``hidden_size`` and +``num_attention_heads``. + +``smp.nn.DistributedTransformer`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code:: python + + class Transformer(nn.Module): + def __init__(self, **kwargs): + super(Transformer, self).__init__() + self.parse_args(kwargs) + + self.layers = [] + for l in range(self.num_layers): + self.layers.append(TransformerLayer(**kwargs)) + + self.seq_layers = nn.Sequential(*self.layers) + + def forward(self, inp): + return self.seq_layers(inp) + +``smp.nn.DistributedTransformerLayer`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code:: python + + class TransformerLayer(nn.Module): + def __init__(self, **kwargs): + super(TransformerLayer, self).__init__() + self.parse_args(kwargs) + + self.attention = AttentionLayer(**kwargs) + self.output = TransformerOutputLayer(**kwargs) + + if self.add_cross_attention: + self.cross_attention = AttentionLayer(cross_attention=True, **kwargs) + + def forward(self, inp): + if self.add_cross_attention: + hidden_states, cross_states, attention_mask, cross_mask = inp + else: + hidden_states, attention_mask = inp + + attention_output = self.attention((hidden_states, attention_mask)) + if self.add_cross_attention: + attention_output = self.cross_attention((attention_output, + cross_states, + cross_mask)) + + output = self.output(attention_output) + + if self.add_cross_attention: + return output, cross_states, attention_mask, cross_mask + else: + return output, attention_mask + +``smp.nn.DistributedAttentionLayer`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code:: python + + class AttentionLayer(nn.Module): + def __init__(self, **kwargs): + super(AttentionLayer, self).__init__() + self.parse_args(kwargs) + self.attention_head_size = self.hidden_size // self.num_attention_heads + + self.query = nn.Linear(self.hidden_size, self.hidden_size) + self.key = nn.Linear(self.hidden_size, self.hidden_size) + self.value = nn.Linear(self.hidden_size, self.hidden_size) + self.dense = nn.Linear(self.hidden_size, self.hidden_size) + + self.dropout1 = nn.Dropout(self.attention_dropout_prob) + self.dropout2 = nn.Dropout(self.hidden_dropout_prob) + + if self.pre_layernorm: + self.pre_layernorm = nn.LayerNorm(self.hidden_size, + eps=self.layernorm_epsilon) + + if self.post_layernorm: + self.layernorm = nn.LayerNorm(self.hidden_size, + eps=self.layernorm_epsilon) + + def transpose(self, tensor, key=False): + shape = tensor.size()[:-1] + + (self.num_attention_heads, self.attention_head_size) + tensor = torch.reshape(tensor, shape) + if key: + return tensor.permute(0, 2, 3, 1) + else: + return tensor.permute(0, 2, 1, 3) + + def forward(self, inp): + if self.cross_attention: + hidden_states, cross_states, attention_mask = inp + else: + hidden_states, attention_mask = inp + + if self.pre_layernorm: + norm_states = self.pre_layernorm(hidden_states) + else: + norm_states = hidden_states + + query_layer = self.query(norm_states) + + if self.cross_attention: + key_layer = self.key(cross_states) + value_layer = self.value(cross_states) + else: + key_layer = self.key(norm_states) + value_layer = self.value(norm_states) + + query_layer = self.transpose(query_layer) + key_layer = self.transpose(key_layer, key=True) + value_layer = self.transpose(value_layer) + + attention_scores = torch.matmul(query_layer, key_layer) + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + + if not self.cross_attention and self.causal_mask is not None: + attention_scores = self.apply_causal_mask(attention_scores) + + attention_scores = attention_scores + attention_mask + + attention_probs = F.softmax(attention_scores, dim=-1) + attention_probs = self.dropout1(attention_probs) + + context_layer = torch.matmul(attention_probs, value_layer) + context_layer = context_layer.permute(0, 2, 1, 3) + new_context_layer_shape = context_layer.size()[:-2] + \ + (self.local_attention_size,) + context_layer = torch.reshape(context_layer, new_context_layer_shape) + + self_attention = self.dense(context_layer) + self_attention = self.dropout2(self_attention) + + if self.post_layernorm: + return self.layernorm(self_attention + hidden_states) + else: + return self_attention + +``smp.nn.DistributedTransformerOutputLayer`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code:: python + + class TransformerOutputLayer(nn.Module): + def __init__(self, **kwargs): + super(TransformerOutputLayer, self).__init__() + self.parse_args(kwargs) + + self.dense1 = nn.Linear(self.hidden_size, self.intermediate_size) + self.dense2 = nn.Linear(self.intermediate_size, self.hidden_size) + + self.dropout = nn.Dropout(self.attention_dropout_prob) + + if self.pre_layernorm: + self.pre_layernorm = nn.LayerNorm(self.hidden_size, + eps=self.layernorm_epsilon) + + if self.post_layernorm: + self.layernorm = nn.LayerNorm(self.hidden_size, + eps=self.layernorm_epsilon) + + def forward(self, inp): + if self.pre_layernorm: + norm_inp = self.pre_layernorm(inp) + else: + norm_inp = inp + + dense1_output = self.dense1(norm_inp) + if self.activation == "gelu": + act_output = F.gelu(dense1_output) + else: + act_output = F.relu(dense1_output) + + dense2_output = self.dense2(act_output) + output = self.dropout(dense2_output) + + if self.post_layernorm: + return self.layernorm(inp + output) + else: + return output diff --git a/doc/api/training/smp_versions/v1.9.0/smd_model_parallel_tensorflow.rst b/doc/api/training/smp_versions/v1.9.0/smd_model_parallel_tensorflow.rst new file mode 100644 index 0000000000..54ec558fc5 --- /dev/null +++ b/doc/api/training/smp_versions/v1.9.0/smd_model_parallel_tensorflow.rst @@ -0,0 +1,171 @@ +TensorFlow API +============== + +To use the TensorFlow-specific APIs for SageMaker distributed model parallism, +you need to add the following import statement at the top of your training script. + +.. code:: python + + import smdistributed.modelparallel.tensorflow as smp + +.. tip:: + + Refer to + `Modify a TensorFlow Training Script + `_ + to learn how to use the following APIs in your TensorFlow training script. + +.. class:: smp.DistributedModel + :noindex: + + A sub-class of the Keras \ ``Model`` class, which defines the model to + be partitioned. Model definition is done by sub-classing + ``smp.DistributedModel`` class, and implementing the ``call()`` method, + in the same way as the Keras model sub-classing API. Any operation that + is part of the \ ``smp.DistributedModel.call()`` method is subject to + partitioning, meaning that every operation placed inside executes in + exactly one of the devices (the operations outside run on all devices). + + + Similar to the regular Keras API, the forward pass is done by directly + calling the model object on the input tensors. For example: + + .. code:: python + + predictions = model(inputs)   # model is a smp.DistributedModel object + + However, ``model()`` calls can only be made inside a + ``smp.step``-decorated function. + + The outputs from a ``smp.DistributedModel`` are available in all ranks, + regardless of which rank computed the last operation. + + **Methods:** + + .. function:: save_model(save_path="/opt/ml/model") + :noindex: + + **Inputs** + - ``save_path`` (``string``): A path to save an unpartitioned model with latest training weights. + + Saves the entire, + unpartitioned model with the latest trained weights to ``save_path`` in + TensorFlow ``SavedModel`` format. Defaults to ``"/opt/ml/model"``, which + SageMaker monitors to upload the model artifacts to Amazon S3. + +.. function:: smp.partition(index) + :noindex: + + **Inputs** + + - ``index`` (``int``): The index of the partition. + + A context manager which places all operations defined inside into the + partition whose ID is equal to ``index``. When + ``smp.partition`` contexts are nested, the innermost context overrides + the rest. The ``index`` argument must be smaller than the number of + partitions. + + ``smp.partition`` is used in the manual partitioning API; + if \ ``"auto_partition"`` parameter is set to ``True`` while launching + training, then ``smp.partition`` contexts are ignored. Any operation + that is not placed in any ``smp.partition`` context is placed in the + ``default_partition``, as shown in the following example: + + .. code:: python + + # auto_partition: False + # default_partition: 0 + smp.init() + [...] + x = tf.constant(1.2)                     # placed in partition 0 + with smp.partition(1): +     y = tf.add(x, tf.constant(2.3))      # placed in partition 1 +     with smp.partition(3): +         z = tf.reduce_sum(y)             # placed in partition 3 + + +.. function:: register_post_partition_hook(hook) + :noindex: + + Registers a callable ``hook`` to + be executed after the model is partitioned. This is useful in situations + where an operation needs to be executed after the model partition during + the first call to ``smp.step``, but before the actual execution of the + first forward pass. + + .. code:: python + + @smp.register_post_partition_hook + def test_eager(): + # All statements here will be executed right after partition but before the first forward pass + tf.print("Entered hook through eager context") + +.. class:: smp.CheckpointManager + :noindex: + + + A subclass of TensorFlow + `CheckpointManager `__, + which is used to manage checkpoints. The usage is similar to TensorFlow + ``CheckpointManager``. + + The following returns a ``CheckpointManager`` object. + + .. code:: python + + smp.CheckpointManager(checkpoint, +                       directory="/opt/ml/checkpoints", +                       max_to_keep=None, +                       checkpoint_name="ckpt") + + **Parameters** + + - ``checkpoint``: A `tf.train.Checkpoint + `__ instance + that represents a model checkpoint. + + - ``directory``: (``str``) The path to a directory in which to write + checkpoints. A file named "checkpoint" is also written to this + directory (in a human-readable text format) which contains the state + of the ``CheckpointManager``. Defaults to + ``"/opt/ml/checkpoints"``, which is the directory that SageMaker + monitors for uploading the checkpoints to Amazon S3. + - ``max_to_keep`` (``int``): The number of checkpoints to keep. If + ``None``, all checkpoints are kept. + - ``checkpoint_name`` (``str``): Custom name for the checkpoint file. + Defaults to ``"ckpt"``. + + + **Methods:** + + .. function:: save( ) + :noindex: + + Saves a new checkpoint in the specified directory. Internally uses ``tf.train.CheckpointManager.save()``. + + .. function:: restore( ) + :noindex: + + Restores the latest checkpoint in the specified directory. + Internally uses ``tf.train.CheckpointManager.restore()``. + + + **Examples:** + + .. code:: python + + checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model) + ckpt_manager = smp.CheckpointManager(checkpoint, max_to_keep=5)  # use /opt/ml/checkpoints + + for inputs in train_ds: +     loss = train_step(inputs) +     # [...] +     ckpt_manager.save()  # save a new checkpoint in /opt/ml/checkpoints + + .. code:: python + + for step, inputs in enumerate(train_ds): +     if step == 0: +         ckpt_manager.restore() +     loss = train_step(inputs) diff --git a/doc/api/training/smp_versions/v1_9_0.rst b/doc/api/training/smp_versions/v1_9_0.rst new file mode 100644 index 0000000000..e2e9acd83a --- /dev/null +++ b/doc/api/training/smp_versions/v1_9_0.rst @@ -0,0 +1,13 @@ + +Version 1.7.0, 1.8.0, 1.8.1, 1.9.0 +================================== + +To use the library, reference the Common API documentation alongside the framework specific API documentation. + +.. toctree:: + :maxdepth: 1 + + v1.9.0/smd_model_parallel_common_api + v1.9.0/smd_model_parallel_pytorch + v1.9.0/smd_model_parallel_pytorch_tensor_parallel + v1.9.0/smd_model_parallel_tensorflow From 7b4c553c42ca108a05f86a324e67ec51aafa8d6b Mon Sep 17 00:00:00 2001 From: Miyoung Choi Date: Wed, 22 Jun 2022 20:18:25 -0700 Subject: [PATCH 13/23] fix indexing --- .../v1.9.0/smd_model_parallel_pytorch.rst | 19 +++++++++---------- ...model_parallel_pytorch_tensor_parallel.rst | 1 + .../v1.9.0/smd_model_parallel_tensorflow.rst | 2 +- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/doc/api/training/smp_versions/v1.9.0/smd_model_parallel_pytorch.rst b/doc/api/training/smp_versions/v1.9.0/smd_model_parallel_pytorch.rst index 055f2b6dde..88d1a42165 100644 --- a/doc/api/training/smp_versions/v1.9.0/smd_model_parallel_pytorch.rst +++ b/doc/api/training/smp_versions/v1.9.0/smd_model_parallel_pytorch.rst @@ -17,7 +17,7 @@ you need to add the following import statement at the top of your training scrip to learn how to use the following API in your PyTorch training script. .. class:: smp.DistributedModel - :noindex: + :noindex: A sub-class of ``torch.nn.Module`` which specifies the model to be partitioned. Accepts a ``torch.nn.Module`` object ``module`` which is @@ -362,45 +362,45 @@ you need to add the following import statement at the top of your training scrip `documentation `__. .. function:: distributed_modules() - :noindex: + :noindex: - An iterator that runs over the set of distributed (tensor-parallelized) modules in the model .. function:: is_distributed_parameter(param) - :noindex: + :noindex: - Returns ``True`` if the given ``nn.Parameter`` is distributed over tensor-parallel ranks. .. function:: is_distributed_buffer(buf) - :noindex: + :noindex: - Returns ``True`` if the given buffer is distributed over tensor-parallel ranks. .. function:: is_scaled_batch_parameter(param) - :noindex: + :noindex: - Returns ``True`` if the given ``nn.Parameter`` is operates on the scaled batch (batch over the entire ``TP_GROUP``, and not only the local batch). .. function:: is_scaled_batch_buffer(buf) - :noindex: + :noindex: - Returns ``True`` if the parameter corresponding to the given buffer operates on the scaled batch (batch over the entire ``TP_GROUP``, and not only the local batch). .. function:: default_reducer_named_parameters() - :noindex: + :noindex: - Returns an iterator that runs over ``(name, param)`` tuples, for ``param`` that is allreduced over the ``DP_GROUP``. .. function:: scaled_batch_reducer_named_parameters() - :noindex: + :noindex: - Returns an iterator that runs over ``(name, param)`` tuples, for ``param`` that is allreduced over the ``RDP_GROUP``. @@ -512,6 +512,7 @@ you need to add the following import statement at the top of your training scrip .. function::smp.is_tracing( ) :noindex: + :noindex: Returns ``True`` if the current process is running the tracing step, and ``False`` otherwise. @@ -527,7 +528,6 @@ you need to add the following import statement at the top of your training scrip .. data:: smp.optimizers.FusedNovoGrad :noindex: - `Fused Novo Grad optimizer `__ is currently not supported by the library. ``smp.optimizers.FusedNovoGrad`` replaces ``apex`` ``FusedNovoGrad`` optimizer and provides the same functionality. This requires ``apex`` to @@ -536,7 +536,6 @@ you need to add the following import statement at the top of your training scrip .. data:: smp.optimizers.FusedLamb :noindex: - `FusedLamb optimizer `__ currently doesn’t work with the library. ``smp.optimizers.FusedLamb`` replaces ``apex`` ``FusedLamb`` optimizer and provides the same functionality. diff --git a/doc/api/training/smp_versions/v1.9.0/smd_model_parallel_pytorch_tensor_parallel.rst b/doc/api/training/smp_versions/v1.9.0/smd_model_parallel_pytorch_tensor_parallel.rst index 851408b4b8..c66595ddf2 100644 --- a/doc/api/training/smp_versions/v1.9.0/smd_model_parallel_pytorch_tensor_parallel.rst +++ b/doc/api/training/smp_versions/v1.9.0/smd_model_parallel_pytorch_tensor_parallel.rst @@ -460,6 +460,7 @@ supported modules within that scope. To do this, you can use the following API: .. decorator:: smp.tensor_parallelism(enabled=True, **kwargs) + :noindex: - A context manager that enables or disables tensor parallelism for any supported module that is created inside. If there are nested diff --git a/doc/api/training/smp_versions/v1.9.0/smd_model_parallel_tensorflow.rst b/doc/api/training/smp_versions/v1.9.0/smd_model_parallel_tensorflow.rst index 54ec558fc5..2c658b487c 100644 --- a/doc/api/training/smp_versions/v1.9.0/smd_model_parallel_tensorflow.rst +++ b/doc/api/training/smp_versions/v1.9.0/smd_model_parallel_tensorflow.rst @@ -102,7 +102,7 @@ you need to add the following import statement at the top of your training scrip tf.print("Entered hook through eager context") .. class:: smp.CheckpointManager - :noindex: + :noindex: A subclass of TensorFlow From 2bb3204f414bec4f4599eed99c0805956fbbbc3a Mon Sep 17 00:00:00 2001 From: Miyoung Choi Date: Thu, 23 Jun 2022 15:50:59 -0700 Subject: [PATCH 14/23] add new smp cpu memory apis --- .../latest/smd_model_parallel_pytorch.rst | 249 +++++++++++++----- ...model_parallel_pytorch_tensor_parallel.rst | 102 +++---- 2 files changed, 243 insertions(+), 108 deletions(-) diff --git a/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch.rst b/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch.rst index b05413965c..d829da43f8 100644 --- a/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch.rst +++ b/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch.rst @@ -2,7 +2,7 @@ PyTorch API =========== To use the PyTorch-specific APIs for SageMaker distributed model parallism, -you need to add the following import statement at the top of your training script. +import the ``smdistributed.modelparallel.torch`` package at the top of your training script. .. code:: python @@ -16,24 +16,33 @@ you need to add the following import statement at the top of your training scrip `_ to learn how to use the following API in your PyTorch training script. -.. class:: smp.DistributedModel +.. contents:: Topics + :depth: 3 + :local: + +smdistributed.modelparallel.torch.DistributedModel +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. class:: smdistributed.modelparallel.torch.DistributedModel A sub-class of ``torch.nn.Module`` which specifies the model to be partitioned. Accepts a ``torch.nn.Module`` object ``module`` which is the model to be partitioned. The returned ``DistributedModel`` object internally manages model parallelism and data parallelism. Only one model in the training script can be wrapped with - ``smp.DistributedModel``. + ``smdistributed.modelparallel.torch.DistributedModel``. **Example:** .. code:: python + import smdistributed.modelparallel.torch as smp + model = smp.DistributedModel(model) **Important**: The ``__call__`` and  ``backward`` method calls on the - ``smp.DistributedModel`` object (in the following example, the object - is \ ``model``) can only be made inside a ``smp.step``-decorated + ``smdistributed.modelparallel.torch.DistributedModel`` object (in the following example, the object + is \ ``model``) can only be made inside a ``smdistributed.modelparallel.torch.step``-decorated function. Since ``DistributedModel``  is a ``torch.nn.Module``, a forward pass can @@ -78,7 +87,7 @@ you need to add the following import statement at the top of your training scrip In these examples, all ``__call__``  and ``backward`` method calls on the model objects (``model(inputs)`` and ``model.backward(loss)``) must be made inside - a ``smp.step``-decorated function. + a ``smdistributed.modelparallel.torch.step``-decorated function. **Using DDP** @@ -89,7 +98,7 @@ you need to add the following import statement at the top of your training scrip Unlike the original DDP wrapper, when you use ``DistributedModel``, model parameters and buffers are not immediately broadcast across processes when the wrapper is called. Instead, the broadcast is deferred to the first call of the - ``smp.step``-decorated function when the partition is done. + ``smdistributed.modelparallel.torch.step``-decorated function when the partition is done. **Parameters** @@ -160,7 +169,7 @@ you need to add the following import statement at the top of your training scrip - ``partitioned``: Is ``True`` if the model is partitioned, ``False`` otherwise. Initialized to ``False`` when ``DistributedModel`` is first created. It becomes be ``True`` during the first call - to ``smp.step``-decorated function. Once the model is partitioned, the + to ``smdistributed.modelparallel.torch.step``-decorated function. Once the model is partitioned, the local parameters or local ``state_dict`` can be fetched using the following methods. @@ -240,7 +249,7 @@ you need to add the following import statement at the top of your training scrip Registers a callable ``hook`` to be executed after the model is partitioned. This is useful in situations where an operation needs to be executed after the model partition during - the first call to ``smp.step``, but before the actual execution of the + the first call to ``smdistributed.modelparallel.torch.step``, but before the actual execution of the first forward pass. Returns a ``RemovableHandle`` object ``handle``, which can be used to remove the hook by calling ``handle.remove()``. @@ -252,7 +261,7 @@ you need to add the following import statement at the top of your training scrip .. function:: join( ) A context manager to be used in conjunction with an instance of - ``smp.DistributedModel`` to be able to train with uneven inputs across + ``smdistributed.modelparallel.torch.DistributedModel`` to be able to train with uneven inputs across participating processes. This is only supported when ``ddp=True``. This will use the join with the wrapped ``DistributedDataParallel`` instance. For more information, see: `join `__ @@ -276,9 +285,9 @@ you need to add the following import statement at the top of your training scrip `register_comm_hook `__ in the PyTorch documentation. - **Behavior of** ``smp.DistributedModel`` **with Tensor Parallelism** + **Behavior of** ``smdistributed.modelparallel.torch.DistributedModel`` **with Tensor Parallelism** - When a model is wrapped by ``smp.DistributedModel``, the library + When a model is wrapped by ``smdistributed.modelparallel.torch.DistributedModel``, the library immediately traverses the modules of the model object, and replaces the modules that are supported for tensor parallelism with their distributed counterparts. This replacement happens in place. If there are no other @@ -293,6 +302,8 @@ you need to add the following import statement at the top of your training scrip # register DistributedSubmodule as the distributed version of Submodule # (note this is a hypothetical example, smp.nn.DistributedSubmodule does not exist) + import smdistributed.modelparallel.torch as smp + smp.tp_register_with_module(Submodule, smp.nn.DistributedSubmodule) class MyModule(nn.Module): @@ -319,20 +330,20 @@ you need to add the following import statement at the top of your training scrip placement of model partitions into GPUs and the initial broadcast of model parameters and buffers across data-parallel ranks take place immediately. This is because it does not need to wait for the model - partition when ``smp.DistributedModel`` wrapper is called. For other + partition when ``smdistributed.modelparallel.torch.DistributedModel`` wrapper is called. For other cases with ``pipeline_parallel_degree`` greater than 1, the broadcast and device placement will be deferred until the first call of an - ``smp.step``-decorated function happens. This is because the first - ``smp.step``-decorated function call is when the model partitioning + ``smdistributed.modelparallel.torch.step``-decorated function happens. This is because the first + ``smdistributed.modelparallel.torch.step``-decorated function call is when the model partitioning happens if pipeline parallelism is enabled. - Because of the module replacement during the ``smp.DistributedModel`` + Because of the module replacement during the ``smdistributed.modelparallel.torch.DistributedModel`` call, any ``load_state_dict`` calls on the model, as well as any direct access to model parameters, such as during the optimizer creation, - should be done **after** the ``smp.DistributedModel`` call. + should be done **after** the ``smdistributed.modelparallel.torch.DistributedModel`` call. Since the broadcast of the model parameters and buffers happens - immediately during ``smp.DistributedModel`` call when the degree of + immediately during ``smdistributed.modelparallel.torch.DistributedModel`` call when the degree of pipeline parallelism is 1, using ``@smp.step`` decorators is not required when tensor parallelism is used by itself (without pipeline parallelism). @@ -340,9 +351,9 @@ you need to add the following import statement at the top of your training scrip For more information about the library's tensor parallelism APIs for PyTorch, see :ref:`smdmp-pytorch-tensor-parallel`. - **Additional Methods of** ``smp.DistributedModel`` **for Tensor Parallelism** + **Additional Methods of** ``smdistributed.modelparallel.torch.DistributedModel`` **for Tensor Parallelism** - The following are the new methods of ``smp.DistributedModel``, in + The following are the new methods of ``smdistributed.modelparallel.torch.DistributedModel``, in addition to the ones listed in the `documentation `__. @@ -383,24 +394,26 @@ you need to add the following import statement at the top of your training scrip - Returns an iterator that runs over ``(name, param)`` tuples, for ``param`` that is allreduced over the ``RDP_GROUP``. +smdistributed.modelparallel.torch.DistributedOptimizer +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +.. class:: smdistributed.modelparallel.torch.DistributedOptimizer(optimizer) -.. class:: smp.DistributedOptimizer + An optimizer wrapper for saving and loading optimizer states. - **Parameters** - - ``optimizer`` + :param optimizer: An optimizer object. + :type optimizer: object - An optimizer wrapper for saving/loading optimizer states. This wrapper - returns ``optimizer`` with the following methods overridden: + This wrapper returns ``optimizer`` with the following methods overridden: - .. function:: state_dict( ) + .. method:: state_dict( ) Returns the ``state_dict`` that contains optimizer state for the entire model. It first collects the ``local_state_dict`` and gathers and merges the ``local_state_dict`` from all ``mp_rank``s to create a full ``state_dict``. - .. function:: load_state_dict( ) + .. method:: load_state_dict( ) Same as the ``torch.optimizer.load_state_dict()`` , except: @@ -409,7 +422,7 @@ you need to add the following import statement at the top of your training scrip - The actual loading happens after the model partition so that each rank knows its local parameters. - .. function:: local_state_dict( ) + .. method:: local_state_dict( ) Returns the ``state_dict`` that contains the local optimizer state that belongs to the current \ ``mp_rank``. This @@ -418,34 +431,140 @@ you need to add the following import statement at the top of your training scrip ``state_dict`` contains elements corresponding to only the current partition, or to the entire model. - ​ -.. function:: smp.partition(index) - :noindex: + .. method:: save_optimizer_backcompat(cast_to_cpu=True, gather_if_shard=True, fp32_states_only=False) + + Gets the local optimizer states and FP16 states if FP16 training is enabled. + + :param cast_to_cpu: Whether to cast the optimizer states and FP16 states to CPU. + :type cast_to_cpu: boolean + :param gather_if_shard: (for smdistributed-modelparallel v1.10 only) + Whether to gather the optimizer states and FP16 states to the 0th + ``rdp_rank`` when using the `optimizer state sharding + `_ feature. + If you want to save optimizer and also further reduce CPU memory + utilization for better performance, turn it off by setting + ``gather_if_shard=False``. However, you need to make sure that you + save the states on all ``rdp_rank``. To handle both cases, + use the following example code. + + + :type gather_if_shard: boolean + :param fp32_states_only: Whether to return the FP32 optimizer states only. + :type fp32_states_only: boolean + + **Example Usage:** + + .. code:: python + + import smdistributed.modelparallel.torch as smp + + # wrap optimizer + optimizer = torch.optim.Optimizer(...) + optimizer = smp.DistributedOptimizer(optimizer) + + # save optimizer + save_dict["optimizer"] = optimizer.save_optimizer_backcompat( + gather_if_shard=args.gather_if_shard + ) + if not args.gather_if_shard or smp.rdp_rank() == 0: + smp.save( + save_dict, output_save_file, partial=True, + v3=not args.gather_if_shard + ) + + The ``v3`` argument of the ``smp.save()`` function checks whether the value of + the ``gather_if_shard`` arg is ``True`` or ``False``. + If ``gather_if_shard=False``, the ``v3`` arg helps collect optimizer checkpoint + files by adding ``pp_rank``, ``tp_rank``, and ``rdp_rank`` as postfix + to avoid overwriting checkpoint files. + + .. method:: load_optimizer_backcompat(state_dict, gather_if_shard=False) + + Loads the saved optimizer states and FP16 states if FP16 training is enabled. + + :param state_dict: The ``state_dict`` to load. + :type state_dict: dict + :param gather_if_shard: Specify whether the optimizer state was saved with ``gather_if_shard=True`` + when using the :class:`smdistributed.modelparallel.torch.DistributedOptimizer.save_optimizer_backcompat()` method. + :type gather_if_shard: boolean + + **Example Usage:** + + .. code:: python + + import smdistributed.modelparallel.torch as smp + + # load optimizer + checkpoint = smp.load(local_ckpt_path, partial=True) + optimizer.load_optimizer_backcompat( + checkpoint["optimizer"], gather_if_shard=args.gather_if_shard + ) - **Inputs** +smdistributed.modelparallel.torch Context Managers and Util Functions +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - - ``index`` (int) - The index of the partition. +.. function:: smdistributed.modelparallel.torch.model_creation(tensor_parallelism=False, dtype=None, **tensor_parallel_config) + + Context manager to create a ``torch`` model. This API combines both the + :class:`smdistributed.modelparallel.torch.tensor_parallelism` and + ``smdistributed.modelparallel.torch.delay_param_initialization`` decorators + so user need to simply use a single context when creating the torch model. + + :param tensor_parallelism: Whether tensor parallel should be enabled during model creation. + :type tensor_parallelism: boolean + :param dtype: The dtype to use when creating the model. It has the following rules. + + * If dtype is specified, it will be used during model creation. + * If dtype is not specified, the default dtype will be used during model creation, + which is usually FP32. This is for the best performance on CPU. + * Any model that causes out-of-memory problems with FP32 initialization + is recommended to be created with + :class:`smdistributed.modelparallel.torch.delayed_parameter_initialization`. + * ``FP16_Module`` casts the model back to FP16 if FP16 training is enabled with the smp config. + :type dtype: torch.dtype + :param tensor_parallel_config: kwargs to specifiy other tensor parallel configs. + This is not used if tensor_parallelism is False + :type tensor_parallel_config: dict + + **Example Usage:** + + .. code:: python + + import smdistributed.modelparallel.torch as smp + + with smp.model_creation( + tensor_parallelism=smp.tp_size() > 1, + dtype=torch.float16 if args.fp16 else torch.get_default_dtype() + ): + model = MyModel(...) + +.. function:: smdistributed.modelparallel.torch.partition(index) + + :param index: The index of the partition. + :type index: int A context manager which places all modules defined inside into the partition with ID ``index``.  The ``index`` argument must be less than the number of partitions. - Use ``smp.partition`` to implement manual partitioning. + Use ``smdistributed.modelparallel.torch.partition`` to implement manual partitioning. If ``"auto_partition"`` is ``True``, then the - ``smp.partition`` contexts are ignored. Any module that is not placed in - any ``smp.partition`` context is placed in the + ``smdistributed.modelparallel.torch.partition`` contexts are ignored. Any module that is not placed in + any ``smdistributed.modelparallel.torch.partition`` context is placed in the ``default_partition`` defined through the SageMaker Python SDK. - When ``smp.partition`` contexts are nested, the innermost context + When ``smdistributed.modelparallel.torch.partition`` contexts are nested, the innermost context overrides the rest (see the following example). In PyTorch, manual partitioning should be done inside the module \ ``__init__``, and the partition assignment applies to the modules that are *created* inside - the ``smp.partition`` context. + the ``smdistributed.modelparallel.torch.partition`` context. Example: .. code:: python + import smdistributed.modelparallel.torch as smp + class Model(torch.nn.Module):     def __init__(self):         with smp.partition(1): @@ -455,29 +574,40 @@ you need to add the following import statement at the top of your training scrip             self.child2 = Child2()            # child2 on partition 1         self.child3 = Child3()                # child3 on default_partition -.. function:: smp.get_world_process_group( ) +.. data:: smdistributed.modelparallel.torch.amp.GradScaler + + `Torch AMP Gradscaler `__ + currently doesn’t work with the library. ``smdistributed.modelparallel.torch.amp.GradScaler`` replaces + ``torch.amp.GradScaler`` and provides the same functionality. + +.. function:: smdistributed.modelparallel.torch.delayed_parameter_initialization(enabled=True) + + If enabled, it delays the initialization of parameters + to save CPU memory; it initializes after the model is partitioned on GPU. + +.. function:: smdistributed.modelparallel.torch.get_world_process_group( ) Returns a ``torch.distributed`` ``ProcessGroup`` that consists of all processes, which can be used with the ``torch.distributed`` API. Requires ``"ddp": True`` in SageMaker Python SDK parameters. -.. function:: smp.get_mp_process_group( ) +.. function:: smdistributed.modelparallel.torch.get_mp_process_group( ) Returns a ``torch.distributed`` ``ProcessGroup`` that consists of the processes in the ``MP_GROUP`` which contains the current process, which can be used with the \ ``torch.distributed`` API. Requires ``"ddp": True`` in SageMaker Python SDK parameters. -.. function:: smp.get_dp_process_group( ) +.. function:: smdistributed.modelparallel.torch.get_dp_process_group( ) Returns a ``torch.distributed`` ``ProcessGroup`` that consists of the processes in the ``DP_GROUP`` which contains the current process, which can be used with the \ ``torch.distributed`` API. Requires ``"ddp": True`` in SageMaker Python SDK parameters. -.. function:: smp.is_initialized( ) +.. function:: smdistributed.modelparallel.torch.is_initialized( ) - Returns ``True`` if ``smp.init`` has already been called for the + Returns ``True`` if ``smdistributed.modelparallel.torch.init`` has already been called for the process, and ``False`` otherwise. .. function::smp.is_tracing( ) @@ -485,41 +615,35 @@ you need to add the following import statement at the top of your training scrip Returns ``True`` if the current process is running the tracing step, and ``False`` otherwise. -.. data:: smp.nn.FusedLayerNorm +.. data:: smdistributed.modelparallel.torch.nn.FusedLayerNorm `Apex Fused Layer Norm `__ is currently not - supported by the library. ``smp.nn.FusedLayerNorm`` replaces ``apex`` + supported by the library. ``smdistributed.modelparallel.torch.nn.FusedLayerNorm`` replaces ``apex`` ``FusedLayerNorm`` and provides the same functionality. This requires ``apex`` to be installed on the system. -.. data:: smp.optimizers.FusedNovoGrad +.. data:: smdistributed.modelparallel.torch.optimizers.FusedNovoGrad `Fused Novo Grad optimizer `__ is - currently not supported by the library. ``smp.optimizers.FusedNovoGrad`` replaces ``apex`` ``FusedNovoGrad`` + currently not supported by the library. ``smdistributed.modelparallel.torch.optimizers.FusedNovoGrad`` replaces ``apex`` ``FusedNovoGrad`` optimizer and provides the same functionality. This requires ``apex`` to be installed on the system. -.. data:: smp.optimizers.FusedLamb +.. data:: smdistributed.modelparallel.torch.optimizers.FusedLamb `FusedLamb optimizer `__ - currently doesn’t work with the library. ``smp.optimizers.FusedLamb`` replaces + currently doesn’t work with the library. ``smdistributed.modelparallel.torch.optimizers.FusedLamb`` replaces ``apex`` ``FusedLamb`` optimizer and provides the same functionality. This requires ``apex`` to be installed on the system. -.. data:: smp.amp.GradScaler - - `Torch AMP Gradscaler `__ - currently doesn’t work with the library. ``smp.amp.GradScaler`` replaces - ``torch.amp.GradScaler`` and provides the same functionality. - .. _pytorch_saving_loading: APIs for Saving and Loading ^^^^^^^^^^^^^^^^^^^^^^^^^^^ -.. function:: smp.save( ) +.. function:: smdistributed.modelparallel.torch.save( ) Saves an object. This operation is similar to ``torch.save()``, except it has an additional keyword argument, ``partial``, and accepts only @@ -534,16 +658,18 @@ APIs for Saving and Loading - ``partial`` (bool, default= ``True``):  When set to ``True``, each ``mp_rank`` saves a separate checkpoint file and the library adds an ``mp_rank`` index to the saved file. If you want to be able to load - and further train a model that you save with ``smp.save()``, you must + and further train a model that you save with ``smdistributed.modelparallel.torch.save()``, you must set ``partial=True``. - ``pickle_module`` (picklemodule, default = module ``"pickle"`` from ``"/opt/conda/lib/python3.6/pickle.py"``): A module used for pickling metadata and objects. - ``pickle_protocol``  (int, default=2): Can be specified to override the defaultprotocol. + - ``v3`` (bool, default=``False``): When set to ``True``, save optimizer state checkpoints + in V3 file format to add all ``pp_rank``, ``tp_rank``, and ``rdp_rank`` as postfix. -.. function:: smp.load( ) +.. function:: smdistributed.modelparallel.torch.load( ) - Loads an object saved with ``smp.save()`` from a file. + Loads an object saved with ``smdistributed.modelparallel.torch.save()`` from a file. Similar to, `torch.load() `__, except it has an additional keyword argument, ``partial``, and accepts @@ -577,13 +703,13 @@ The library can save partial or full checkpoints. - For full checkpoints, the library saves a single checkpoint that contains entire model parameters. -When **saving** using ``smp.save()``, each rank only holds its own +When **saving** using ``smdistributed.modelparallel.torch.save()``, each rank only holds its own parameters. If you want to save the full model, there will be some communication between the ranks to create the full model. If you save checkpoints often, you should save partial checkpoints for best performance. -When **loading** using ``smp.load()``, the library can load either partial or | +When **loading** using ``smdistributed.modelparallel.torch.load()``, the library can load either partial or | full checkpoints or full checkpoints saved by a non-model-parallel model. If you want to resume training with a non-model-parallel model or do inference, you need a full checkpoint. @@ -592,6 +718,7 @@ The following is an example of how you can save and load a checkpoint: .. code:: python + import smdistributed.modelparallel.torch as smp # Original model and optimizer model = MyModel(...) optimizer = MyOpt(...) diff --git a/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch_tensor_parallel.rst b/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch_tensor_parallel.rst index e0ea1ba6c8..7a75b8f9f3 100644 --- a/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch_tensor_parallel.rst +++ b/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch_tensor_parallel.rst @@ -17,7 +17,7 @@ place on a best-effort basis for those module supported for tensor parallelism. Alternatively, you can directly import and use the library’s distributed modules in the model definition. -Some of the supported modules (such as ``smp.nn.Transformer``) are high-level +Some of the supported modules (such as ``smdistributed.modelparallel.torch.nn.Transformer``) are high-level blocks that contain many operations. Because custom implementations (as opposed to the built-in PyTorch modules) are typically used for these high-level blocks, the library offers an API that you can use to register @@ -47,9 +47,9 @@ use is functionally equivalent to the distributed module. You can verify this by taking a look at the equivalent reference implementations in the :ref:`smdmp-tp-appendix`. These implementations are functionally equivalent to their distributed -versions in ``smp.nn`` module. +versions in ``smdistributed.modelparallel.torch.nn`` module. -.. decorator:: @smp.tp_register(dist_module, init_hook=None, forward_hook=None, return_hook=None) +.. decorator:: @smdistributed.modelparallel.torch.tp_register(dist_module, init_hook=None, forward_hook=None, return_hook=None) - A class decorator that registers the ``dist_module`` class with the module class that it is attached to. The hooks can be used to @@ -57,10 +57,10 @@ versions in ``smp.nn`` module. ``forward`` methods. - **Arguments:** - - ``dist_module``: A subclass of ``smp.nn.DistributedModule`` + - ``dist_module``: A subclass of ``smdistributed.modelparallel.torch.nn.DistributedModule`` that implements the distributed version of the module class the decorator is attached to. Any distributed module class defined - in ``smp.nn`` module can be used. + in ``smdistributed.modelparallel.torch.nn`` module can be used. - ``init_hook``: A callable that translates the arguments of the original module ``__init__`` method to an ``(args, kwargs)`` tuple compatible with the arguments of the corresponding @@ -89,6 +89,8 @@ versions in ``smp.nn`` module. .. code:: python + import smdistributed.modelparallel.torch as smp + init_hook = lambda config: ((), config.to_dict()) # register smp.nn.DistributedTransformer @@ -101,7 +103,7 @@ versions in ``smp.nn`` module. def forward(self, hidden_states, attention_mask): ... -.. function:: smp.tp_register_with_module(module_cls, dist_module, init_hook=None, forward_hook=None, return_hook=None) +.. function:: smdistributed.modelparallel.torch.tp_register_with_module(module_cls, dist_module, init_hook=None, forward_hook=None, return_hook=None) - When you do not have direct access to model definition code, you can use this API to similarly register a distributed module with @@ -111,10 +113,10 @@ versions in ``smp.nn`` module. - ``module_cls``: The existing module class that will be distributed. - - ``dist_module``: A subclass of ``smp.nn.DistributedModule`` + - ``dist_module``: A subclass of ``smdistributed.modelparallel.torch.nn.DistributedModule`` that implements the distributed version of the module class the decorator is attached to. Any distributed module class defined - in ``smp.nn`` module can be used. + in ``smdistributed.modelparallel.torch.nn`` module can be used. - ``init_hook``: A callable that translates the arguments of the original module ``__init__`` method to an ``(args, kwargs)`` tuple compatible with the arguments of the corresponding @@ -143,6 +145,8 @@ versions in ``smp.nn`` module. .. code:: python + import smdistributed.modelparallel.torch as smp + from somelibrary import MyTransformer init_hook = lambda config: ((), config.to_dict()) @@ -160,13 +164,13 @@ Supported Modules for Tensor Parallelism The following modules are supported for tensor parallelism. -- ``smp.nn.DistributedLinear`` (implements ``nn.Linear``) -- ``smp.nn.DistributedTransformerLMHead`` -- ``smp.nn.DistributedTransformer`` -- ``smp.nn.DistributedTransformerLayer`` -- ``smp.nn.DistributedAttentionLayer`` -- ``smp.nn.DistributedTransformerOutputLayer`` -- ``smp.nn.DistributedEmbedding`` +- ``smdistributed.modelparallel.torch.nn.DistributedLinear`` (implements ``nn.Linear``) +- ``smdistributed.modelparallel.torch.nn.DistributedTransformerLMHead`` +- ``smdistributed.modelparallel.torch.nn.DistributedTransformer`` +- ``smdistributed.modelparallel.torch.nn.DistributedTransformerLayer`` +- ``smdistributed.modelparallel.torch.nn.DistributedAttentionLayer`` +- ``smdistributed.modelparallel.torch.nn.DistributedTransformerOutputLayer`` +- ``smdistributed.modelparallel.torch.nn.DistributedEmbedding`` .. contents:: Topics :depth: 3 @@ -177,7 +181,7 @@ parallelism. Tensor Parallelism Module APIs ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. class:: smp.nn.DistributedLinear(in_features, out_features) +.. class:: smdistributed.modelparallel.torch.nn.DistributedLinear(in_features, out_features) - Tensor-parallel implementation of the ``nn.Linear`` class. Functionally equivalent to an ``nn.Linear`` module with the same @@ -191,7 +195,7 @@ Tensor Parallelism Module APIs - ``out_features``: The total number of output channels for the linear layer across all tensor-parallel ranks. -.. class:: smp.nn.DistributedTransformerLMHead(num_layers=12, num_attention_heads=32, attention_head_size=32, hidden_size=1024, intermediate_size=4096, vocab_size=30522, num_positions=1024, attention_dropout_prob=0.1, hidden_dropout_prob=0.1, activation="gelu", layernorm_epsilon=1e-5, num_token_types=0, causal_mask_size=None, add_cross_attention=False, add_lm_head=True, initializer_range=0.02, use_normal_initialization=False, pre_layernorm=False, post_layernorm=True) +.. class:: smdistributed.modelparallel.torch.nn.DistributedTransformerLMHead(num_layers=12, num_attention_heads=32, attention_head_size=32, hidden_size=1024, intermediate_size=4096, vocab_size=30522, num_positions=1024, attention_dropout_prob=0.1, hidden_dropout_prob=0.1, activation="gelu", layernorm_epsilon=1e-5, num_token_types=0, causal_mask_size=None, add_cross_attention=False, add_lm_head=True, initializer_range=0.02, use_normal_initialization=False, pre_layernorm=False, post_layernorm=True) - Constructs a distributed transformer model, including embeddings and a single LM head. A word embedding of size @@ -205,7 +209,7 @@ Tensor Parallelism Module APIs if ``add_lm_head`` is ``True``, the output passes through a single LM head, which is a linear module without bias whose weight is tied to the word embeddings. - - See :class:`smp.nn.DistributedTransformerLayer` for descriptions of the rest + - See :class:`smdistributed.modelparallel.torch.nn.DistributedTransformerLayer` for descriptions of the rest of the arguments. - **Methods:** @@ -223,18 +227,18 @@ Tensor Parallelism Module APIs - ``attention_mask`` is assumed to be a 0-1 tensor of shape ``[N, S]``, where 1 represents a masked position. -.. class:: smp.nn.DistributedTransformer(num_layers=12, num_attention_heads=32, attention_head_size=32, hidden_size=1024, intermediate_size=4096, attention_dropout_prob=0.1, hidden_dropout_prob=0.1, activation="gelu", layernorm_epsilon=1e-5, initializer_range=0.02, use_normal_initialization=False, causal_mask_size=None, add_cross_attention=False, pre_layernorm=False, post_layernorm=True) +.. class:: smdistributed.modelparallel.torch.nn.DistributedTransformer(num_layers=12, num_attention_heads=32, attention_head_size=32, hidden_size=1024, intermediate_size=4096, attention_dropout_prob=0.1, hidden_dropout_prob=0.1, activation="gelu", layernorm_epsilon=1e-5, initializer_range=0.02, use_normal_initialization=False, causal_mask_size=None, add_cross_attention=False, pre_layernorm=False, post_layernorm=True) - - A sequence of ``smp.nn.DistributedTransformerLayer``\ s, whose + - A sequence of ``smdistributed.modelparallel.torch.nn.DistributedTransformerLayer``\ s, whose number is given by ``num_layers`` argument. For the other arguments and methods, refer to - ``smp.nn.DistributedTransformerLayer``. + ``smdistributed.modelparallel.torch.nn.DistributedTransformerLayer``. - If both ``pre_layernorm`` and ``post_layernorm`` are ``True``, layer normalization is applied to both the input and the output of the ``DistributedTransformer``, in addition to the intermediate attention and transformer-output layers. -.. class:: smp.nn.DistributedTransformerLayer(num_attention_heads=32, attention_head_size=32, hidden_size=1024, intermediate_size=4096, attention_dropout_prob=0.1, hidden_dropout_prob=0.1, activation="gelu", layernorm_epsilon=1e-5, initializer_range=0.02, use_normal_initialization=False, causal_mask_size=None, add_cross_attention=False, pre_layernorm=False, post_layernorm=True) +.. class:: smdistributed.modelparallel.torch.nn.DistributedTransformerLayer(num_attention_heads=32, attention_head_size=32, hidden_size=1024, intermediate_size=4096, attention_dropout_prob=0.1, hidden_dropout_prob=0.1, activation="gelu", layernorm_epsilon=1e-5, initializer_range=0.02, use_normal_initialization=False, causal_mask_size=None, add_cross_attention=False, pre_layernorm=False, post_layernorm=True) - Tensor-parallel implementation of a single transformer layer. Number of attention heads, hidden size, and intermediate size @@ -336,7 +340,7 @@ Tensor Parallelism Module APIs and the next three tensors are the same as the input arguments. -.. class:: smp.nn.DistributedAttentionLayer(num_attention_heads=32, attention_head_size=32, hidden_size=1024, attention_dropout_prob=0.1, hidden_dropout_prob=0.1, layernorm_epsilon=1e-5, initializer_range=0.02, use_normal_initialization=False, cross_attention=False, causal_mask_size=None, pre_layernorm=False, post_layernorm=True) +.. class:: smdistributed.modelparallel.torch.nn.DistributedAttentionLayer(num_attention_heads=32, attention_head_size=32, hidden_size=1024, attention_dropout_prob=0.1, hidden_dropout_prob=0.1, layernorm_epsilon=1e-5, initializer_range=0.02, use_normal_initialization=False, cross_attention=False, causal_mask_size=None, pre_layernorm=False, post_layernorm=True) - A distributed implementation for the attention block. Includes the computation of the self- or cross-attention (context layer), @@ -344,7 +348,7 @@ Tensor Parallelism Module APIs followed by the residual-connection and layer normalization. - **Arguments:** - - See :class:`smp.nn.DistributedTransformerLayer` for descriptions of the + - See :class:`smdistributed.modelparallel.torch.nn.DistributedTransformerLayer` for descriptions of the arguments. - ``cross_attention``: If ``True``, it computes the attentions with respect to the ``cross_states`` tensor of the ``forward`` @@ -383,10 +387,10 @@ Tensor Parallelism Module APIs - A single tensor that is the output of the attention layer. -.. class:: smp.nn.DistributedTransformerOutputLayer(hidden_size=1024, intermediate_size=4096, hidden_dropout_prob=0.1, activation="gelu", layernorm_epsilon=1e-5, initializer_range=0.02, use_normal_initialization=False, pre_layernorm=False, post_layernorm=True, fp32_residual_addition=False) +.. class:: smdistributed.modelparallel.torch.nn.DistributedTransformerOutputLayer(hidden_size=1024, intermediate_size=4096, hidden_dropout_prob=0.1, activation="gelu", layernorm_epsilon=1e-5, initializer_range=0.02, use_normal_initialization=False, pre_layernorm=False, post_layernorm=True, fp32_residual_addition=False) - Distributed implementation of a single transformer output layer. A - single :class:`smp.nn.DistributedTransformerLayer` with + single :class:`smdistributed.modelparallel.torch.nn.DistributedTransformerLayer` with ``add_cross_attention=False`` consists of a single ``DistributedAttentionLayer`` immediately followed by a single ``DistributedTransformerOutputLayer``. The latter linearly maps @@ -394,19 +398,19 @@ Tensor Parallelism Module APIs ``intermediate_size``, and then maps it back to ``hidden_size``. - **Arguments:** - - See :class:`smp.nn.DistributedTransformerLayer` for descriptions of the + - See :class:`smdistributed.modelparallel.torch.nn.DistributedTransformerLayer` for descriptions of the arguments. - ``fp32_residual_addition``: Set to ``True`` if you want to avoid overflow (NaN loss values) for large models with more than 100 billion parameters when using FP16. (Default: False) -.. class:: smp.nn.DistributedEmbedding(num_embeddings,embedding_dim, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False, _weight=None, initializer_range=0.02, _skip_allgather=False,_skip_scatter_and_merge=False,) +.. class:: smdistributed.modelparallel.torch.nn.DistributedEmbedding(num_embeddings,embedding_dim, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False, _weight=None, initializer_range=0.02, _skip_allgather=False,_skip_scatter_and_merge=False,) - Distributed implementation of a single Embedding Layer. Currently only supports splitting across the embedding_dim. - **Arguments:** - - See :class:`smp.nn.DistributedEmbedding` for descriptions of the + - See :class:`smdistributed.modelparallel.torch.nn.DistributedEmbedding` for descriptions of the arguments. .. _enabling-tp: @@ -417,7 +421,7 @@ Enabling Tensor Parallelism There are two ways tensor parallelism can be enabled. First, you can use -the distributed module implementations in ``smp.nn`` module directly in +the distributed module implementations in ``smdistributed.modelparallel.torch.nn`` module directly in your model definition. See :ref:`smdmp-supported-modules-for-tp` for a complete list of built-in distributed modules. Here is an example of how this can be done: @@ -446,7 +450,7 @@ of code, which will automatically enable tensor parallelism for the supported modules within that scope. To do this, you can use the following API: -.. decorator:: smp.tensor_parallelism(enabled=True, **kwargs) +.. decorator:: smdistributed.modelparallel.torch.tensor_parallelism(enabled=True, **kwargs) - A context manager that enables or disables tensor parallelism for any supported module that is created inside. If there are nested @@ -463,6 +467,8 @@ following API: .. code:: python + import smdistributed.modelparallel.torch as smp + with smp.tensor_parallelism(): self.m0 = nn.Linear(20, 20) # will be distributed with smp.tensor_parallelism(enabled=False): @@ -472,7 +478,7 @@ following API: the distributed modules created inside the context. If a keyword argument provided through it matches any ``__init__`` method arguments of a ``DistributedModule`` that substitutes a module created inside - the ``smp.tensor_parallelism`` context, this keyword will override + the ``smdistributed.modelparallel.torch.tensor_parallelism`` context, this keyword will override the value defined in the ``init_hook``. - (*For v1.7.0 and later*) Through the following additional keyword arguments, @@ -481,21 +487,21 @@ following API: - ``fused_softmax`` (bool) - Fusion of attention masking and softmax. By default, it is set to ``True``. You can deactivate it by setting - ``fused_softmax=False`` in the ``smp.tensor_parallelism`` context manager. + ``fused_softmax=False`` in the ``smdistributed.modelparallel.torch.tensor_parallelism`` context manager. - ``fused_bias_gelu`` (bool) - Fusion of bias addition and Gelu activation. By default, it is set to ``False``. You can activate it by setting - ``fused_bias_gelu=True`` in the ``smp.tensor_parallelism`` context manager. + ``fused_bias_gelu=True`` in the ``smdistributed.modelparallel.torch.tensor_parallelism`` context manager. -.. function:: smp.set_tensor_parallelism(module, enabled=True, **kwargs) +.. function:: smdistributed.modelparallel.torch.set_tensor_parallelism(module, enabled=True, **kwargs) - Enables or disables tensor parallelism for the supported submodules of ``module``. If enabling, the outermost supported modules will be distributed. If disabling, tensor parallelism will be disabled for the entire module subtree of ``module``. Unlike the context manager, this API can be used after the model creation - (but before wrapping with :class:`smp.DistributedModel`), so direct + (but before wrapping with :class:`smdistributed.modelparallel.torch.DistributedModel`), so direct access to model definition code is not required. If a supported module shares weights with another (supported or unsupported) module, or if its hyperparameters do not support distribution @@ -504,14 +510,16 @@ following API: - Keyword arguments ``kwargs`` can be used to modify the configurations of the distributed modules created inside the context. If a keyword argument provided here matches any - ``__init__`` method arguments of a :class:`smp.DistributedModel` that - substitutes a module created inside the ``smp.tensor_parallelism`` + ``__init__`` method arguments of a :class:`smdistributed.modelparallel.torch.DistributedModel` that + substitutes a module created inside the ``smdistributed.modelparallel.torch.tensor_parallelism`` context, this keyword will override the value defined in the ``init_hook``. - **Example:** .. code:: python + import smdistributed.modelparallel.torch as smp + model = MyModel() smp.set_tensor_parallelism(model.encoder, True) smp.set_tensor_parallelism(model.encoder.embedding, True) @@ -608,7 +616,7 @@ in the *SageMaker's Distributed Model Parallel developer guide*. any tuples received. If the checkpointed layer takes a tuple as input, then this needs to be set to True. -.. class:: smp.set_activation_checkpointing(module, preserve_rng_state=True, pack_args_as_tuple=False, strategy="each") +.. class:: smdistributed.modelparallel.torch.set_activation_checkpointing(module, preserve_rng_state=True, pack_args_as_tuple=False, strategy="each") - This API is recommended when importing pretrained models from libraries, such as PyTorch and Hugging Face Transformers. This is @@ -673,8 +681,8 @@ parses the arguments to ``__init__`` methods and sets the relevant attributes of the module, such as ``hidden_size`` and ``num_attention_heads``. -``smp.nn.DistributedTransformer`` -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +``smdistributed.modelparallel.torch.nn.DistributedTransformer`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. code:: python @@ -692,8 +700,8 @@ attributes of the module, such as ``hidden_size`` and def forward(self, inp): return self.seq_layers(inp) -``smp.nn.DistributedTransformerLayer`` -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +``smdistributed.modelparallel.torch.nn.DistributedTransformerLayer`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. code:: python @@ -727,8 +735,8 @@ attributes of the module, such as ``hidden_size`` and else: return output, attention_mask -``smp.nn.DistributedAttentionLayer`` -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +``smdistributed.modelparallel.torch.nn.DistributedAttentionLayer`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. code:: python @@ -812,8 +820,8 @@ attributes of the module, such as ``hidden_size`` and else: return self_attention -``smp.nn.DistributedTransformerOutputLayer`` -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +``smdistributed.modelparallel.torch.nn.DistributedTransformerOutputLayer`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. code:: python From 5e3f8f0b0888d7d03e0d86eda82ec5f2e1a8b158 Mon Sep 17 00:00:00 2001 From: Miyoung Choi Date: Thu, 23 Jun 2022 16:54:07 -0700 Subject: [PATCH 15/23] add new params --- doc/api/training/smd_model_parallel_general.rst | 10 ++++++++++ .../latest/smd_model_parallel_pytorch.rst | 15 +++++++++++---- 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/doc/api/training/smd_model_parallel_general.rst b/doc/api/training/smd_model_parallel_general.rst index a35e0d60bc..fbb99f5224 100644 --- a/doc/api/training/smd_model_parallel_general.rst +++ b/doc/api/training/smd_model_parallel_general.rst @@ -178,6 +178,16 @@ PyTorch-specific Parameters - 1 - The number of devices over which the tensor parallel modules will be distributed. If ``tensor_parallel_degree`` is greater than 1, then ``ddp`` must be set to ``True``. + * - ``fp16`` (**smdistributed-modelparallel**>=v1.10) + - bool + - ``False`` + - To run FP16 training, add ``"fp16"'": True`` to the smp configuration. + Other APIs remain the same between FP16 and FP32. + If ``fp16`` is enabled and when user calls ``smp.DistributedModel``, + the model will be wrapped with ``FP16_Module``, which converts the model + to FP16 dtype and deals with forward pass in FP16. + If ``fp16`` is enabled and when user calls ``smp.DistributedOptimizer``, + the optimizer will be wrapped with ``FP16_Optimizer``. * - ``fp16_params`` (**smdistributed-modelparallel**>=v1.6) - bool - ``False`` diff --git a/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch.rst b/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch.rst index d829da43f8..3340587641 100644 --- a/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch.rst +++ b/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch.rst @@ -397,12 +397,18 @@ smdistributed.modelparallel.torch.DistributedModel smdistributed.modelparallel.torch.DistributedOptimizer ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -.. class:: smdistributed.modelparallel.torch.DistributedOptimizer(optimizer) +.. class:: smdistributed.modelparallel.torch.DistributedOptimizer(optimizer, static_loss_scale=1, dynamic_loss_scale=False, **dynamic_loss_args) An optimizer wrapper for saving and loading optimizer states. :param optimizer: An optimizer object. :type optimizer: object + :param static_loss_scale: Available only for FP16 training. Set to ``1`` to use static loss scale. The default value is ``1``. + :type static_loss_scale: float + :param dynamic_loss_scale: Available only for FP16 training. Set to ``True`` to use dynamic loss scale. + :type dynamic_loss_scale: boolean + :param dynamic_loss_args: Available only for FP16 training. If ``dynamic_loss_scale=True``, specify parameters for dynamic loss scale. + :type dynamic_loss_args: dict This wrapper returns ``optimizer`` with the following methods overridden: @@ -523,7 +529,7 @@ smdistributed.modelparallel.torch Context Managers and Util Functions * ``FP16_Module`` casts the model back to FP16 if FP16 training is enabled with the smp config. :type dtype: torch.dtype :param tensor_parallel_config: kwargs to specifiy other tensor parallel configs. - This is not used if tensor_parallelism is False + This is not used if ``tensor_parallelism`` is ``False`` :type tensor_parallel_config: dict **Example Usage:** @@ -580,10 +586,11 @@ smdistributed.modelparallel.torch Context Managers and Util Functions currently doesn’t work with the library. ``smdistributed.modelparallel.torch.amp.GradScaler`` replaces ``torch.amp.GradScaler`` and provides the same functionality. -.. function:: smdistributed.modelparallel.torch.delayed_parameter_initialization(enabled=True) +.. function:: smdistributed.modelparallel.torch.delay_param_initialization(enabled=True) If enabled, it delays the initialization of parameters - to save CPU memory; it initializes after the model is partitioned on GPU. + to save CPU memory. That is, parameter initialization takes place + after the model is partitioned on GPUs. .. function:: smdistributed.modelparallel.torch.get_world_process_group( ) From e29620fd1a70fd62ef4696db504ab1be0f7e8122 Mon Sep 17 00:00:00 2001 From: Miyoung Choi Date: Thu, 23 Jun 2022 17:41:57 -0700 Subject: [PATCH 16/23] add dynamic scale params, add reference --- .../latest/smd_model_parallel_pytorch.rst | 64 ++++++++++++++++--- 1 file changed, 55 insertions(+), 9 deletions(-) diff --git a/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch.rst b/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch.rst index 3340587641..4eb4c7aeff 100644 --- a/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch.rst +++ b/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch.rst @@ -397,26 +397,73 @@ smdistributed.modelparallel.torch.DistributedModel smdistributed.modelparallel.torch.DistributedOptimizer ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -.. class:: smdistributed.modelparallel.torch.DistributedOptimizer(optimizer, static_loss_scale=1, dynamic_loss_scale=False, **dynamic_loss_args) +.. class:: smdistributed.modelparallel.torch.DistributedOptimizer(optimizer, static_loss_scale=1.0, dynamic_loss_scale=False, **dynamic_loss_args) An optimizer wrapper for saving and loading optimizer states. :param optimizer: An optimizer object. :type optimizer: object - :param static_loss_scale: Available only for FP16 training. Set to ``1`` to use static loss scale. The default value is ``1``. + :param static_loss_scale: Available only for FP16 training. The default value is ``1.0``. :type static_loss_scale: float :param dynamic_loss_scale: Available only for FP16 training. Set to ``True`` to use dynamic loss scale. :type dynamic_loss_scale: boolean - :param dynamic_loss_args: Available only for FP16 training. If ``dynamic_loss_scale=True``, specify parameters for dynamic loss scale. + :param dynamic_loss_args: Available only for FP16 training. + If you set ``dynamic_loss_scale=True``, configure scale parameters for dynamic loss scale. + The following list shows available parameters. + + * ``"init_scale"``: Default is ``2**32`` + * ``"scale_factor"``: Default is ``2.`` + * ``"scale_window"``: Default is ``1000`` + * ``"min_scale"``: Default is ``1`` + * ``"delayed_shift"``: Default is ``1`` + * ``"consecutive_hysteresis"``: Default is ``False`` :type dynamic_loss_args: dict - This wrapper returns ``optimizer`` with the following methods overridden: + **Example Usage for an FP32 Optimizer:** + + .. code:: python + + optimizer = torch.optim.AdaDelta(model.parameters(), lr=4.0) + optimizer = smdistributed.modelparallel.torch.DistributedOptimizer(optimizer) + + + **Example Usage for an FP16 Optimizer:** + + .. code:: python + + optimizer = smdistributed.modelparallel.torch.DistributedOptimizer( + optimizer, + static_loss_scale=None, + dynamic_loss_scale=True, + dynamic_loss_args={ + "scale_window": 1000, + "min_scale": 1, + "delayed_shift": 2 + } + ) + + .. tip:: + + After you modify training scripts with + :class:`smdistributed.modelparallel.torch.DistributedModel` and + :class:`smdistributed.modelparallel.torch.DistributedOptimizer`, + use the SageMaker PyTorch estimator's distribution configuration o enable FP16 training. + You simply need to add ``"fp16": True`` to the ``smp_options`` config dictionary's + ``"parameters"`` key as shown in + `Using the SageMaker TensorFlow and PyTorch Estimators + `_. + For more information about available parameters for the ``smp_options`` config, + see :ref:`sm-sdk-modelparallel-general`. + + + + This wrapper returns an ``optimizer`` object with the following methods overridden: .. method:: state_dict( ) Returns the ``state_dict`` that contains optimizer state for the entire model. It first collects the ``local_state_dict`` and gathers and merges - the ``local_state_dict`` from all ``mp_rank``s to create a full + the ``local_state_dict`` from all ``mp_rank``\ s to create a full ``state_dict``. .. method:: load_state_dict( ) @@ -450,10 +497,9 @@ smdistributed.modelparallel.torch.DistributedOptimizer If you want to save optimizer and also further reduce CPU memory utilization for better performance, turn it off by setting ``gather_if_shard=False``. However, you need to make sure that you - save the states on all ``rdp_rank``. To handle both cases, + save the states on all ``rdp_rank``\ s. To handle both cases, use the following example code. - :type gather_if_shard: boolean :param fp32_states_only: Whether to return the FP32 optimizer states only. :type fp32_states_only: boolean @@ -465,7 +511,7 @@ smdistributed.modelparallel.torch.DistributedOptimizer import smdistributed.modelparallel.torch as smp # wrap optimizer - optimizer = torch.optim.Optimizer(...) + optimizer = torch.optim.AdaDelta(...) optimizer = smp.DistributedOptimizer(optimizer) # save optimizer @@ -482,7 +528,7 @@ smdistributed.modelparallel.torch.DistributedOptimizer the ``gather_if_shard`` arg is ``True`` or ``False``. If ``gather_if_shard=False``, the ``v3`` arg helps collect optimizer checkpoint files by adding ``pp_rank``, ``tp_rank``, and ``rdp_rank`` as postfix - to avoid overwriting checkpoint files. + to avoid overwriting optimizer checkpoint files. .. method:: load_optimizer_backcompat(state_dict, gather_if_shard=False) From 10ae69499ea44872e6be2e38bc73d1ffc585b579 Mon Sep 17 00:00:00 2001 From: Miyoung Choi Date: Thu, 23 Jun 2022 18:00:56 -0700 Subject: [PATCH 17/23] minor fix --- .../training/smp_versions/latest/smd_model_parallel_pytorch.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch.rst b/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch.rst index 4eb4c7aeff..e36fbb9993 100644 --- a/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch.rst +++ b/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch.rst @@ -447,7 +447,7 @@ smdistributed.modelparallel.torch.DistributedOptimizer After you modify training scripts with :class:`smdistributed.modelparallel.torch.DistributedModel` and :class:`smdistributed.modelparallel.torch.DistributedOptimizer`, - use the SageMaker PyTorch estimator's distribution configuration o enable FP16 training. + use the SageMaker PyTorch estimator's distribution configuration to enable FP16 training. You simply need to add ``"fp16": True`` to the ``smp_options`` config dictionary's ``"parameters"`` key as shown in `Using the SageMaker TensorFlow and PyTorch Estimators From 53f28866d00074c5702699f531b2c533343815aa Mon Sep 17 00:00:00 2001 From: Miyoung Choi Date: Tue, 28 Jun 2022 15:10:08 -0700 Subject: [PATCH 18/23] minor fixes --- .../latest/smd_model_parallel_pytorch.rst | 43 ++++--- ...model_parallel_pytorch_tensor_parallel.rst | 106 ++++++++++-------- 2 files changed, 90 insertions(+), 59 deletions(-) diff --git a/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch.rst b/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch.rst index e36fbb9993..520542846a 100644 --- a/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch.rst +++ b/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch.rst @@ -403,12 +403,14 @@ smdistributed.modelparallel.torch.DistributedOptimizer :param optimizer: An optimizer object. :type optimizer: object - :param static_loss_scale: Available only for FP16 training. The default value is ``1.0``. + :param static_loss_scale: Effective only for FP16 training. The default value is ``1.0``. :type static_loss_scale: float - :param dynamic_loss_scale: Available only for FP16 training. Set to ``True`` to use dynamic loss scale. + :param dynamic_loss_scale: Effective only for FP16 training. Set to ``True`` to + use dynamic loss scale. The default value is ``False``. :type dynamic_loss_scale: boolean - :param dynamic_loss_args: Available only for FP16 training. - If you set ``dynamic_loss_scale=True``, configure scale parameters for dynamic loss scale. + :param dynamic_loss_args: Effective only for FP16 training. + If ``dynamic_loss_scale=True``, you can configure additional scale + parameters for dynamic loss scale. The following list shows available parameters. * ``"init_scale"``: Default is ``2**32`` @@ -423,14 +425,24 @@ smdistributed.modelparallel.torch.DistributedOptimizer .. code:: python - optimizer = torch.optim.AdaDelta(model.parameters(), lr=4.0) + optimizer = torch.optim.AdaDelta(...) optimizer = smdistributed.modelparallel.torch.DistributedOptimizer(optimizer) + **Example Usage for an FP16 Optimizer with static loss scale:** - **Example Usage for an FP16 Optimizer:** + .. code:: python + + optimizer = torch.optim.AdaDelta(...) + optimizer = smdistributed.modelparallel.torch.DistributedOptimizer( + optimizer, + static_loss_scale=1.0 + ) + + **Example Usage for an FP16 Optimizer with dynamic loss scale:** .. code:: python + optimizer = torch.optim.AdaDelta(...) optimizer = smdistributed.modelparallel.torch.DistributedOptimizer( optimizer, static_loss_scale=None, @@ -455,8 +467,6 @@ smdistributed.modelparallel.torch.DistributedOptimizer For more information about available parameters for the ``smp_options`` config, see :ref:`sm-sdk-modelparallel-general`. - - This wrapper returns an ``optimizer`` object with the following methods overridden: .. method:: state_dict( ) @@ -555,14 +565,14 @@ smdistributed.modelparallel.torch.DistributedOptimizer smdistributed.modelparallel.torch Context Managers and Util Functions ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -.. function:: smdistributed.modelparallel.torch.model_creation(tensor_parallelism=False, dtype=None, **tensor_parallel_config) +.. function:: smdistributed.modelparallel.torch.model_creation(tensor_parallelism=False, dtype=None, distribute_embedding=False, **tensor_parallel_config) Context manager to create a ``torch`` model. This API combines both the :class:`smdistributed.modelparallel.torch.tensor_parallelism` and - ``smdistributed.modelparallel.torch.delay_param_initialization`` decorators - so user need to simply use a single context when creating the torch model. + :class:`smdistributed.modelparallel.torch.delay_param_initialization` decorators, + so you can simply use this single context when creating the torch model. - :param tensor_parallelism: Whether tensor parallel should be enabled during model creation. + :param tensor_parallelism: Whether to enable tensor parallelism during model creation. :type tensor_parallelism: boolean :param dtype: The dtype to use when creating the model. It has the following rules. @@ -572,10 +582,12 @@ smdistributed.modelparallel.torch Context Managers and Util Functions * Any model that causes out-of-memory problems with FP32 initialization is recommended to be created with :class:`smdistributed.modelparallel.torch.delayed_parameter_initialization`. - * ``FP16_Module`` casts the model back to FP16 if FP16 training is enabled with the smp config. + * ``FP16_Module`` casts the model back to FP16 if FP16 training is enabled with the ``smp`` config. :type dtype: torch.dtype + :param distribute_embedding: Whether to enable vocabulary parallelism for NLP models. + :type dtype: boolean :param tensor_parallel_config: kwargs to specifiy other tensor parallel configs. - This is not used if ``tensor_parallelism`` is ``False`` + This is not used if ``tensor_parallelism`` is ``False``. :type tensor_parallel_config: dict **Example Usage:** @@ -586,7 +598,8 @@ smdistributed.modelparallel.torch Context Managers and Util Functions with smp.model_creation( tensor_parallelism=smp.tp_size() > 1, - dtype=torch.float16 if args.fp16 else torch.get_default_dtype() + dtype=torch.float16 if args.fp16 else torch.get_default_dtype(), + distribute_embedding=False ): model = MyModel(...) diff --git a/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch_tensor_parallel.rst b/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch_tensor_parallel.rst index 7a75b8f9f3..c101e0025d 100644 --- a/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch_tensor_parallel.rst +++ b/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch_tensor_parallel.rst @@ -49,9 +49,9 @@ this by taking a look at the equivalent reference implementations in the These implementations are functionally equivalent to their distributed versions in ``smdistributed.modelparallel.torch.nn`` module. -.. decorator:: @smdistributed.modelparallel.torch.tp_register(dist_module, init_hook=None, forward_hook=None, return_hook=None) +.. class:: smdistributed.modelparallel.torch.tp_register(dist_module, init_hook=None, forward_hook=None, return_hook=None) - - A class decorator that registers the ``dist_module`` class with + - A decorator class that registers the ``dist_module`` class with the module class that it is attached to. The hooks can be used to adapt to different interfaces used with ``__init__`` and ``forward`` methods. @@ -161,16 +161,7 @@ versions in ``smdistributed.modelparallel.torch.nn`` module. Supported Modules for Tensor Parallelism ---------------------------------------- -The following modules are supported for tensor -parallelism. - -- ``smdistributed.modelparallel.torch.nn.DistributedLinear`` (implements ``nn.Linear``) -- ``smdistributed.modelparallel.torch.nn.DistributedTransformerLMHead`` -- ``smdistributed.modelparallel.torch.nn.DistributedTransformer`` -- ``smdistributed.modelparallel.torch.nn.DistributedTransformerLayer`` -- ``smdistributed.modelparallel.torch.nn.DistributedAttentionLayer`` -- ``smdistributed.modelparallel.torch.nn.DistributedTransformerOutputLayer`` -- ``smdistributed.modelparallel.torch.nn.DistributedEmbedding`` +The following modules are supported for tensor parallelism. .. contents:: Topics :depth: 3 @@ -181,14 +172,27 @@ parallelism. Tensor Parallelism Module APIs ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +- :class:`smdistributed.modelparallel.torch.nn.DistributedLinear` (implements ``nn.Linear``) +- :class:`smdistributed.modelparallel.torch.nn.DistributedTransformerLMHead` +- :class:`smdistributed.modelparallel.torch.nn.DistributedTransformer` +- :class:`smdistributed.modelparallel.torch.nn.DistributedTransformerLayer` +- :class:`smdistributed.modelparallel.torch.nn.DistributedAttentionLayer` +- :class:`smdistributed.modelparallel.torch.nn.DistributedTransformerOutputLayer` +- :class:`smdistributed.modelparallel.torch.nn.DistributedEmbedding` + .. class:: smdistributed.modelparallel.torch.nn.DistributedLinear(in_features, out_features) - - Tensor-parallel implementation of the ``nn.Linear`` class. - Functionally equivalent to an ``nn.Linear`` module with the same - ``in_features`` and ``out_features``. In other words, - ``in_features`` and ``out_features`` are the number of *global* - channels across tensor-parallel ranks. - - **Arguments:** + Tensor-parallel implementation of the ``nn.Linear`` class. + Functionally equivalent to an ``nn.Linear`` module with the same + ``in_features`` and ``out_features``. In other words, + ``in_features`` and ``out_features`` are the number of *global* + channels across tensor-parallel ranks. + + For more information about what's the reference implementation of this module, + see :ref:`smdmp-tp-appendix`. + + + - **Arguments:** - ``in_features``: The total number of input channels for the linear layer across all tensor-parallel ranks. @@ -197,21 +201,22 @@ Tensor Parallelism Module APIs .. class:: smdistributed.modelparallel.torch.nn.DistributedTransformerLMHead(num_layers=12, num_attention_heads=32, attention_head_size=32, hidden_size=1024, intermediate_size=4096, vocab_size=30522, num_positions=1024, attention_dropout_prob=0.1, hidden_dropout_prob=0.1, activation="gelu", layernorm_epsilon=1e-5, num_token_types=0, causal_mask_size=None, add_cross_attention=False, add_lm_head=True, initializer_range=0.02, use_normal_initialization=False, pre_layernorm=False, post_layernorm=True) - - Constructs a distributed transformer model, including embeddings - and a single LM head. A word embedding of size - ``(vocab_size, hidden_size)`` is created, as well as a positional - embedding of size ``(num_positions, hidden_size)``, and the - embeddings are added together. If ``num_token_types`` is larger - than 0, a separate embedding of size - ``(num_token_types, hidden_size)`` is created, and further added - on top. - - The embeddings are fed through a ``DistributedTransformer``, and - if ``add_lm_head`` is ``True``, the output passes through a single - LM head, which is a linear module without bias whose weight is - tied to the word embeddings. - - See :class:`smdistributed.modelparallel.torch.nn.DistributedTransformerLayer` for descriptions of the rest - of the arguments. - - **Methods:** + Constructs a distributed transformer model, including embeddings + and a single LM head. A word embedding of size + ``(vocab_size, hidden_size)`` is created, as well as a positional + embedding of size ``(num_positions, hidden_size)``, and the + embeddings are added together. If ``num_token_types`` is larger + than 0, a separate embedding of size + ``(num_token_types, hidden_size)`` is created, and further added + on top. + + - The embeddings are fed through a ``DistributedTransformer``, and + if ``add_lm_head`` is ``True``, the output passes through a single + LM head, which is a linear module without bias whose weight is + tied to the word embeddings. + - See :class:`smdistributed.modelparallel.torch.nn.DistributedTransformerLayer` for descriptions of the rest + of the arguments. + - **Methods:** - ``forward(self, inputs)`` @@ -229,10 +234,11 @@ Tensor Parallelism Module APIs .. class:: smdistributed.modelparallel.torch.nn.DistributedTransformer(num_layers=12, num_attention_heads=32, attention_head_size=32, hidden_size=1024, intermediate_size=4096, attention_dropout_prob=0.1, hidden_dropout_prob=0.1, activation="gelu", layernorm_epsilon=1e-5, initializer_range=0.02, use_normal_initialization=False, causal_mask_size=None, add_cross_attention=False, pre_layernorm=False, post_layernorm=True) - - A sequence of ``smdistributed.modelparallel.torch.nn.DistributedTransformerLayer``\ s, whose - number is given by ``num_layers`` argument. For the other - arguments and methods, refer to - ``smdistributed.modelparallel.torch.nn.DistributedTransformerLayer``. + A sequence of :class:`smdistributed.modelparallel.torch.nn.DistributedTransformerLayer`\ s, whose + number is given by ``num_layers`` argument. For the other + arguments and methods, refer to + :class:`smdistributed.modelparallel.torch.nn.DistributedTransformerLayer`. + - If both ``pre_layernorm`` and ``post_layernorm`` are ``True``, layer normalization is applied to both the input and the output of the ``DistributedTransformer``, in addition to the intermediate @@ -240,9 +246,13 @@ Tensor Parallelism Module APIs .. class:: smdistributed.modelparallel.torch.nn.DistributedTransformerLayer(num_attention_heads=32, attention_head_size=32, hidden_size=1024, intermediate_size=4096, attention_dropout_prob=0.1, hidden_dropout_prob=0.1, activation="gelu", layernorm_epsilon=1e-5, initializer_range=0.02, use_normal_initialization=False, causal_mask_size=None, add_cross_attention=False, pre_layernorm=False, post_layernorm=True) - - Tensor-parallel implementation of a single transformer layer. - Number of attention heads, hidden size, and intermediate size - refer to the global quantities across all tensor-parallel ranks. + Tensor-parallel implementation of a single transformer layer. + Number of attention heads, hidden size, and intermediate size + refer to the global quantities across all tensor-parallel ranks. + + For more information about what's the reference implementation of this module, + see :ref:`smdmp-tp-appendix`. + - **Arguments:** - ``num_attention_heads``: The total number of attention heads @@ -342,10 +352,14 @@ Tensor Parallelism Module APIs .. class:: smdistributed.modelparallel.torch.nn.DistributedAttentionLayer(num_attention_heads=32, attention_head_size=32, hidden_size=1024, attention_dropout_prob=0.1, hidden_dropout_prob=0.1, layernorm_epsilon=1e-5, initializer_range=0.02, use_normal_initialization=False, cross_attention=False, causal_mask_size=None, pre_layernorm=False, post_layernorm=True) - - A distributed implementation for the attention block. Includes the - computation of the self- or cross-attention (context layer), - followed by a linear mapping and dropout, which is optionally - followed by the residual-connection and layer normalization. + A distributed implementation for the attention block. Includes the + computation of the self- or cross-attention (context layer), + followed by a linear mapping and dropout, which is optionally + followed by the residual-connection and layer normalization. + + For more information about what's the reference implementation of this module, + see :ref:`smdmp-tp-appendix`. + - **Arguments:** - See :class:`smdistributed.modelparallel.torch.nn.DistributedTransformerLayer` for descriptions of the @@ -396,6 +410,10 @@ Tensor Parallelism Module APIs ``DistributedTransformerOutputLayer``. The latter linearly maps the last channel of the input tensor from ``hidden_size`` to ``intermediate_size``, and then maps it back to ``hidden_size``. + + For more information about what's the reference implementation of this module, + see :ref:`smdmp-tp-appendix`. + - **Arguments:** - See :class:`smdistributed.modelparallel.torch.nn.DistributedTransformerLayer` for descriptions of the From e5c265d91655c8bc2dc83753e6beffae1a048961 Mon Sep 17 00:00:00 2001 From: Miyoung Choi Date: Wed, 13 Jul 2022 17:19:09 -0700 Subject: [PATCH 19/23] rm temp methods --- .../latest/smd_model_parallel_pytorch.rst | 69 +------------------ 1 file changed, 1 insertion(+), 68 deletions(-) diff --git a/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch.rst b/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch.rst index 520542846a..1d3ea83337 100644 --- a/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch.rst +++ b/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch.rst @@ -493,74 +493,7 @@ smdistributed.modelparallel.torch.DistributedOptimizer a partial \ ``state_dict``, which indicates whether the ``state_dict`` contains elements corresponding to only the current partition, or to the entire model. - - .. method:: save_optimizer_backcompat(cast_to_cpu=True, gather_if_shard=True, fp32_states_only=False) - - Gets the local optimizer states and FP16 states if FP16 training is enabled. - - :param cast_to_cpu: Whether to cast the optimizer states and FP16 states to CPU. - :type cast_to_cpu: boolean - :param gather_if_shard: (for smdistributed-modelparallel v1.10 only) - Whether to gather the optimizer states and FP16 states to the 0th - ``rdp_rank`` when using the `optimizer state sharding - `_ feature. - If you want to save optimizer and also further reduce CPU memory - utilization for better performance, turn it off by setting - ``gather_if_shard=False``. However, you need to make sure that you - save the states on all ``rdp_rank``\ s. To handle both cases, - use the following example code. - - :type gather_if_shard: boolean - :param fp32_states_only: Whether to return the FP32 optimizer states only. - :type fp32_states_only: boolean - - **Example Usage:** - - .. code:: python - - import smdistributed.modelparallel.torch as smp - - # wrap optimizer - optimizer = torch.optim.AdaDelta(...) - optimizer = smp.DistributedOptimizer(optimizer) - - # save optimizer - save_dict["optimizer"] = optimizer.save_optimizer_backcompat( - gather_if_shard=args.gather_if_shard - ) - if not args.gather_if_shard or smp.rdp_rank() == 0: - smp.save( - save_dict, output_save_file, partial=True, - v3=not args.gather_if_shard - ) - - The ``v3`` argument of the ``smp.save()`` function checks whether the value of - the ``gather_if_shard`` arg is ``True`` or ``False``. - If ``gather_if_shard=False``, the ``v3`` arg helps collect optimizer checkpoint - files by adding ``pp_rank``, ``tp_rank``, and ``rdp_rank`` as postfix - to avoid overwriting optimizer checkpoint files. - - .. method:: load_optimizer_backcompat(state_dict, gather_if_shard=False) - - Loads the saved optimizer states and FP16 states if FP16 training is enabled. - - :param state_dict: The ``state_dict`` to load. - :type state_dict: dict - :param gather_if_shard: Specify whether the optimizer state was saved with ``gather_if_shard=True`` - when using the :class:`smdistributed.modelparallel.torch.DistributedOptimizer.save_optimizer_backcompat()` method. - :type gather_if_shard: boolean - - **Example Usage:** - - .. code:: python - - import smdistributed.modelparallel.torch as smp - - # load optimizer - checkpoint = smp.load(local_ckpt_path, partial=True) - optimizer.load_optimizer_backcompat( - checkpoint["optimizer"], gather_if_shard=args.gather_if_shard - ) + smdistributed.modelparallel.torch Context Managers and Util Functions ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ From ba10af3b15b25d00614f99d247cb8f932fee3c7e Mon Sep 17 00:00:00 2001 From: Miyoung Choi Date: Wed, 13 Jul 2022 21:06:03 -0700 Subject: [PATCH 20/23] add new checkpoint save/load functions, doc improvement --- .../latest/smd_model_parallel_pytorch.rst | 111 +++++++++++++++--- 1 file changed, 94 insertions(+), 17 deletions(-) diff --git a/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch.rst b/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch.rst index 1d3ea83337..cb1c37e400 100644 --- a/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch.rst +++ b/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch.rst @@ -17,7 +17,7 @@ import the ``smdistributed.modelparallel.torch`` package at the top of your trai to learn how to use the following API in your PyTorch training script. .. contents:: Topics - :depth: 3 + :depth: 1 :local: smdistributed.modelparallel.torch.DistributedModel @@ -421,14 +421,14 @@ smdistributed.modelparallel.torch.DistributedOptimizer * ``"consecutive_hysteresis"``: Default is ``False`` :type dynamic_loss_args: dict - **Example Usage for an FP32 Optimizer:** + **Example usage of an FP32 Optimizer:** .. code:: python optimizer = torch.optim.AdaDelta(...) optimizer = smdistributed.modelparallel.torch.DistributedOptimizer(optimizer) - **Example Usage for an FP16 Optimizer with static loss scale:** + **Example usage of an FP16 Optimizer with static loss scale:** .. code:: python @@ -438,7 +438,7 @@ smdistributed.modelparallel.torch.DistributedOptimizer static_loss_scale=1.0 ) - **Example Usage for an FP16 Optimizer with dynamic loss scale:** + **Example usage of an FP16 Optimizer with dynamic loss scale:** .. code:: python @@ -493,7 +493,7 @@ smdistributed.modelparallel.torch.DistributedOptimizer a partial \ ``state_dict``, which indicates whether the ``state_dict`` contains elements corresponding to only the current partition, or to the entire model. - + smdistributed.modelparallel.torch Context Managers and Util Functions ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -515,10 +515,15 @@ smdistributed.modelparallel.torch Context Managers and Util Functions * Any model that causes out-of-memory problems with FP32 initialization is recommended to be created with :class:`smdistributed.modelparallel.torch.delayed_parameter_initialization`. - * ``FP16_Module`` casts the model back to FP16 if FP16 training is enabled with the ``smp`` config. - :type dtype: torch.dtype + * ``FP16_Module`` casts the model back to FP16 if FP16 training is enabled + with the ``smp`` config. For more inforamtion about FP16 training + in SageMaker with the model parallel library, see `FP16 Training + `_ + in the *Amazon SageMaker Developer Guide*. + + :type dtype: ``torch.dtype`` :param distribute_embedding: Whether to enable vocabulary parallelism for NLP models. - :type dtype: boolean + :type distribute_embedding: boolean :param tensor_parallel_config: kwargs to specifiy other tensor parallel configs. This is not used if ``tensor_parallelism`` is ``False``. :type tensor_parallel_config: dict @@ -639,12 +644,13 @@ smdistributed.modelparallel.torch Context Managers and Util Functions .. _pytorch_saving_loading: -APIs for Saving and Loading -^^^^^^^^^^^^^^^^^^^^^^^^^^^ +smdistributed.modelparallel.torch APIs for Saving and Loading +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -.. function:: smdistributed.modelparallel.torch.save( ) +.. function:: smdistributed.modelparallel.torch.save(obj, f, partial=True, pickel_module=picklemodule, pickle_protocol=2, ) - Saves an object. This operation is similar to ``torch.save()``, except + Saves an object. This operation is similar to `torch.save() + `_, except that it has an additional keyword argument, ``partial``, and accepts only string type for the argument ``f`` (file). If ``partial=True``, each ``mp_rank`` saves a separate checkpoint file and the library adds an ``mp_rank`` @@ -663,10 +669,8 @@ APIs for Saving and Loading A module used for pickling metadata and objects. - ``pickle_protocol``  (int, default=2): Can be specified to override the defaultprotocol. - - ``v3`` (bool, default=``False``): When set to ``True``, save optimizer state checkpoints - in V3 file format to add all ``pp_rank``, ``tp_rank``, and ``rdp_rank`` as postfix. -.. function:: smdistributed.modelparallel.torch.load( ) +.. function:: smdistributed.modelparallel.torch.load(f, map_location, pickle_module, pickle_load_args, partial=True) Loads an object saved with ``smdistributed.modelparallel.torch.save()`` from a file. @@ -690,10 +694,83 @@ APIs for Saving and Loading ``mp_rank`` loads the checkpoint corresponding to the ``mp_rank``. Should be used when loading a model trained with the library. +.. function:: smdistributed.modelparallel.torch.save_checkpoint(path, tag, partial=True, model=None, optimizer=None, user_content=None, translate_if_full=True, num_kept_partial_checkpoints=None) + + Saves a checkpoint. While :class:`smdistributed.modelparallel.torch.save` saves + model and optimizer objects, + this function checkpoints model and optimizer and saves the checkpoints as separate files. + It creates checkpoint folders in the following structure. + + .. code:: text + + - path + - ${tag}_partial (folder for partial checkpoint) + - model_rankinfo.pt + - optimizer_rankinfo.pt + - fp16_states_rankinfo.pt + - user_content.pt + - $tag (checkpoint file for full checkpoint) + - user_content_$tag (user_content file for full checkpoint) + - newest (a file that indicates the newest checkpoint) + + **Parameters** + + * ``path`` (str) (required): Path to save the checkpoint. The library creates + the directory if it does not already exist. + For example, ``/opt/ml/checkpoint/model_parallel``. + * ``tag`` (str) (required): A tag for the current checkpoint, usually the train + steps. Note: tag needs to be the same across all ranks (GPU workers). + When ``partial=False`` this will be the checkpoint file name. + * ``partial`` (boolean) (default: True): Whether to save the partial checkpoint. + * ``model`` (:class:`smdistributed.modelparallel.torch.DistributedModel`) + (default: None): The model to save. It needs to an ``smp.DistributedModel`` object. + * ``optimizer`` (:class:`smdistributed.modelparallel.torch.DistributedOptimizer`) + (default: None): The optimizer to save. It needs to be an ``smp.DistributedOptimizer`` object. + * ``user_content`` (any) (default: None): User-defined content to save. + * ``translate_if_full`` (boolean) (default: True): Whether to translate the + full ``state_dict`` to HF ``state_dict`` if possible. + * ``num_kept_partial_checkpoints`` (int) (default: None): The maximum number + of partial checkpoints to keep on disk. + +.. function:: smdistributed.modelparallel.torch.resume_from_checkpoint(path, tag=None, partial=True, strict=True, load_optimizer_states=True, translate_function=None) + + While :class:`smdistributed.modelparallel.torch.load` loads saved + model and optimizer objects, this function resumes from a saved checkpoint file. + + **Parameters** + + * ``path`` (str) (required): Path to load the checkpoint. + * ``tag`` (str) (default: None): Tag of the checkpoint to resume. If not provided, + the library tries to locate the newest checkpoint from the saved newest file. + * ``partial`` (boolean) (default: True): Whether to load the partial checkpoint. + * ``strict`` (boolean) (default: True): Load with strict load, no extra key or + missing key is allowed. + * ``load_optimizer_states`` (boolean) (default: True): Whether to load ``optimizer_states``. + * ``translate_function`` (function) (default: None): function to translate the full + checkpoint into smdistributed.modelparallel format. + For supported models, this is not required. + + **Example usage** + + .. code:: python + + # Save + smp.save_checkpoint( + checkpoint_dir, + tag=f"total_steps{total_steps}", + partial=True, + model=model, + optimizer=optimizer, + user_content=user_content + num_kept_partial_checkpoints=args.num_kept_checkpoints) + + # Load: this will automatically load the newest checkpoint + user_content = smp.resume_from_checkpoint(path, partial=partial) + .. _pytorch_saving_loading_instructions: -General Instruction For Saving and Loading -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +General instruction on saving and loading +----------------------------------------- The library can save partial or full checkpoints. From 5d3eb7a822c77b04ed9cac593d784d4885291391 Mon Sep 17 00:00:00 2001 From: Miyoung Choi Date: Thu, 14 Jul 2022 10:39:16 -0700 Subject: [PATCH 21/23] pass doc8 --- .../smd_model_parallel_pytorch_tensor_parallel.rst | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch_tensor_parallel.rst b/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch_tensor_parallel.rst index c101e0025d..de7d20aaa2 100644 --- a/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch_tensor_parallel.rst +++ b/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch_tensor_parallel.rst @@ -238,11 +238,11 @@ Tensor Parallelism Module APIs number is given by ``num_layers`` argument. For the other arguments and methods, refer to :class:`smdistributed.modelparallel.torch.nn.DistributedTransformerLayer`. - - - If both ``pre_layernorm`` and ``post_layernorm`` are ``True``, - layer normalization is applied to both the input and the output of - the ``DistributedTransformer``, in addition to the intermediate - attention and transformer-output layers. + + If both ``pre_layernorm`` and ``post_layernorm`` are ``True``, + layer normalization is applied to both the input and the output of + the ``DistributedTransformer``, in addition to the intermediate + attention and transformer-output layers. .. class:: smdistributed.modelparallel.torch.nn.DistributedTransformerLayer(num_attention_heads=32, attention_head_size=32, hidden_size=1024, intermediate_size=4096, attention_dropout_prob=0.1, hidden_dropout_prob=0.1, activation="gelu", layernorm_epsilon=1e-5, initializer_range=0.02, use_normal_initialization=False, causal_mask_size=None, add_cross_attention=False, pre_layernorm=False, post_layernorm=True) From 060f68199ceeb1fbec6fd28380c78ac390e93038 Mon Sep 17 00:00:00 2001 From: Miyoung Choi Date: Thu, 14 Jul 2022 22:35:09 -0700 Subject: [PATCH 22/23] Trigger Build From d981ec3d0400efe318fcc9ca7f7ca1b45d5f034c Mon Sep 17 00:00:00 2001 From: Miyoung Choi Date: Fri, 15 Jul 2022 12:40:52 -0700 Subject: [PATCH 23/23] remove dist word embedding option --- .../smp_versions/latest/smd_model_parallel_pytorch.rst | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch.rst b/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch.rst index cb1c37e400..f6d1db6f21 100644 --- a/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch.rst +++ b/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch.rst @@ -498,7 +498,7 @@ smdistributed.modelparallel.torch.DistributedOptimizer smdistributed.modelparallel.torch Context Managers and Util Functions ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -.. function:: smdistributed.modelparallel.torch.model_creation(tensor_parallelism=False, dtype=None, distribute_embedding=False, **tensor_parallel_config) +.. function:: smdistributed.modelparallel.torch.model_creation(tensor_parallelism=False, dtype=None, **tensor_parallel_config) Context manager to create a ``torch`` model. This API combines both the :class:`smdistributed.modelparallel.torch.tensor_parallelism` and @@ -522,8 +522,6 @@ smdistributed.modelparallel.torch Context Managers and Util Functions in the *Amazon SageMaker Developer Guide*. :type dtype: ``torch.dtype`` - :param distribute_embedding: Whether to enable vocabulary parallelism for NLP models. - :type distribute_embedding: boolean :param tensor_parallel_config: kwargs to specifiy other tensor parallel configs. This is not used if ``tensor_parallelism`` is ``False``. :type tensor_parallel_config: dict @@ -536,8 +534,7 @@ smdistributed.modelparallel.torch Context Managers and Util Functions with smp.model_creation( tensor_parallelism=smp.tp_size() > 1, - dtype=torch.float16 if args.fp16 else torch.get_default_dtype(), - distribute_embedding=False + dtype=torch.float16 if args.fp16 else torch.get_default_dtype() ): model = MyModel(...)