@@ -729,7 +729,7 @@ smdistributed.modelparallel.torch APIs for Saving and Loading
729
729
* ``num_kept_partial_checkpoints `` (int) (default: None): The maximum number
730
730
of partial checkpoints to keep on disk.
731
731
732
- .. function :: smdistributed.modelparallel.torch.resume_from_checkpoint(path, tag=None, partial=True, strict=True, load_optimizer_states =True, translate_function=None)
732
+ .. function :: smdistributed.modelparallel.torch.resume_from_checkpoint(path, tag=None, partial=True, strict=True, load_optimizer=True, load_sharded_optimizer_state =True, translate_function=None)
733
733
734
734
While :class: `smdistributed.modelparallel.torch.load ` loads saved
735
735
model and optimizer objects, this function resumes from a saved checkpoint file.
@@ -742,7 +742,16 @@ smdistributed.modelparallel.torch APIs for Saving and Loading
742
742
* ``partial `` (boolean) (default: True): Whether to load the partial checkpoint.
743
743
* ``strict `` (boolean) (default: True): Load with strict load, no extra key or
744
744
missing key is allowed.
745
- * ``load_optimizer_states `` (boolean) (default: True): Whether to load ``optimizer_states ``.
745
+ * ``load_optimizer `` (boolean) (default: True): Whether to load ``optimizer ``.
746
+ * ``load_sharded_optimizer_state `` (boolean) (default: True): Whether to load
747
+ the sharded optimizer state of a model.
748
+ It can be used only when you activate
749
+ the `sharded data parallelism
750
+ <https://docs.aws.amazon.com/sagemaker/latest/dg/model-parallel-extended-features-pytorch-sharded-data-parallelism.html> `_
751
+ feature of the SageMaker model parallel library.
752
+ When this is ``False ``, the library only loads the FP16
753
+ states, such as FP32 master parameters and the loss scaling factor,
754
+ not the sharded optimizer states.
746
755
* ``translate_function `` (function) (default: None): function to translate the full
747
756
checkpoint into smdistributed.modelparallel format.
748
757
For supported models, this is not required.
0 commit comments