Skip to content

Commit c56e1e1

Browse files
knikuremchoi8739
authored andcommitted
feature: Trainium Neuron support for PyTorch (aws#3423)
Co-authored-by: Miyoung Choi <[email protected]>
1 parent 995da32 commit c56e1e1

File tree

6 files changed

+284
-2
lines changed

6 files changed

+284
-2
lines changed

doc/frameworks/pytorch/using_pytorch.rst

+115
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,121 @@ using two ``ml.p4d.24xlarge`` instances:
293293
294294
pt_estimator.fit("s3://bucket/path/to/training/data")
295295
296+
.. _distributed-pytorch-training-on-trainium:
297+
298+
Distributed Training with PyTorch Neuron on Trn1 instances
299+
==========================================================
300+
301+
SageMaker Training supports Amazon EC2 Trn1 instances powered by
302+
`AWS Trainium <https://aws.amazon.com/machine-learning/trainium/>`_ device,
303+
the second generation purpose-built machine learning accelerator from AWS.
304+
Each Trn1 instance consists of up to 16 Trainium devices, and each
305+
Trainium device consists of two `NeuronCores
306+
<https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/arch/neuron-hardware/trn1-arch.html#trainium-architecture>`_
307+
in the *AWS Neuron Documentation*.
308+
309+
You can run distributed training job on Trn1 instances.
310+
SageMaker supports the ``xla`` package through ``torchrun``.
311+
With this, you do not need to manually pass ``RANK``,
312+
``WORLD_SIZE``, ``MASTER_ADDR``, and ``MASTER_PORT``.
313+
You can launch the training job using the
314+
:class:`sagemaker.pytorch.estimator.PyTorch` estimator class
315+
with the ``torch_distributed`` option as the distribution strategy.
316+
317+
.. note::
318+
319+
This ``torch_distributed`` support is available
320+
in the AWS Deep Learning Containers for PyTorch Neuron starting v1.11.0.
321+
To find a complete list of supported versions of PyTorch Neuron, see
322+
`Neuron Containers <https://github.com/aws/deep-learning-containers/blob/master/available_images.md#neuron-containers>`_
323+
in the *AWS Deep Learning Containers GitHub repository*.
324+
325+
.. note::
326+
327+
SageMaker Debugger is currently not supported with Trn1 instances.
328+
329+
Adapt Your Training Script to Initialize with the XLA backend
330+
-------------------------------------------------------------
331+
332+
To initialize distributed training in your script, call
333+
`torch.distributed.init_process_group
334+
<https://pytorch.org/docs/master/distributed.html#torch.distributed.init_process_group>`_
335+
with the ``xla`` backend as shown below.
336+
337+
.. code:: python
338+
339+
import torch.distributed as dist
340+
341+
dist.init_process_group('xla')
342+
343+
SageMaker takes care of ``'MASTER_ADDR'`` and ``'MASTER_PORT'`` for you via ``torchrun``
344+
345+
For detailed documentation about modifying your training script for Trainium, see `Multi-worker data-parallel MLP training using torchrun <https://awsdocs-neuron.readthedocs-hosted.com/en/latest/frameworks/torch/torch-neuronx/tutorials/training/mlp.html?highlight=torchrun#multi-worker-data-parallel-mlp-training-using-torchrun>`_ in the *AWS Neuron Documentation*.
346+
347+
**Currently Supported backends:**
348+
349+
- ``xla`` for Trainium (Trn1) instances
350+
351+
For up-to-date information on supported backends for Trn1 instances, see `AWS Neuron Documentation <https://awsdocs-neuron.readthedocs-hosted.com/en/latest/index.html>`_.
352+
353+
Launching a Distributed Training Job on Trainium
354+
------------------------------------------------
355+
356+
You can run multi-node distributed PyTorch training jobs on Trn1 instances using the
357+
:class:`sagemaker.pytorch.estimator.PyTorch` estimator class.
358+
With ``instance_count=1``, the estimator submits a
359+
single-node training job to SageMaker; with ``instance_count`` greater
360+
than one, a multi-node training job is launched.
361+
362+
With the ``torch_distributed`` option, the SageMaker PyTorch estimator runs a SageMaker
363+
training container for PyTorch Neuron, sets up the environment, and launches
364+
the training job using the ``torchrun`` command on each worker with the given information.
365+
366+
**Examples**
367+
368+
The following examples show how to run a PyTorch training using ``torch_distributed`` in SageMaker
369+
on one ``ml.trn1.2xlarge`` instance and two ``ml.trn1.32xlarge`` instances:
370+
371+
.. code:: python
372+
373+
from sagemaker.pytorch import PyTorch
374+
375+
pt_estimator = PyTorch(
376+
entry_point="train_torch_distributed.py",
377+
role="SageMakerRole",
378+
framework_version="1.11.0",
379+
py_version="py38",
380+
instance_count=1,
381+
instance_type="ml.trn1.2xlarge",
382+
distribution={
383+
"torch_distributed": {
384+
"enabled": True
385+
}
386+
}
387+
)
388+
389+
pt_estimator.fit("s3://bucket/path/to/training/data")
390+
391+
.. code:: python
392+
393+
from sagemaker.pytorch import PyTorch
394+
395+
pt_estimator = PyTorch(
396+
entry_point="train_torch_distributed.py",
397+
role="SageMakerRole",
398+
framework_version="1.11.0",
399+
py_version="py38",
400+
instance_count=2,
401+
instance_type="ml.trn1.32xlarge",
402+
distribution={
403+
"torch_distributed": {
404+
"enabled": True
405+
}
406+
}
407+
)
408+
409+
pt_estimator.fit("s3://bucket/path/to/training/data")
410+
296411
*********************
297412
Deploy PyTorch Models
298413
*********************
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
{
2+
"training": {
3+
"processors": ["trn"],
4+
"version_aliases": {"1.11": "1.11.0"},
5+
"versions": {
6+
"1.11.0": {
7+
"py_versions": ["py38"],
8+
"repository": "pytorch-training-neuron",
9+
"registries": {
10+
"af-south-1": "626614931356",
11+
"ap-east-1": "871362719292",
12+
"ap-northeast-1": "763104351884",
13+
"ap-northeast-2": "763104351884",
14+
"ap-northeast-3": "364406365360",
15+
"ap-south-1": "763104351884",
16+
"ap-southeast-1": "763104351884",
17+
"ap-southeast-2": "763104351884",
18+
"ca-central-1": "763104351884",
19+
"cn-north-1": "727897471807",
20+
"cn-northwest-1": "727897471807",
21+
"eu-central-1": "763104351884",
22+
"eu-north-1": "763104351884",
23+
"eu-west-1": "763104351884",
24+
"eu-west-2": "763104351884",
25+
"eu-west-3": "763104351884",
26+
"eu-south-1": "692866216735",
27+
"me-south-1": "217643126080",
28+
"sa-east-1": "763104351884",
29+
"us-east-1": "763104351884",
30+
"us-east-2": "763104351884",
31+
"us-gov-west-1": "442386744353",
32+
"us-iso-east-1": "886529160074",
33+
"us-west-1": "763104351884",
34+
"us-west-2": "763104351884"
35+
},
36+
"container_version": {"trn": "ubuntu20.04"},
37+
"sdk_versions": ["sdk2.4.0"]
38+
}
39+
}
40+
}
41+
}

src/sagemaker/image_uris.py

+31-2
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
HUGGING_FACE_FRAMEWORK = "huggingface"
3434
XGBOOST_FRAMEWORK = "xgboost"
3535
SKLEARN_FRAMEWORK = "sklearn"
36+
TRAINIUM_ALLOWED_FRAMEWORKS = "pytorch"
3637

3738

3839
@override_pipeline_parameter_var
@@ -150,11 +151,12 @@ def retrieve(
150151
)
151152
else:
152153
_framework = framework
153-
if framework == HUGGING_FACE_FRAMEWORK:
154+
if framework == HUGGING_FACE_FRAMEWORK or framework in TRAINIUM_ALLOWED_FRAMEWORKS:
154155
inference_tool = _get_inference_tool(inference_tool, instance_type)
155156
if inference_tool == "neuron":
156157
_framework = f"{framework}-{inference_tool}"
157158
final_image_scope = _get_final_image_scope(framework, instance_type, image_scope)
159+
_validate_for_suppported_frameworks_and_instance_type(framework, instance_type)
158160
config = _config_for_framework_and_scope(_framework, final_image_scope, accelerator_type)
159161

160162
original_version = version
@@ -186,6 +188,12 @@ def retrieve(
186188
if version_config.get("container_version"):
187189
container_version = version_config["container_version"][processor]
188190

191+
# Append sdk version in case of trainium instances
192+
if repo in ["pytorch-training-neuron"]:
193+
if not sdk_version:
194+
sdk_version = _get_latest_versions(version_config["sdk_versions"])
195+
container_version = sdk_version + "-" + container_version
196+
189197
if framework == HUGGING_FACE_FRAMEWORK:
190198
pt_or_tf_version = (
191199
re.compile("^(pytorch|tensorflow)(.*)$").match(base_framework_version).group(2)
@@ -344,6 +352,16 @@ def _config_for_framework_and_scope(framework, image_scope, accelerator_type=Non
344352
return config if "scope" in config else config[image_scope]
345353

346354

355+
def _validate_for_suppported_frameworks_and_instance_type(framework, instace_type):
356+
"""Validate if framework is supported for the instance_type"""
357+
if (
358+
instace_type is not None
359+
and "trn" in instace_type
360+
and framework not in TRAINIUM_ALLOWED_FRAMEWORKS
361+
):
362+
_validate_framework(framework, TRAINIUM_ALLOWED_FRAMEWORKS, "framework")
363+
364+
347365
def config_for_framework(framework):
348366
"""Loads the JSON config for the given framework."""
349367
fname = os.path.join(os.path.dirname(__file__), "image_uri_config", "{}.json".format(framework))
@@ -371,7 +389,7 @@ def _get_inference_tool(inference_tool, instance_type):
371389
"""Extract the inference tool name from instance type."""
372390
if not inference_tool:
373391
instance_type_family = _get_instance_type_family(instance_type)
374-
if instance_type_family.startswith("inf"):
392+
if instance_type_family.startswith("inf") or instance_type_family.startswith("trn"):
375393
return "neuron"
376394
return inference_tool
377395

@@ -460,6 +478,8 @@ def _processor(instance_type, available_processors, serverless_inference_config=
460478
processor = family
461479
elif family.startswith("inf"):
462480
processor = "inf"
481+
elif family.startswith("trn"):
482+
processor = "trn"
463483
elif family[0] in ("g", "p"):
464484
processor = "gpu"
465485
else:
@@ -523,6 +543,15 @@ def _validate_arg(arg, available_options, arg_name):
523543
)
524544

525545

546+
def _validate_framework(framework, allowed_frameworks, arg_name):
547+
"""Checks if the framework is in the allowed frameworks, and raises a ``ValueError`` if not."""
548+
if framework not in allowed_frameworks:
549+
raise ValueError(
550+
f"Unsupported {arg_name}: {framework}. "
551+
f"Supported {arg_name}(s) for trainium instances: {allowed_frameworks}."
552+
)
553+
554+
526555
def _format_tag(tag_prefix, processor, py_version, container_version, inference_tool=None):
527556
"""Creates a tag for the image URI."""
528557
if inference_tool:

tests/conftest.py

+5
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,11 @@ def huggingface_neuron_latest_inference_py_version():
358358
return "py37"
359359

360360

361+
@pytest.fixture(scope="module")
362+
def pytorch_neuron_version():
363+
return "1.11"
364+
365+
361366
@pytest.fixture(scope="module")
362367
def pytorch_eia_py_version():
363368
return "py3"

tests/unit/sagemaker/image_uris/expected_uris.py

+18
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,24 @@ def framework_uri(repo, fw_version, account, py_version=None, processor="cpu", r
3030
return IMAGE_URI_FORMAT.format(account, region, domain, repo, tag)
3131

3232

33+
def neuron_framework_uri(
34+
repo,
35+
fw_version,
36+
account,
37+
py_version=None,
38+
inference_tool="neuron",
39+
region=REGION,
40+
sdk_version="sdk2.4.0",
41+
container_version="ubuntu20.04",
42+
):
43+
domain = ALTERNATE_DOMAINS.get(region, DOMAIN)
44+
tag = "-".join(
45+
x for x in (fw_version, inference_tool, py_version, sdk_version, container_version) if x
46+
)
47+
48+
return IMAGE_URI_FORMAT.format(account, region, domain, repo, tag)
49+
50+
3351
def algo_uri(algo, account, region, version=1):
3452
domain = ALTERNATE_DOMAINS.get(region, DOMAIN)
3553
return IMAGE_URI_FORMAT.format(account, region, domain, algo, version)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
from sagemaker import image_uris
16+
from tests.unit.sagemaker.image_uris import expected_uris
17+
18+
ACCOUNTS = {
19+
"af-south-1": "626614931356",
20+
"ap-east-1": "871362719292",
21+
"ap-northeast-1": "763104351884",
22+
"ap-northeast-2": "763104351884",
23+
"ap-northeast-3": "364406365360",
24+
"ap-south-1": "763104351884",
25+
"ap-southeast-1": "763104351884",
26+
"ap-southeast-2": "763104351884",
27+
"ca-central-1": "763104351884",
28+
"cn-north-1": "727897471807",
29+
"cn-northwest-1": "727897471807",
30+
"eu-central-1": "763104351884",
31+
"eu-north-1": "763104351884",
32+
"eu-west-1": "763104351884",
33+
"eu-west-2": "763104351884",
34+
"eu-west-3": "763104351884",
35+
"eu-south-1": "692866216735",
36+
"me-south-1": "217643126080",
37+
"sa-east-1": "763104351884",
38+
"us-east-1": "763104351884",
39+
"us-east-2": "763104351884",
40+
"us-gov-west-1": "442386744353",
41+
"us-iso-east-1": "886529160074",
42+
"us-west-1": "763104351884",
43+
"us-west-2": "763104351884",
44+
}
45+
46+
TRAINIUM_REGIONS = ACCOUNTS.keys()
47+
48+
49+
def _expected_trainium_framework_uri(
50+
framework, version, region="us-west-2", inference_tool="neuron"
51+
):
52+
return expected_uris.neuron_framework_uri(
53+
"{}-neuron".format(framework),
54+
fw_version=version,
55+
py_version="py38",
56+
account=ACCOUNTS[region],
57+
region=region,
58+
inference_tool=inference_tool,
59+
)
60+
61+
62+
def _test_trainium_framework_uris(framework, version):
63+
for region in TRAINIUM_REGIONS:
64+
uri = image_uris.retrieve(
65+
framework, region, instance_type="ml.trn1.xlarge", version=version
66+
)
67+
expected = _expected_trainium_framework_uri(
68+
"{}-training".format(framework), version, region=region, inference_tool="neuron"
69+
)
70+
assert expected == uri
71+
72+
73+
def test_trainium_pytorch(pytorch_neuron_version):
74+
_test_trainium_framework_uris("pytorch", pytorch_neuron_version)

0 commit comments

Comments
 (0)