Skip to content

Commit 5e6372c

Browse files
committed
Trainium Neuron support for PyTorch
1 parent f2d5e41 commit 5e6372c

File tree

8 files changed

+195
-3
lines changed

8 files changed

+195
-3
lines changed

src/sagemaker/estimator.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
UploadedCode,
4545
_region_supports_debugger,
4646
_region_supports_profiler,
47+
_instance_type_supports_profiler,
4748
get_mp_parameters,
4849
tar_and_upload_dir,
4950
validate_source_dir,
@@ -592,7 +593,9 @@ def __init__(
592593

593594
self.max_retry_attempts = max_retry_attempts
594595

595-
if not _region_supports_profiler(self.sagemaker_session.boto_region_name):
596+
if not _region_supports_profiler(
597+
self.sagemaker_session.boto_region_name
598+
) or not _instance_type_supports_profiler(self.instance_type):
596599
self.disable_profiler = True
597600

598601
self.profiler_rule_configs = None

src/sagemaker/fw_utils.py

+16
Original file line numberDiff line numberDiff line change
@@ -904,6 +904,22 @@ def _region_supports_profiler(region_name):
904904
return region_name.lower() not in PROFILER_UNSUPPORTED_REGIONS
905905

906906

907+
def _instance_type_supports_profiler(instance_type):
908+
"""Returns bool indicating whether instance_type supports SageMaker Debugger profiling feature.
909+
910+
Args:
911+
instance_type (str): Name of the instance_type to check against.
912+
913+
Returns:
914+
bool: Whether or not the region supports Amazon SageMaker Debugger profiling feature.
915+
"""
916+
if isinstance(instance_type, str):
917+
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
918+
if match and match[1].startswith("trn"):
919+
return False
920+
return True
921+
922+
907923
def validate_version_or_image_args(framework_version, py_version, image_uri):
908924
"""Checks if version or image arguments are specified.
909925
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
{
2+
"training": {
3+
"processors": ["trn"],
4+
"version_aliases": {"1.11": "1.11.0"},
5+
"versions": {
6+
"1.11.0": {
7+
"py_versions": ["py38"],
8+
"repository": "pytorch-training-neuron",
9+
"registries": {
10+
"af-south-1": "626614931356",
11+
"ap-east-1": "871362719292",
12+
"ap-northeast-1": "763104351884",
13+
"ap-northeast-2": "763104351884",
14+
"ap-northeast-3": "364406365360",
15+
"ap-south-1": "763104351884",
16+
"ap-southeast-1": "763104351884",
17+
"ap-southeast-2": "763104351884",
18+
"ca-central-1": "763104351884",
19+
"cn-north-1": "727897471807",
20+
"cn-northwest-1": "727897471807",
21+
"eu-central-1": "763104351884",
22+
"eu-north-1": "763104351884",
23+
"eu-west-1": "763104351884",
24+
"eu-west-2": "763104351884",
25+
"eu-west-3": "763104351884",
26+
"eu-south-1": "692866216735",
27+
"me-south-1": "217643126080",
28+
"sa-east-1": "763104351884",
29+
"us-east-1": "763104351884",
30+
"us-east-2": "763104351884",
31+
"us-gov-west-1": "442386744353",
32+
"us-iso-east-1": "886529160074",
33+
"us-west-1": "763104351884",
34+
"us-west-2": "763104351884"
35+
},
36+
"container_version": {"trn": "ubuntu20.04"},
37+
"sdk_versions": ["sdk2.3.0"]
38+
}
39+
}
40+
}
41+
}

src/sagemaker/image_uris.py

+31-2
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030

3131
ECR_URI_TEMPLATE = "{registry}.dkr.{hostname}/{repository}"
3232
HUGGING_FACE_FRAMEWORK = "huggingface"
33+
TRAINIUM_ALLOWED_FRAMEWORKS = "pytorch"
3334

3435

3536
@override_pipeline_parameter_var
@@ -147,10 +148,11 @@ def retrieve(
147148
)
148149
else:
149150
_framework = framework
150-
if framework == HUGGING_FACE_FRAMEWORK:
151+
if framework == HUGGING_FACE_FRAMEWORK or framework in TRAINIUM_ALLOWED_FRAMEWORKS:
151152
inference_tool = _get_inference_tool(inference_tool, instance_type)
152153
if inference_tool == "neuron":
153154
_framework = f"{framework}-{inference_tool}"
155+
_validate_for_suppported_frameworks_and_instance_type(framework, instance_type)
154156
config = _config_for_framework_and_scope(_framework, image_scope, accelerator_type)
155157

156158
original_version = version
@@ -182,6 +184,12 @@ def retrieve(
182184
if version_config.get("container_version"):
183185
container_version = version_config["container_version"][processor]
184186

187+
# Append sdk version in case of trainium instances
188+
if repo in ["pytorch-training-neuron"]:
189+
if not sdk_version:
190+
sdk_version = _get_latest_versions(version_config["sdk_versions"])
191+
container_version = sdk_version + "-" + container_version
192+
185193
if framework == HUGGING_FACE_FRAMEWORK:
186194
pt_or_tf_version = (
187195
re.compile("^(pytorch|tensorflow)(.*)$").match(base_framework_version).group(2)
@@ -280,6 +288,16 @@ def _config_for_framework_and_scope(framework, image_scope, accelerator_type=Non
280288
return config if "scope" in config else config[image_scope]
281289

282290

291+
def _validate_for_suppported_frameworks_and_instance_type(framework, instace_type):
292+
"""Validate if framework is supported for the instance_type"""
293+
if (
294+
instace_type is not None
295+
and "trn" in instace_type
296+
and framework not in TRAINIUM_ALLOWED_FRAMEWORKS
297+
):
298+
_validate_framework(framework, TRAINIUM_ALLOWED_FRAMEWORKS, "framework")
299+
300+
283301
def config_for_framework(framework):
284302
"""Loads the JSON config for the given framework."""
285303
fname = os.path.join(os.path.dirname(__file__), "image_uri_config", "{}.json".format(framework))
@@ -291,7 +309,7 @@ def _get_inference_tool(inference_tool, instance_type):
291309
"""Extract the inference tool name from instance type."""
292310
if not inference_tool and instance_type:
293311
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
294-
if match and match[1].startswith("inf"):
312+
if match and (match[1].startswith("inf") or match[1].startswith("trn")):
295313
return "neuron"
296314
return inference_tool
297315

@@ -382,6 +400,8 @@ def _processor(instance_type, available_processors, serverless_inference_config=
382400
processor = family
383401
elif family.startswith("inf"):
384402
processor = "inf"
403+
elif family.startswith("trn"):
404+
processor = "trn"
385405
elif family[0] in ("g", "p"):
386406
processor = "gpu"
387407
else:
@@ -446,6 +466,15 @@ def _validate_arg(arg, available_options, arg_name):
446466
)
447467

448468

469+
def _validate_framework(framework, allowed_frameworks, arg_name):
470+
"""Checks if the framework is in the allowed frameworks, and raises a ``ValueError`` if not."""
471+
if framework not in allowed_frameworks:
472+
raise ValueError(
473+
f"Unsupported {arg_name}: {framework}. "
474+
f"Supported {arg_name}(s) for trainium instances: {allowed_frameworks}."
475+
)
476+
477+
449478
def _format_tag(tag_prefix, processor, py_version, container_version, inference_tool=None):
450479
"""Creates a tag for the image URI."""
451480
if inference_tool:

tests/conftest.py

+5
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,11 @@ def huggingface_neuron_latest_inference_py_version():
328328
return "py37"
329329

330330

331+
@pytest.fixture(scope="module")
332+
def pytorch_neuron_version():
333+
return "1.11"
334+
335+
331336
@pytest.fixture(scope="module")
332337
def pytorch_eia_py_version():
333338
return "py3"

tests/unit/sagemaker/image_uris/expected_uris.py

+18
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,24 @@ def framework_uri(repo, fw_version, account, py_version=None, processor="cpu", r
3030
return IMAGE_URI_FORMAT.format(account, region, domain, repo, tag)
3131

3232

33+
def neuron_framework_uri(
34+
repo,
35+
fw_version,
36+
account,
37+
py_version=None,
38+
inference_tool="neuron",
39+
region=REGION,
40+
sdk_version="sdk2.3.0",
41+
container_version="ubuntu20.04",
42+
):
43+
domain = ALTERNATE_DOMAINS.get(region, DOMAIN)
44+
tag = "-".join(
45+
x for x in (fw_version, inference_tool, py_version, sdk_version, container_version) if x
46+
)
47+
48+
return IMAGE_URI_FORMAT.format(account, region, domain, repo, tag)
49+
50+
3351
def algo_uri(algo, account, region, version=1):
3452
domain = ALTERNATE_DOMAINS.get(region, DOMAIN)
3553
return IMAGE_URI_FORMAT.format(account, region, domain, algo, version)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
from sagemaker import image_uris
16+
from tests.unit.sagemaker.image_uris import expected_uris
17+
18+
ACCOUNTS = {
19+
"af-south-1": "626614931356",
20+
"ap-east-1": "871362719292",
21+
"ap-northeast-1": "763104351884",
22+
"ap-northeast-2": "763104351884",
23+
"ap-northeast-3": "364406365360",
24+
"ap-south-1": "763104351884",
25+
"ap-southeast-1": "763104351884",
26+
"ap-southeast-2": "763104351884",
27+
"ca-central-1": "763104351884",
28+
"cn-north-1": "727897471807",
29+
"cn-northwest-1": "727897471807",
30+
"eu-central-1": "763104351884",
31+
"eu-north-1": "763104351884",
32+
"eu-west-1": "763104351884",
33+
"eu-west-2": "763104351884",
34+
"eu-west-3": "763104351884",
35+
"eu-south-1": "692866216735",
36+
"me-south-1": "217643126080",
37+
"sa-east-1": "763104351884",
38+
"us-east-1": "763104351884",
39+
"us-east-2": "763104351884",
40+
"us-gov-west-1": "442386744353",
41+
"us-iso-east-1": "886529160074",
42+
"us-west-1": "763104351884",
43+
"us-west-2": "763104351884",
44+
}
45+
46+
TRAINIUM_REGIONS = ACCOUNTS.keys()
47+
48+
49+
def _expected_trainium_framework_uri(
50+
framework, version, region="us-west-2", inference_tool="neuron"
51+
):
52+
return expected_uris.neuron_framework_uri(
53+
"{}-neuron".format(framework),
54+
fw_version=version,
55+
py_version="py38",
56+
account=ACCOUNTS[region],
57+
region=region,
58+
inference_tool=inference_tool,
59+
)
60+
61+
62+
def _test_trainium_framework_uris(framework, version):
63+
for region in TRAINIUM_REGIONS:
64+
uri = image_uris.retrieve(
65+
framework, region, instance_type="ml.trn1.xlarge", version=version
66+
)
67+
expected = _expected_trainium_framework_uri(
68+
"{}-training".format(framework), version, region=region, inference_tool="neuron"
69+
)
70+
assert expected == uri
71+
72+
73+
def test_trainium_pytorch(pytorch_neuron_version):
74+
_test_trainium_framework_uris("pytorch", pytorch_neuron_version)

tests/unit/test_fw_utils.py

+6
Original file line numberDiff line numberDiff line change
@@ -946,3 +946,9 @@ def test_validate_pytorchddp_raises():
946946
py_version="py2",
947947
image_uri=None,
948948
)
949+
950+
951+
def test_instance_type_supports_profiler():
952+
assert fw_utils._instance_type_supports_profiler("ml.trn1.xlarge") is False
953+
assert fw_utils._instance_type_supports_profiler("ml.m4.xlarge") is True
954+
assert fw_utils._instance_type_supports_profiler("local") is True

0 commit comments

Comments
 (0)