Skip to content

Commit dfdec6a

Browse files
feature: Add support for torch_distributed distribution for Trainium training instances
1 parent 907f4ff commit dfdec6a

File tree

13 files changed

+688
-22
lines changed

13 files changed

+688
-22
lines changed

doc/amazon_sagemaker_model_building_pipeline.rst

+126
Original file line numberDiff line numberDiff line change
@@ -954,6 +954,132 @@ When model repacking is needed, :class:`sagemaker.workflow.model_step.ModelStep`
954954
955955
:class:`sagemaker.workflow.model_step.ModelStep` uses the provided inputs to automatically detect if a repack is needed. If a repack is needed, :class:`sagemaker.workflow.steps.TrainingStep` is added to the step collection for that repack. Then, either :class:`sagemaker.workflow.steps.CreateModelStep` or :class:`sagemaker.workflow.step_collections.RegisterModelStep` will be chained after it.
956956
957+
MonitorBatchTransform Step
958+
===========================
959+
960+
MonitorBatchTransformStep is a new step type that allows customers to use SageMaker Model Monitor with batch transform jobs that are a part of their pipeline. Using this step, customers can set up the following monitors for their batch transform job: data quality, model quality, model bias, and feature attribution.
961+
962+
963+
When configuring this step, customers have the flexibility to run the monitoring job before or after the transform job executes. There is an additional flag called :code:`fail_on_violation` which will fail the step if set to true and there is a monitoring violation, or will continue to execute the step if set to false.
964+
965+
Here is an example showing you how to configure a :class:`sagemaker.workflow.monitor_batch_transform_step.MonitorBatchTransformStep` with a Data Quality monitor.
966+
967+
.. code-block:: python
968+
969+
from sagemaker.workflow.pipeline_context import PipelineSession
970+
971+
from sagemaker.transformer import Transformer
972+
from sagemaker.model_monitor import DefaultModelMonitor
973+
from sagemaker.model_monitor.dataset_format import DatasetFormat
974+
from sagemaker.workflow.check_job_config import CheckJobConfig
975+
from sagemaker.workflow.quality_check_step import DataQualityCheckConfig
976+
977+
from sagemaker.workflow.parameters import ParameterString
978+
979+
pipeline_session = PipelineSession()
980+
981+
transform_input_param = ParameterString(
982+
name="transform_input",
983+
default_value=f"s3://my-bucket/my-prefix/my-transform-input",
984+
)
985+
986+
# the resource configuration for the monitoring job
987+
job_config = CheckJobConfig(
988+
role=role,
989+
instance_count=1,
990+
instance_type="ml.m5.xlarge",
991+
...
992+
)
993+
994+
The following code sample demonstrates how to set up an on-demand batch transform *data quality* monitor:
995+
996+
.. code-block:: python
997+
998+
# configure your transformer
999+
transformer = Transformer(..., sagemaker_session=pipeline_session)
1000+
transform_arg = transformer.transform(
1001+
transform_input_param,
1002+
content_type="text/csv",
1003+
split_type="Line",
1004+
...
1005+
)
1006+
1007+
data_quality_config = DataQualityCheckConfig(
1008+
baseline_dataset=transform_input_param,
1009+
dataset_format=DatasetFormat.csv(header=False),
1010+
output_s3_uri="s3://my-report-path",
1011+
)
1012+
1013+
from sagemaker.workflow.monitor_batch_transform_step import MonitorBatchTransformStep
1014+
1015+
transform_and_monitor_step = MonitorBatchTransformStep(
1016+
name="MyMonitorBatchTransformStep",
1017+
transform_step_args=transform_arg,
1018+
monitor_configuration=data_quality_config,
1019+
check_job_configuration=job_config,
1020+
# since data quality only looks at the inputs,
1021+
# so there is no need to wait for the transform output.
1022+
monitor_before_transform=True,
1023+
# if violation is detected in the monitoring, and you want to skip it
1024+
# and continue running batch transform, you can set fail_on_violation
1025+
# to false.
1026+
fail_on_violation=False,
1027+
supplied_baseline_statistics="s3://my-baseline-statistics.json",
1028+
supplied_baseline_constraints="s3://my-baseline-constraints.json",
1029+
)
1030+
...
1031+
1032+
The same example can be extended for model quality, bias, and feature attribute monitoring.
1033+
1034+
.. warning::
1035+
Note that to run on-demand model quality, you will need to have the ground truth data ready. When running the transform job, include the ground truth inside your transform input, and join the transform inference input and output. Then you can indicate which attribute or column name/index points to the ground truth when run the monitoring job.
1036+
1037+
.. code-block:: python
1038+
1039+
transformer = Transformer(..., sagemaker_session=pipeline_session)
1040+
1041+
transform_arg = transformer.transform(
1042+
transform_input_param,
1043+
content_type="text/csv",
1044+
split_type="Line",
1045+
# Note that we need to join both the inference input and output
1046+
# into transform outputs. The inference input needs to have the ground truth.
1047+
# details can be found here
1048+
# https://docs.aws.amazon.com/sagemaker/latest/dg/batch-transform-data-processing.html
1049+
join_source="Input",
1050+
# We need to exclude the ground truth inside the inference input
1051+
# before passing it to the prediction model.
1052+
# Assume the first column of our csv file is the ground truth
1053+
input_filter="$[1:]",
1054+
...
1055+
)
1056+
1057+
model_quality_config = ModelQualityCheckConfig(
1058+
baseline_dataset=transformer.output_path,
1059+
problem_type="BinaryClassification",
1060+
dataset_format=DatasetFormat.csv(header=False),
1061+
output_s3_uri="s3://my-output",
1062+
# assume the model output is at column idx 10
1063+
inference_attribute="_c10",
1064+
# As pointed out previously, the first column is the ground truth.
1065+
ground_truth_attribute="_c0",
1066+
)
1067+
from sagemaker.workflow.monitor_batch_transform_step import MonitorBatchTransformStep
1068+
1069+
transform_and_monitor_step = MonitorBatchTransformStep(
1070+
name="MyMonitorBatchTransformStep",
1071+
transform_step_args=transform_arg,
1072+
monitor_configuration=data_quality_config,
1073+
check_job_configuration=job_config,
1074+
# model quality job needs the transform outputs, therefore
1075+
# monitor_before_transform can not be true for model quality
1076+
monitor_before_transform=False,
1077+
fail_on_violation=True,
1078+
supplied_baseline_statistics="s3://my-baseline-statistics.json",
1079+
supplied_baseline_constraints="s3://my-baseline-constraints.json",
1080+
)
1081+
...
1082+
9571083
=================
9581084
Example Notebooks
9591085
=================

