@@ -293,6 +293,121 @@ using two ``ml.p4d.24xlarge`` instances:
293
293
294
294
pt_estimator.fit(" s3://bucket/path/to/training/data" )
295
295
296
+ .. _distributed-pytorch-training-on-trainium :
297
+
298
+ Distributed Training with PyTorch Neuron on Trn1 instances
299
+ ==========================================================
300
+
301
+ SageMaker Training supports Amazon EC2 Trn1 instances powered by
302
+ `AWS Trainium <https://aws.amazon.com/machine-learning/trainium/ >`_ device,
303
+ the second generation purpose-built machine learning accelerator from AWS.
304
+ Each Trn1 instance consists of up to 16 Trainium devices, and each
305
+ Trainium device consists of two `NeuronCores
306
+ <https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/arch/neuron-hardware/trn1-arch.html#trainium-architecture> `_
307
+ in the *AWS Neuron Documentation *.
308
+
309
+ You can run distributed training job on Trn1 instances.
310
+ SageMaker supports the ``xla `` package through ``torchrun ``.
311
+ With this, you do not need to manually pass ``RANK ``,
312
+ ``WORLD_SIZE ``, ``MASTER_ADDR ``, and ``MASTER_PORT ``.
313
+ You can launch the training job using the
314
+ :class: `sagemaker.pytorch.estimator.PyTorch ` estimator class
315
+ with the ``torch_distributed `` option as the distribution strategy.
316
+
317
+ .. note ::
318
+
319
+ This ``torch_distributed `` support is available
320
+ in the AWS Deep Learning Containers for PyTorch Neuron starting v1.11.0.
321
+ To find a complete list of supported versions of PyTorch Neuron, see
322
+ `Neuron Containers <https://github.com/aws/deep-learning-containers/blob/master/available_images.md#neuron-containers >`_
323
+ in the *AWS Deep Learning Containers GitHub repository *.
324
+
325
+ .. note ::
326
+
327
+ SageMaker Debugger is currently not supported with Trn1 instances.
328
+
329
+ Adapt Your Training Script to Initialize with the XLA backend
330
+ -------------------------------------------------------------
331
+
332
+ To initialize distributed training in your script, call
333
+ `torch.distributed.init_process_group
334
+ <https://pytorch.org/docs/master/distributed.html#torch.distributed.init_process_group> `_
335
+ with the ``xla `` backend as shown below.
336
+
337
+ .. code :: python
338
+
339
+ import torch.distributed as dist
340
+
341
+ dist.init_process_group(' xla' )
342
+
343
+ SageMaker takes care of ``'MASTER_ADDR' `` and ``'MASTER_PORT' `` for you via ``torchrun ``
344
+
345
+ For detailed documentation about modifying your training script for Trainium, see `Multi-worker data-parallel MLP training using torchrun <https://awsdocs-neuron.readthedocs-hosted.com/en/latest/frameworks/torch/torch-neuronx/tutorials/training/mlp.html?highlight=torchrun#multi-worker-data-parallel-mlp-training-using-torchrun >`_ in the *AWS Neuron Documentation *.
346
+
347
+ **Currently Supported backends: **
348
+
349
+ - ``xla `` for Trainium (Trn1) instances
350
+
351
+ For up-to-date information on supported backends for Trn1 instances, see `AWS Neuron Documentation <https://awsdocs-neuron.readthedocs-hosted.com/en/latest/index.html >`_.
352
+
353
+ Launching a Distributed Training Job on Trainium
354
+ ------------------------------------------------
355
+
356
+ You can run multi-node distributed PyTorch training jobs on Trn1 instances using the
357
+ :class: `sagemaker.pytorch.estimator.PyTorch ` estimator class.
358
+ With ``instance_count=1 ``, the estimator submits a
359
+ single-node training job to SageMaker; with ``instance_count `` greater
360
+ than one, a multi-node training job is launched.
361
+
362
+ With the ``torch_distributed `` option, the SageMaker PyTorch estimator runs a SageMaker
363
+ training container for PyTorch Neuron, sets up the environment, and launches
364
+ the training job using the ``torchrun `` command on each worker with the given information.
365
+
366
+ **Examples **
367
+
368
+ The following examples show how to run a PyTorch training using ``torch_distributed `` in SageMaker
369
+ on one ``ml.trn1.2xlarge `` instance and two ``ml.trn1.32xlarge `` instances:
370
+
371
+ .. code :: python
372
+
373
+ from sagemaker.pytorch import PyTorch
374
+
375
+ pt_estimator = PyTorch(
376
+ entry_point = " train_torch_distributed.py" ,
377
+ role = " SageMakerRole" ,
378
+ framework_version = " 1.11.0" ,
379
+ py_version = " py38" ,
380
+ instance_count = 1 ,
381
+ instance_type = " ml.trn1.2xlarge" ,
382
+ distribution = {
383
+ " torch_distributed" : {
384
+ " enabled" : True
385
+ }
386
+ }
387
+ )
388
+
389
+ pt_estimator.fit(" s3://bucket/path/to/training/data" )
390
+
391
+ .. code :: python
392
+
393
+ from sagemaker.pytorch import PyTorch
394
+
395
+ pt_estimator = PyTorch(
396
+ entry_point = " train_torch_distributed.py" ,
397
+ role = " SageMakerRole" ,
398
+ framework_version = " 1.11.0" ,
399
+ py_version = " py38" ,
400
+ instance_count = 2 ,
401
+ instance_type = " ml.trn1.32xlarge" ,
402
+ distribution = {
403
+ " torch_distributed" : {
404
+ " enabled" : True
405
+ }
406
+ }
407
+ )
408
+
409
+ pt_estimator.fit(" s3://bucket/path/to/training/data" )
410
+
296
411
*********************
297
412
Deploy PyTorch Models
298
413
*********************
0 commit comments