@@ -498,7 +498,7 @@ smdistributed.modelparallel.torch.DistributedOptimizer
498
498
smdistributed.modelparallel.torch Context Managers and Util Functions
499
499
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
500
500
501
- .. function:: smdistributed.modelparallel.torch.model_creation(tensor_parallelism=False, dtype=None, distribute_embedding=False, **tensor_parallel_config)
501
+ .. function:: smdistributed.modelparallel.torch.model_creation(tensor_parallelism=False, dtype=None, **tensor_parallel_config)
502
502
503
503
Context manager to create a ``torch `` model. This API combines both the
504
504
:class: `smdistributed.modelparallel.torch.tensor_parallelism ` and
@@ -522,8 +522,6 @@ smdistributed.modelparallel.torch Context Managers and Util Functions
522
522
in the *Amazon SageMaker Developer Guide *.
523
523
524
524
:type dtype: ``torch.dtype ``
525
- :param distribute_embedding: Whether to enable vocabulary parallelism for NLP models.
526
- :type distribute_embedding: boolean
527
525
:param tensor_parallel_config: kwargs to specifiy other tensor parallel configs.
528
526
This is not used if ``tensor_parallelism `` is ``False ``.
529
527
:type tensor_parallel_config: dict
@@ -536,8 +534,7 @@ smdistributed.modelparallel.torch Context Managers and Util Functions
536
534
537
535
with smp.model_creation(
538
536
tensor_parallelism = smp.tp_size() > 1 ,
539
- dtype = torch.float16 if args.fp16 else torch.get_default_dtype(),
540
- distribute_embedding = False
537
+ dtype = torch.float16 if args.fp16 else torch.get_default_dtype()
541
538
):
542
539
model = MyModel(... )
543
540
0 commit comments