Skip to content

Commit a53fbc1

Browse files
committed
Trainium Neuron support for PyTorch
1 parent 8dc17fb commit a53fbc1

File tree

5 files changed

+169
-2
lines changed

5 files changed

+169
-2
lines changed
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
{
2+
"training": {
3+
"processors": ["trn"],
4+
"version_aliases": {"1.11": "1.11.0"},
5+
"versions": {
6+
"1.11.0": {
7+
"py_versions": ["py38"],
8+
"repository": "pytorch-training-neuron",
9+
"registries": {
10+
"af-south-1": "626614931356",
11+
"ap-east-1": "871362719292",
12+
"ap-northeast-1": "763104351884",
13+
"ap-northeast-2": "763104351884",
14+
"ap-northeast-3": "364406365360",
15+
"ap-south-1": "763104351884",
16+
"ap-southeast-1": "763104351884",
17+
"ap-southeast-2": "763104351884",
18+
"ca-central-1": "763104351884",
19+
"cn-north-1": "727897471807",
20+
"cn-northwest-1": "727897471807",
21+
"eu-central-1": "763104351884",
22+
"eu-north-1": "763104351884",
23+
"eu-west-1": "763104351884",
24+
"eu-west-2": "763104351884",
25+
"eu-west-3": "763104351884",
26+
"eu-south-1": "692866216735",
27+
"me-south-1": "217643126080",
28+
"sa-east-1": "763104351884",
29+
"us-east-1": "763104351884",
30+
"us-east-2": "763104351884",
31+
"us-gov-west-1": "442386744353",
32+
"us-iso-east-1": "886529160074",
33+
"us-west-1": "763104351884",
34+
"us-west-2": "763104351884"
35+
},
36+
"container_version": {"trn": "ubuntu20.04"},
37+
"sdk_versions": ["sdk2.3.0"]
38+
}
39+
}
40+
}
41+
}

src/sagemaker/image_uris.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
HUGGING_FACE_FRAMEWORK = "huggingface"
3434
XGBOOST_FRAMEWORK = "xgboost"
3535
SKLEARN_FRAMEWORK = "sklearn"
36+
TRAINIUM_ALLOWED_FRAMEWORKS = "pytorch"
3637

3738

3839
@override_pipeline_parameter_var
@@ -150,11 +151,12 @@ def retrieve(
150151
)
151152
else:
152153
_framework = framework
153-
if framework == HUGGING_FACE_FRAMEWORK:
154+
if framework == HUGGING_FACE_FRAMEWORK or framework in TRAINIUM_ALLOWED_FRAMEWORKS:
154155
inference_tool = _get_inference_tool(inference_tool, instance_type)
155156
if inference_tool == "neuron":
156157
_framework = f"{framework}-{inference_tool}"
157158
final_image_scope = _get_final_image_scope(framework, instance_type, image_scope)
159+
_validate_for_suppported_frameworks_and_instance_type(framework, instance_type)
158160
config = _config_for_framework_and_scope(_framework, final_image_scope, accelerator_type)
159161