doc/api/training/smp_versions/latest/smd_model_parallel_pytorch.rst

+11-2
Original file line numberDiff line numberDiff line change
@@ -729,7 +729,7 @@ smdistributed.modelparallel.torch APIs for Saving and Loading
729729
* ``num_kept_partial_checkpoints`` (int) (default: None): The maximum number
730730
of partial checkpoints to keep on disk.
731731

732-
.. function:: smdistributed.modelparallel.torch.resume_from_checkpoint(path, tag=None, partial=True, strict=True, load_optimizer_states=True, translate_function=None)
732+
.. function:: smdistributed.modelparallel.torch.resume_from_checkpoint(path, tag=None, partial=True, strict=True, load_optimizer=True, load_sharded_optimizer_state=True, translate_function=None)
733733

734734
While :class:`smdistributed.modelparallel.torch.load` loads saved
735735
model and optimizer objects, this function resumes from a saved checkpoint file.
@@ -742,7 +742,16 @@ smdistributed.modelparallel.torch APIs for Saving and Loading
742742
* ``partial`` (boolean) (default: True): Whether to load the partial checkpoint.
743743
* ``strict`` (boolean) (default: True): Load with strict load, no extra key or
744744
missing key is allowed.
745-
* ``load_optimizer_states`` (boolean) (default: True): Whether to load ``optimizer_states``.
745+
* ``load_optimizer`` (boolean) (default: True): Whether to load ``optimizer``.
746+
* ``load_sharded_optimizer_state`` (boolean) (default: True): Whether to load
747+
the sharded optimizer state of a model.
748+
It can be used only when you activate
749+
the `sharded data parallelism
750+
<https://docs.aws.amazon.com/sagemaker/latest/dg/model-parallel-extended-features-pytorch-sharded-data-parallelism.html>`_
751+
feature of the SageMaker model parallel library.
752+
When this is ``False``, the library only loads the FP16
753+
states, such as FP32 master parameters and the loss scaling factor,
754+
not the sharded optimizer states.
746755
* ``translate_function`` (function) (default: None): function to translate the full
747756
checkpoint into smdistributed.modelparallel format.
748757
For supported models, this is not required.

doc/frameworks/pytorch/using_pytorch.rst

+101-1
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,8 @@ during the PyTorch DDP initialization.
262262

263263
.. note::
264264

265-
The SageMaker PyTorch estimator doesn’t use ``torchrun`` for distributed training.
265+
The SageMaker PyTorch estimator operates ``mpirun`` in the backend.
266+
It doesn’t use ``torchrun`` for distributed training.
266267

