Skip to content

Commit 98599e4

Browse files
mchoi8739mufiAmazon
authored andcommitted
documentation: fix kwargs and descriptions of the smdmp checkpoint function (aws#3410)
1 parent 4ee30e2 commit 98599e4

File tree

2 files changed

+12
-2
lines changed

2 files changed

+12
-2
lines changed

doc/api/training/smp_versions/latest/smd_model_parallel_pytorch.rst

+11-2
Original file line numberDiff line numberDiff line change
@@ -729,7 +729,7 @@ smdistributed.modelparallel.torch APIs for Saving and Loading
729729
* ``num_kept_partial_checkpoints`` (int) (default: None): The maximum number
730730
of partial checkpoints to keep on disk.
731731

732-
.. function:: smdistributed.modelparallel.torch.resume_from_checkpoint(path, tag=None, partial=True, strict=True, load_optimizer_states=True, translate_function=None)
732+
.. function:: smdistributed.modelparallel.torch.resume_from_checkpoint(path, tag=None, partial=True, strict=True, load_optimizer=True, load_sharded_optimizer_state=True, translate_function=None)
733733

734734
While :class:`smdistributed.modelparallel.torch.load` loads saved
735735
model and optimizer objects, this function resumes from a saved checkpoint file.
@@ -742,7 +742,16 @@ smdistributed.modelparallel.torch APIs for Saving and Loading
742742
* ``partial`` (boolean) (default: True): Whether to load the partial checkpoint.
743743
* ``strict`` (boolean) (default: True): Load with strict load, no extra key or
744744
missing key is allowed.
745-
* ``load_optimizer_states`` (boolean) (default: True): Whether to load ``optimizer_states``.
745+
* ``load_optimizer`` (boolean) (default: True): Whether to load ``optimizer``.
746+
* ``load_sharded_optimizer_state`` (boolean) (default: True): Whether to load
747+
the sharded optimizer state of a model.
748+
It can be used only when you activate
749+
the `sharded data parallelism
750+
<https://docs.aws.amazon.com/sagemaker/latest/dg/model-parallel-extended-features-pytorch-sharded-data-parallelism.html>`_
751+
feature of the SageMaker model parallel library.
752+
When this is ``False``, the library only loads the FP16
753+
states, such as FP32 master parameters and the loss scaling factor,
754+
not the sharded optimizer states.
746755
* ``translate_function`` (function) (default: None): function to translate the full
747756
checkpoint into smdistributed.modelparallel format.
748757
For supported models, this is not required.

doc/requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@ sphinx-rtd-theme==0.5.0
33
docutils==0.15.2
44
packaging==20.9
55
jinja2<3.1
6+
schema

0 commit comments

Comments
 (0)