From bbb46b1ecdd00da03363dd893adc2de39f2cbbfe Mon Sep 17 00:00:00 2001 From: Miyoung Choi Date: Wed, 12 Oct 2022 12:46:44 -0700 Subject: [PATCH 1/3] fix kwargs and descriptions --- .../latest/smd_model_parallel_pytorch.rst | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch.rst b/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch.rst index f6d1db6f21..2c2db7c5e2 100644 --- a/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch.rst +++ b/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch.rst @@ -729,7 +729,7 @@ smdistributed.modelparallel.torch APIs for Saving and Loading * ``num_kept_partial_checkpoints`` (int) (default: None): The maximum number of partial checkpoints to keep on disk. -.. function:: smdistributed.modelparallel.torch.resume_from_checkpoint(path, tag=None, partial=True, strict=True, load_optimizer_states=True, translate_function=None) +.. function:: smdistributed.modelparallel.torch.resume_from_checkpoint(path, tag=None, partial=True, strict=True, load_optimizer=True, load_optimizer_states=True, translate_function=None) While :class:`smdistributed.modelparallel.torch.load` loads saved model and optimizer objects, this function resumes from a saved checkpoint file. @@ -742,7 +742,16 @@ smdistributed.modelparallel.torch APIs for Saving and Loading * ``partial`` (boolean) (default: True): Whether to load the partial checkpoint. * ``strict`` (boolean) (default: True): Load with strict load, no extra key or missing key is allowed. - * ``load_optimizer_states`` (boolean) (default: True): Whether to load ``optimizer_states``. + * ``load_optimizer`` (boolean) (default: True): Whether to load ``optimizer``. + * ``load_sharded_optimizer_state`` (boolean) (default: True): Whether to load + the sharded optimizer state of a model. + It can be used only when you activate + the `sharded data parallelism + `_ + feature of the SageMaker model parallel library. + When this is ``False``, the library only loads the FP16 + states, such as FP32 master parameters and the loss scaling factor, + not the sharded optimizer states. * ``translate_function`` (function) (default: None): function to translate the full checkpoint into smdistributed.modelparallel format. For supported models, this is not required. From b3bdff0519cd3cb9e08116a8daa779d0ad37960b Mon Sep 17 00:00:00 2001 From: Miyoung Choi Date: Wed, 12 Oct 2022 13:05:07 -0700 Subject: [PATCH 2/3] fix arg --- .../training/smp_versions/latest/smd_model_parallel_pytorch.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch.rst b/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch.rst index 2c2db7c5e2..a12f1108e4 100644 --- a/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch.rst +++ b/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch.rst @@ -729,7 +729,7 @@ smdistributed.modelparallel.torch APIs for Saving and Loading * ``num_kept_partial_checkpoints`` (int) (default: None): The maximum number of partial checkpoints to keep on disk. -.. function:: smdistributed.modelparallel.torch.resume_from_checkpoint(path, tag=None, partial=True, strict=True, load_optimizer=True, load_optimizer_states=True, translate_function=None) +.. function:: smdistributed.modelparallel.torch.resume_from_checkpoint(path, tag=None, partial=True, strict=True, load_optimizer=True, load_sharded_optimizer_state=True, translate_function=None) While :class:`smdistributed.modelparallel.torch.load` loads saved model and optimizer objects, this function resumes from a saved checkpoint file. From e25d718c6f00653403febb4bfc78f858117bdf98 Mon Sep 17 00:00:00 2001 From: Miyoung Choi Date: Wed, 12 Oct 2022 13:30:04 -0700 Subject: [PATCH 3/3] add schema to requirements.txt for sphinx --- doc/requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/requirements.txt b/doc/requirements.txt index 21c94775d5..f8490ee933 100644 --- a/doc/requirements.txt +++ b/doc/requirements.txt @@ -3,3 +3,4 @@ sphinx-rtd-theme==0.5.0 docutils==0.15.2 packaging==20.9 jinja2<3.1 +schema