Skip to content

Commit c03cf9f

Browse files
knikureJoseJuan98
authored andcommitted
feature: Graviton support for PyTorch and Tensorflow frameworks (aws#3432)
1 parent fa472ad commit c03cf9f

File tree

7 files changed

+221
-0
lines changed

7 files changed

+221
-0
lines changed

src/sagemaker/fw_utils.py

+9
Original file line numberDiff line numberDiff line change
@@ -134,10 +134,13 @@
134134
"1.12.0",
135135
]
136136

137+
137138
TORCH_DISTRIBUTED_SUPPORTED_FRAMEWORK_VERSIONS = ["1.11", "1.11.0"]
138139

140+
139141
TRAINIUM_SUPPORTED_DISTRIBUTION_STRATEGIES = ["torch_distributed"]
140142

143+
141144
SMDISTRIBUTED_SUPPORTED_STRATEGIES = ["dataparallel", "modelparallel"]
142145

143146

@@ -160,6 +163,12 @@ def validate_source_dir(script, directory):
160163
return True
161164

162165

166+
GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY = ["c6g", "t4g", "r6g", "m6g"]
167+
168+
169+
GRAVITON_ALLOWED_FRAMEWORKS = set(["tensorflow", "pytorch"])
170+
171+
163172
def validate_source_code_input_against_pipeline_variables(
164173
entry_point: Optional[Union[str, PipelineVariable]] = None,
165174
source_dir: Optional[Union[str, PipelineVariable]] = None,

src/sagemaker/image_uri_config/pytorch.json

+45
Original file line numberDiff line numberDiff line change
@@ -654,6 +654,51 @@
654654
}
655655
}
656656
},
657+
"inference_graviton": {
658+
"processors": [
659+
"cpu"
660+
],
661+
"version_aliases": {
662+
"1.12": "1.12.1"
663+
},
664+
"versions": {
665+
"1.12.1": {
666+
"py_versions": [
667+
"py38"
668+
],
669+
"registries": {
670+
"af-south-1": "626614931356",
671+
"ap-east-1": "871362719292",
672+
"ap-northeast-1": "763104351884",
673+
"ap-northeast-2": "763104351884",
674+
"ap-northeast-3": "364406365360",
675+
"ap-south-1": "763104351884",
676+
"ap-southeast-1": "763104351884",
677+
"ap-southeast-2": "763104351884",
678+
"ap-southeast-3": "907027046896",
679+
"ca-central-1": "763104351884",
680+
"cn-north-1": "727897471807",
681+
"cn-northwest-1": "727897471807",
682+
"eu-central-1": "763104351884",
683+
"eu-north-1": "763104351884",
684+
"eu-west-1": "763104351884",
685+
"eu-west-2": "763104351884",
686+
"eu-west-3": "763104351884",
687+
"eu-south-1": "692866216735",
688+
"me-south-1": "217643126080",
689+
"sa-east-1": "763104351884",
690+
"us-east-1": "763104351884",
691+
"us-east-2": "763104351884",
692+
"us-gov-west-1": "442386744353",
693+
"us-iso-east-1": "886529160074",
694+
"us-west-1": "763104351884",
695+
"us-west-2": "763104351884"
696+
},
697+
"repository": "pytorch-inference-graviton",
698+
"container_version": {"cpu": "ubuntu20.04"}
699+
}
700+
}
701+
},
657702
"training": {
658703
"processors": [
659704
"cpu",

src/sagemaker/image_uri_config/tensorflow.json

+45
Original file line numberDiff line numberDiff line change
@@ -1471,6 +1471,51 @@
14711471
}
14721472
}
14731473
},
1474+
"inference_graviton": {
1475+
"processors": [
1476+
"cpu"
1477+
],
1478+
"version_aliases": {
1479+
"2.9": "2.9.1"
1480+
},
1481+
"versions": {
1482+
"2.9.1": {
1483+
"py_versions": [
1484+
"py38"
1485+
],
1486+
"registries": {
1487+
"af-south-1": "626614931356",
1488+
"ap-east-1": "871362719292",
1489+
"ap-northeast-1": "763104351884",
1490+
"ap-northeast-2": "763104351884",
1491+
"ap-northeast-3": "364406365360",
1492+
"ap-south-1": "763104351884",
1493+
"ap-southeast-1": "763104351884",
1494+
"ap-southeast-2": "763104351884",
1495+
"ap-southeast-3": "907027046896",
1496+
"ca-central-1": "763104351884",
1497+
"cn-north-1": "727897471807",
1498+
"cn-northwest-1": "727897471807",
1499+
"eu-central-1": "763104351884",
1500+
"eu-north-1": "763104351884",
1501+
"eu-west-1": "763104351884",
1502+
"eu-west-2": "763104351884",
1503+
"eu-west-3": "763104351884",
1504+
"eu-south-1": "692866216735",
1505+
"me-south-1": "217643126080",
1506+
"sa-east-1": "763104351884",
1507+
"us-east-1": "763104351884",
1508+
"us-east-2": "763104351884",
1509+
"us-gov-west-1": "442386744353",
1510+
"us-iso-east-1": "886529160074",
1511+
"us-west-1": "763104351884",
1512+
"us-west-2": "763104351884"
1513+
},
1514+
"repository": "tensorflow-inference-graviton",
1515+
"container_version": {"cpu": "ubuntu20.04"}
1516+
}
1517+
}
1518+
},
14741519
"training": {
14751520
"processors": [
14761521
"cpu",

src/sagemaker/image_uris.py

+14
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from sagemaker.jumpstart import artifacts
2626
from sagemaker.workflow import is_pipeline_variable
2727
from sagemaker.workflow.utilities import override_pipeline_parameter_var
28+
from sagemaker.fw_utils import GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY, GRAVITON_ALLOWED_FRAMEWORKS
2829

2930
logger = logging.getLogger(__name__)
3031

@@ -151,6 +152,7 @@ def retrieve(
151152
inference_tool = _get_inference_tool(inference_tool, instance_type)
152153
if inference_tool == "neuron":
153154
_framework = f"{framework}-{inference_tool}"
155+
image_scope = _get_image_scope_for_instance_type(_framework, instance_type, image_scope)
154156
config = _config_for_framework_and_scope(_framework, image_scope, accelerator_type)
155157

156158
original_version = version
@@ -216,6 +218,9 @@ def retrieve(
216218
else:
217219
tag_prefix = version_config.get("tag_prefix", version)
218220

221+
if repo == f"{framework}-inference-graviton":
222+
container_version = f"{container_version}-sagemaker"
223+
219224
tag = _format_tag(tag_prefix, processor, py_version, container_version, inference_tool)
220225

221226
if instance_type is not None and _should_auto_select_container_version(
@@ -287,6 +292,15 @@ def config_for_framework(framework):
287292
return json.load(f)
288293

289294

295+
def _get_image_scope_for_instance_type(framework, instance_type, image_scope):
296+
"""Extract the image scope from instance type."""
297+
if framework in GRAVITON_ALLOWED_FRAMEWORKS and isinstance(instance_type, str):
298+
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
299+
if match and match[1] in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY:
300+
return "inference_graviton"
301+
return image_scope
302+
303+
290304
def _get_inference_tool(inference_tool, instance_type):
291305
"""Extract the inference tool name from instance type."""
292306
if not inference_tool and instance_type:

tests/conftest.py

+10
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,16 @@ def huggingface_pytorch_latest_inference_py_version(huggingface_inference_pytorc
308308
)
309309

310310

311+
@pytest.fixture(scope="module")
312+
def graviton_tensorflow_version():
313+
return "2.9.1"
314+
315+
316+
@pytest.fixture(scope="module")
317+
def graviton_pytorch_version():
318+
return "1.12.1"
319+
320+
311321
@pytest.fixture(scope="module")
312322
def huggingface_tensorflow_latest_training_py_version():
313323
return "py38"

tests/unit/sagemaker/image_uris/expected_uris.py

+15
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,18 @@ def algo_uri(algo, account, region, version=1):
3838
def monitor_uri(account, region=REGION):
3939
domain = ALTERNATE_DOMAINS.get(region, DOMAIN)
4040
return MONITOR_URI_FORMAT.format(account, region, domain)
41+
42+
43+
def graviton_framework_uri(
44+
repo,
45+
fw_version,
46+
account,
47+
py_version="py38",
48+
processor="cpu",
49+
region=REGION,
50+
container_version="ubuntu20.04-sagemaker",
51+
):
52+
domain = ALTERNATE_DOMAINS.get(region, DOMAIN)
53+
tag = "-".join(x for x in (fw_version, processor, py_version, container_version) if x)
54+
55+
return IMAGE_URI_FORMAT.format(account, region, domain, repo, tag)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
from sagemaker import image_uris
16+
from tests.unit.sagemaker.image_uris import expected_uris
17+
18+
GRAVITON_ALGOS = ("tensorflow", "pytotch")
19+
GRAVITON_INSTANCE_TYPES = [
20+
"ml.c6g.4xlarge",
21+
"ml.t4g.2xlarge",
22+
"ml.r6g.2xlarge",
23+
"ml.m6g.4xlarge",
24+
]
25+
26+
ACCOUNTS = {
27+
"af-south-1": "626614931356",
28+
"ap-east-1": "871362719292",
29+
"ap-northeast-1": "763104351884",
30+
"ap-northeast-2": "763104351884",
31+
"ap-northeast-3": "364406365360",
32+
"ap-south-1": "763104351884",
33+
"ap-southeast-1": "763104351884",
34+
"ap-southeast-2": "763104351884",
35+
"ap-southeast-3": "907027046896",
36+
"ca-central-1": "763104351884",
37+
"cn-north-1": "727897471807",
38+
"cn-northwest-1": "727897471807",
39+
"eu-central-1": "763104351884",
40+
"eu-north-1": "763104351884",
41+
"eu-west-1": "763104351884",
42+
"eu-west-2": "763104351884",
43+
"eu-west-3": "763104351884",
44+
"eu-south-1": "692866216735",
45+
"me-south-1": "217643126080",
46+
"sa-east-1": "763104351884",
47+
"us-east-1": "763104351884",
48+
"us-east-2": "763104351884",
49+
"us-gov-west-1": "442386744353",
50+
"us-iso-east-1": "886529160074",
51+
"us-west-1": "763104351884",
52+
"us-west-2": "763104351884",
53+
}
54+
55+
GRAVITON_REGIONS = ACCOUNTS.keys()
56+
57+
58+
def _test_graviton_framework_uris(framework, version):
59+
for region in GRAVITON_REGIONS:
60+
for instance_type in GRAVITON_INSTANCE_TYPES:
61+
uri = image_uris.retrieve(
62+
framework, region, instance_type=instance_type, version=version
63+
)
64+
expected = _expected_graviton_framework_uri(framework, version, region=region)
65+
assert expected == uri
66+
67+
68+
def test_graviton_tensorflow(graviton_tensorflow_version):
69+
_test_graviton_framework_uris("tensorflow", graviton_tensorflow_version)
70+
71+
72+
def test_graviton_pytorch(graviton_pytorch_version):
73+
_test_graviton_framework_uris("pytorch", graviton_pytorch_version)
74+
75+
76+
def _expected_graviton_framework_uri(framework, version, region):
77+
return expected_uris.graviton_framework_uri(
78+
"{}-inference-graviton".format(framework),
79+
fw_version=version,
80+
py_version="py38",
81+
account=ACCOUNTS[region],
82+
region=region,
83+
)

0 commit comments

Comments
 (0)