160162
original_version = version
@@ -186,6 +188,12 @@ def retrieve(
186188
if version_config.get("container_version"):
187189
container_version = version_config["container_version"][processor]
188190

191+
# Append sdk version in case of trainium instances
192+
if repo in ["pytorch-training-neuron"]:
193+
if not sdk_version:
194+
sdk_version = _get_latest_versions(version_config["sdk_versions"])
195+
container_version = sdk_version + "-" + container_version
196+
189197
if framework == HUGGING_FACE_FRAMEWORK:
190198
pt_or_tf_version = (
191199
re.compile("^(pytorch|tensorflow)(.*)$").match(base_framework_version).group(2)
@@ -344,6 +352,16 @@ def _config_for_framework_and_scope(framework, image_scope, accelerator_type=Non
344352
return config if "scope" in config else config[image_scope]
345353

346354

355+
def _validate_for_suppported_frameworks_and_instance_type(framework, instace_type):
356+
"""Validate if framework is supported for the instance_type"""
357+
if (
358+
instace_type is not None
359+
and "trn" in instace_type
360+
and framework not in TRAINIUM_ALLOWED_FRAMEWORKS
361+
):
362+
_validate_framework(framework, TRAINIUM_ALLOWED_FRAMEWORKS, "framework")
363+
364+
347365
def config_for_framework(framework):
348366
"""Loads the JSON config for the given framework."""
349367
fname = os.path.join(os.path.dirname(__file__), "image_uri_config", "{}.json".format(framework))
@@ -371,7 +389,7 @@ def _get_inference_tool(inference_tool, instance_type):
371389
"""Extract the inference tool name from instance type."""
372390
if not inference_tool:
373391
instance_type_family = _get_instance_type_family(instance_type)
374-
if instance_type_family.startswith("inf"):
392+
if instance_type_family.startswith("inf") or instance_type_family.startswith("trn"):
375393
return "neuron"
376394
return inference_tool
377395

@@ -460,6 +478,8 @@ def _processor(instance_type, available_processors, serverless_inference_config=
460478
processor = family
461479
elif family.startswith("inf"):
462480
processor = "inf"
481+
elif family.startswith("trn"):
482+
processor = "trn"
463483
elif family[0] in ("g", "p"):
464484
processor = "gpu"
465485
else:
@@ -523,6 +543,15 @@ def _validate_arg(arg, available_options, arg_name):
523543
)
524544

525545

546+
def _validate_framework(framework, allowed_frameworks, arg_name):
547+
"""Checks if the framework is in the allowed frameworks, and raises a ``ValueError`` if not."""
548+
if framework not in allowed_frameworks:
549+
raise ValueError(
550+
f"Unsupported {arg_name}: {framework}. "
551+
f"Supported {arg_name}(s) for trainium instances: {allowed_frameworks}."
552+
)
553+
554+
526555
def _format_tag(tag_prefix, processor, py_version, container_version, inference_tool=None):
527556
"""Creates a tag for the image URI."""
528557
if inference_tool:

tests/conftest.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,11 @@ def huggingface_neuron_latest_inference_py_version():
358358
return "py37"
359359

360360

361+
@pytest.fixture(scope="module")
362+
def pytorch_neuron_version():
363+
return "1.11"
364+
365+
361366
@pytest.fixture(scope="module")
362367
def pytorch_eia_py_version():
363368
return "py3"

tests/unit/sagemaker/image_uris/expected_uris.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,24 @@ def framework_uri(repo, fw_version, account, py_version=None, processor="cpu", r
3030
return IMAGE_URI_FORMAT.format(account, region, domain, repo, tag)
3131

3232

33+
def neuron_framework_uri(
34+
repo,
35+
fw_version,
36+
account,
37+
py_version=None,
38+
inference_tool="neuron",
39+
region=REGION,
40+
sdk_version="sdk2.3.0",
41+
container_version="ubuntu20.04",
42+
):
43+
domain = ALTERNATE_DOMAINS.get(region, DOMAIN)
44+
tag = "-".join(
45+
x for x in (fw_version, inference_tool, py_version, sdk_version, container_version) if x
46+
)
47+
48+
return IMAGE_URI_FORMAT.format(account, region, domain, repo, tag)
49+
50+
3351
def algo_uri(algo, account, region, version=1):
3452
domain = ALTERNATE_DOMAINS.get(region, DOMAIN)
3553
return IMAGE_URI_FORMAT.format(account, region, domain, algo, version)
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
from sagemaker import image_uris
16+
from tests.unit.sagemaker.image_uris import expected_uris
17+
18+
ACCOUNTS = {
19+
"af-south-1": "626614931356",
20+
"ap-east-1": "871362719292",
21+
"ap-northeast-1": "763104351884",
22+
"ap-northeast-2": "763104351884",
23+
"ap-northeast-3": "364406365360",
24+
"ap-south-1": "763104351884",
25+
"ap-southeast-1": "763104351884",
26+
"ap-southeast-2": "763104351884",
27+
"ca-central-1": "763104351884",
28+
"cn-north-1": "727897471807",
29+
"cn-northwest-1": "727897471807",
30+
"eu-central-1": "763104351884",
31+
"eu-north-1": "763104351884",
32+
"eu-west-1": "763104351884",
33+
"eu-west-2": "763104351884",
34+
"eu-west-3": "763104351884",
35+
"eu-south-1": "692866216735",
36+
"me-south-1": "217643126080",
37+
"sa-east-1": "763104351884",
38+
"us-east-1": "763104351884",
39+
"us-east-2": "763104351884",
40+
"us-gov-west-1": "442386744353",
41+
"us-iso-east-1": "886529160074",
42+
"us-west-1": "763104351884",
43+
"us-west-2": "763104351884",
44+
}
45+
46+
TRAINIUM_REGIONS = ACCOUNTS.keys()
47+
48+
49+
def _expected_trainium_framework_uri(
50+
framework, version, region="us-west-2", inference_tool="neuron"
51+
):
52+
return expected_uris.neuron_framework_uri(
53+
"{}-neuron".format(framework),
54+
fw_version=version,
55+
py_version="py38",
56+
account=ACCOUNTS[region],
57+
region=region,
58+
inference_tool=inference_tool,
59+
)
60+
61+
62+
def _test_trainium_framework_uris(framework, version):
63+
for region in TRAINIUM_REGIONS:
64+
uri = image_uris.retrieve(
65+
framework, region, instance_type="ml.trn1.xlarge", version=version
66+
)
67+
expected = _expected_trainium_framework_uri(
68+
"{}-training".format(framework), version, region=region, inference_tool="neuron"
69+
)
70+
assert expected == uri
71+
72+
73+
def test_trainium_pytorch(pytorch_neuron_version):
74+
_test_trainium_framework_uris("pytorch", pytorch_neuron_version)

0 commit comments

Comments
 (0)