Skip to content

feature: Graviton support for PyTorch and Tensorflow frameworks #3432

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/sagemaker/fw_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,13 @@
"1.12.0",
]


TORCH_DISTRIBUTED_SUPPORTED_FRAMEWORK_VERSIONS = ["1.11", "1.11.0"]


TRAINIUM_SUPPORTED_DISTRIBUTION_STRATEGIES = ["torch_distributed"]


SMDISTRIBUTED_SUPPORTED_STRATEGIES = ["dataparallel", "modelparallel"]


Expand All @@ -160,6 +163,12 @@ def validate_source_dir(script, directory):
return True


GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY = ["c6g", "t4g", "r6g", "m6g"]


GRAVITON_ALLOWED_FRAMEWORKS = set(["tensorflow", "pytorch"])


def validate_source_code_input_against_pipeline_variables(
entry_point: Optional[Union[str, PipelineVariable]] = None,
source_dir: Optional[Union[str, PipelineVariable]] = None,
Expand Down
45 changes: 45 additions & 0 deletions src/sagemaker/image_uri_config/pytorch.json
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,51 @@
}
}
},
"inference_graviton": {
"processors": [
"cpu"
],
"version_aliases": {
"1.12": "1.12.1"
},
"versions": {
"1.12.1": {
"py_versions": [
"py38"
],
"registries": {
"af-south-1": "626614931356",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ap-southeast-3": "907027046896",
"ca-central-1": "763104351884",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
"eu-north-1": "763104351884",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"eu-south-1": "692866216735",
"me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
"us-west-2": "763104351884"
},
"repository": "pytorch-inference-graviton",
"container_version": {"cpu": "ubuntu20.04"}
}
}
},
"training": {
"processors": [
"cpu",
Expand Down
45 changes: 45 additions & 0 deletions src/sagemaker/image_uri_config/tensorflow.json
Original file line number Diff line number Diff line change
Expand Up @@ -1471,6 +1471,51 @@
}
}
},
"inference_graviton": {
"processors": [
"cpu"
],
"version_aliases": {
"2.9": "2.9.1"
},
"versions": {
"2.9.1": {
"py_versions": [
"py38"
],
"registries": {
"af-south-1": "626614931356",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ap-southeast-3": "907027046896",
"ca-central-1": "763104351884",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
"eu-north-1": "763104351884",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"eu-south-1": "692866216735",
"me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
"us-west-2": "763104351884"
},
"repository": "tensorflow-inference-graviton",
"container_version": {"cpu": "ubuntu20.04"}
}
}
},
"training": {
"processors": [
"cpu",
Expand Down
14 changes: 14 additions & 0 deletions src/sagemaker/image_uris.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from sagemaker.jumpstart import artifacts
from sagemaker.workflow import is_pipeline_variable
from sagemaker.workflow.utilities import override_pipeline_parameter_var
from sagemaker.fw_utils import GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY, GRAVITON_ALLOWED_FRAMEWORKS

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -151,6 +152,7 @@ def retrieve(
inference_tool = _get_inference_tool(inference_tool, instance_type)
if inference_tool == "neuron":
_framework = f"{framework}-{inference_tool}"
image_scope = _get_image_scope_for_instance_type(_framework, instance_type, image_scope)
config = _config_for_framework_and_scope(_framework, image_scope, accelerator_type)

original_version = version
Expand Down Expand Up @@ -216,6 +218,9 @@ def retrieve(
else:
tag_prefix = version_config.get("tag_prefix", version)

if repo == f"{framework}-inference-graviton":
container_version = f"{container_version}-sagemaker"

tag = _format_tag(tag_prefix, processor, py_version, container_version, inference_tool)

if instance_type is not None and _should_auto_select_container_version(
Expand Down Expand Up @@ -287,6 +292,15 @@ def config_for_framework(framework):
return json.load(f)


def _get_image_scope_for_instance_type(framework, instance_type, image_scope):
"""Extract the image scope from instance type."""
if framework in GRAVITON_ALLOWED_FRAMEWORKS and isinstance(instance_type, str):
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
if match and match[1] in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY:
return "inference_graviton"
return image_scope


def _get_inference_tool(inference_tool, instance_type):
"""Extract the inference tool name from instance type."""
if not inference_tool and instance_type:
Expand Down
10 changes: 10 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,16 @@ def huggingface_pytorch_latest_inference_py_version(huggingface_inference_pytorc
)


@pytest.fixture(scope="module")
def graviton_tensorflow_version():
return "2.9.1"


@pytest.fixture(scope="module")
def graviton_pytorch_version():
return "1.12.1"


@pytest.fixture(scope="module")
def huggingface_tensorflow_latest_training_py_version():
return "py38"
Expand Down
15 changes: 15 additions & 0 deletions tests/unit/sagemaker/image_uris/expected_uris.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,18 @@ def algo_uri(algo, account, region, version=1):
def monitor_uri(account, region=REGION):
domain = ALTERNATE_DOMAINS.get(region, DOMAIN)
return MONITOR_URI_FORMAT.format(account, region, domain)


def graviton_framework_uri(
repo,
fw_version,
account,
py_version="py38",
processor="cpu",
region=REGION,
container_version="ubuntu20.04-sagemaker",
):
domain = ALTERNATE_DOMAINS.get(region, DOMAIN)
tag = "-".join(x for x in (fw_version, processor, py_version, container_version) if x)

return IMAGE_URI_FORMAT.format(account, region, domain, repo, tag)
83 changes: 83 additions & 0 deletions tests/unit/sagemaker/image_uris/test_graviton.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import

from sagemaker import image_uris
from tests.unit.sagemaker.image_uris import expected_uris

GRAVITON_ALGOS = ("tensorflow", "pytotch")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
GRAVITON_ALGOS = ("tensorflow", "pytotch")
GRAVITON_ALGOS = ("tensorflow", "pytorch")

GRAVITON_INSTANCE_TYPES = [
"ml.c6g.4xlarge",
"ml.t4g.2xlarge",
"ml.r6g.2xlarge",
"ml.m6g.4xlarge",
]

ACCOUNTS = {
"af-south-1": "626614931356",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ap-southeast-3": "907027046896",
"ca-central-1": "763104351884",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
"eu-north-1": "763104351884",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"eu-south-1": "692866216735",
"me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
"us-west-2": "763104351884",
}

GRAVITON_REGIONS = ACCOUNTS.keys()


def _test_graviton_framework_uris(framework, version):
for region in GRAVITON_REGIONS:
for instance_type in GRAVITON_INSTANCE_TYPES:
uri = image_uris.retrieve(
framework, region, instance_type=instance_type, version=version
)
expected = _expected_graviton_framework_uri(framework, version, region=region)
assert expected == uri


def test_graviton_tensorflow(graviton_tensorflow_version):
_test_graviton_framework_uris("tensorflow", graviton_tensorflow_version)


def test_graviton_pytorch(graviton_pytorch_version):
_test_graviton_framework_uris("pytorch", graviton_pytorch_version)


def _expected_graviton_framework_uri(framework, version, region):
return expected_uris.graviton_framework_uri(
"{}-inference-graviton".format(framework),
fw_version=version,
py_version="py38",
account=ACCOUNTS[region],
region=region,
)