267268
For more information about setting up PyTorch DDP in your training script,
268269
see `Getting Started with Distributed Data Parallel
@@ -292,7 +293,106 @@ using two ``ml.p4d.24xlarge`` instances:
292293
293294
pt_estimator.fit("s3://bucket/path/to/training/data")
294295
296+
.. _distributed-pytorch-training-on-trainium:
295297

298+
Distributed PyTorch Training on Trainium
299+
========================================
300+
301+
SageMaker Training on Trainium instances now supports the ``xla``
302+
package through ``torchrun``. With this, you do not need to manually pass RANK,
303+
WORLD_SIZE, MASTER_ADDR, and MASTER_PORT. You can launch the training job using the
304+
:class:`sagemaker.pytorch.estimator.PyTorch` estimator class
305+
with the ``torch_distributed`` option as the distribution strategy.
306+
307+
.. note::
308+
309+
This ``torch_distributed`` support is available
310+
in the SageMaker Trainium (trn1) PyTorch Deep Learning Containers starting v1.11.0.
311+
To find a complete list of supported versions of PyTorch Neuron, see `Neuron Containers <https://github.com/aws/deep-learning-containers/blob/master/available_images.md#neuron-containers>`_ in the *AWS Deep Learning Containers GitHub repository*.
312+
313+
SageMaker Debugger and Profiler are currently not supported with Trainium instances.
314+
315+
Adapt Your Training Script to Initialize with the XLA backend
316+
-------------------------------------------------------------
317+
318+
To initialize distributed training in your script, call
319+
`torch.distributed.init_process_group
320+
<https://pytorch.org/docs/master/distributed.html#torch.distributed.init_process_group>`_
321+
with the ``xla`` backend as shown below.
322+
323+
.. code:: python
324+
325+
import torch.distributed as dist
326+
327+
dist.init_process_group('xla')
328+
329+
SageMaker takes care of ``'MASTER_ADDR'`` and ``'MASTER_PORT'`` for you via ``torchrun``
330+
331+
For detailed documentation about modifying your training script for Trainium, see `Multi-worker data-parallel MLP training using torchrun <https://awsdocs-neuron.readthedocs-hosted.com/en/latest/frameworks/torch/torch-neuronx/tutorials/training/mlp.html?highlight=torchrun#multi-worker-data-parallel-mlp-training-using-torchrun>`_ in the *AWS Neuron Documentation*.
332+
333+
**Currently Supported backends:**
334+
335+
- ``xla`` for Trainium (Trn1) instances
336+
337+
For up-to-date information on supported backends for Trainium instances, see `AWS Neuron Documentation <https://awsdocs-neuron.readthedocs-hosted.com/en/latest/index.html>`_.
338+
339+
Launching a Distributed Training Job on Trainium
340+
------------------------------------------------
341+
342+
You can run multi-node distributed PyTorch training jobs on Trainium instances using the
343+
:class:`sagemaker.pytorch.estimator.PyTorch` estimator class.
344+
With ``instance_count=1``, the estimator submits a
345+
single-node training job to SageMaker; with ``instance_count`` greater
346+
than one, a multi-node training job is launched.
347+
348+
With the ``torch_distributed`` option, the SageMaker PyTorch estimator runs a SageMaker
349+
training container for PyTorch Neuron, sets up the environment, and launches
350+
the training job using the ``torchrun`` command on each worker with the given information.
351+
352+
**Examples**
353+
354+
The following examples show how to run a PyTorch training using ``torch_distributed`` in SageMaker
355+
on one ``ml.trn1.2xlarge`` instance and two ``ml.trn1.32xlarge`` instances:
356+
357+
.. code:: python
358+
359+
from sagemaker.pytorch import PyTorch
360+
361+
pt_estimator = PyTorch(
362+
entry_point="train_ptddp.py",
363+
role="SageMakerRole",
364+
framework_version="1.11.0",
365+
py_version="py38",
366+
instance_count=1,
367+
instance_type="ml.trn1.2xlarge",
368+
distribution={
369+
"torch_distributed": {
370+
"enabled": True
371+
}
372+
}
373+
)
374+
375+
pt_estimator.fit("s3://bucket/path/to/training/data")
376+
377+
.. code:: python
378+
379+
from sagemaker.pytorch import PyTorch
380+
381+
pt_estimator = PyTorch(
382+
entry_point="train_ptddp.py",
383+
role="SageMakerRole",
384+
framework_version="1.11.0",
385+
py_version="py38",
386+
instance_count=2,
387+
instance_type="ml.trn1.32xlarge",
388+
distribution={
389+
"torch_distributed": {
390+
"enabled": True
391+
}
392+
}
393+
)
394+
395+
pt_estimator.fit("s3://bucket/path/to/training/data")
296396
297397
*********************
298398
Deploy PyTorch Models

doc/requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@ sphinx-rtd-theme==0.5.0
33
docutils==0.15.2
44
packaging==20.9
55
jinja2<3.1
6+
schema==0.7.5

doc/workflows/pipelines/sagemaker.workflow.pipelines.rst

+2
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,8 @@ Step Collections
132132

133133
.. autoclass:: sagemaker.workflow.model_step.ModelStep
134134

135+
.. autoclass:: sagemaker.workflow.monitor_batch_transform_step.MonitorBatchTransformStep
136+
135137
Steps
136138
-----
137139

requirements/extras/test_requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ contextlib2==21.6.0
1111
awslogs==0.14.0
1212
black==22.3.0
1313
stopit==1.1.2
14-
apache-airflow==2.4.0
14+
apache-airflow==2.4.1
1515
apache-airflow-providers-amazon==4.0.0
1616
attrs==22.1.0
1717
fabric==2.6.0

0 commit comments

Comments
 (0)