Skip to content

Commit 86f56c3

Browse files
author
Mark Bunday
committed
feature: Graviton support for XGB and SKLearn frameworks
1 parent 5adc2d3 commit 86f56c3

File tree

4 files changed

+100
-6
lines changed

4 files changed

+100
-6
lines changed

src/sagemaker/fw_utils.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,43 @@ def validate_source_dir(script, directory):
169169
GRAVITON_ALLOWED_FRAMEWORKS = set(["tensorflow", "pytorch"])
170170

171171

172+
GRAVITON_ALLOWED_REGIONS = {
173+
"af-south-1",
174+
"ap-east-1",
175+
"ap-northeast-1",
176+
"ap-northeast-2",
177+
"ap-northeast-3",
178+
"ap-south-1",
179+
"ap-southeast-1",
180+
"ap-southeast-2",
181+
"ap-southeast-3",
182+
"ca-central-1",
183+
"cn-north-1",
184+
"cn-northwest-1",
185+
"eu-central-1",
186+
"eu-north-1",
187+
"eu-south-1",
188+
"eu-west-1",
189+
"eu-west-2",
190+
"eu-west-3",
191+
"me-south-1",
192+
"sa-east-1",
193+
"us-east-1",
194+
"us-east-2",
195+
"us-gov-west-1",
196+
"us-iso-east-1",
197+
"us-west-1",
198+
"us-west-2",
199+
}
200+
201+
202+
# TODO: Consider combining with GRAVITON_ALLOWED_FRAMEWORKS into a dictionary
203+
XGBOOST_GRAVITON_INFERENCE_ENABLED_VERSIONS = ["1.5-1", "1.3-1"]
204+
205+
206+
SKLEARN_GRAVITON_INFERENCE_ENABLED_VERSIONS = ["1.0-1"]
207+
208+
172209
def validate_source_code_input_against_pipeline_variables(
173210
entry_point: Optional[Union[str, PipelineVariable]] = None,
174211
source_dir: Optional[Union[str, PipelineVariable]] = None,

src/sagemaker/image_uris.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,21 @@
2525
from sagemaker.jumpstart import artifacts
2626
from sagemaker.workflow import is_pipeline_variable
2727
from sagemaker.workflow.utilities import override_pipeline_parameter_var
28-
from sagemaker.fw_utils import GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY, GRAVITON_ALLOWED_FRAMEWORKS
28+
from sagemaker.fw_utils import (
29+
GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY,
30+
GRAVITON_ALLOWED_FRAMEWORKS,
31+
GRAVITON_ALLOWED_REGIONS,
32+
SKLEARN_GRAVITON_INFERENCE_ENABLED_VERSIONS,
33+
XGBOOST_GRAVITON_INFERENCE_ENABLED_VERSIONS,
34+
)
2935

3036
logger = logging.getLogger(__name__)
3137

3238
ECR_URI_TEMPLATE = "{registry}.dkr.{hostname}/{repository}"
3339
HUGGING_FACE_FRAMEWORK = "huggingface"
40+
XGBOOST_FRAMEWORK = "xgboost"
41+
SKLEARN_FRAMEWORK = "sklearn"
42+
INSTANCE_TYPE_REGEX = r"^ml[\._]([a-z\d]+)\.?\w*$"
3443

3544

3645
@override_pipeline_parameter_var
@@ -244,6 +253,18 @@ def retrieve(
244253
if key in container_versions:
245254
tag = "-".join([tag, container_versions[key]])
246255

256+
if framework in (XGBOOST_FRAMEWORK, SKLEARN_FRAMEWORK):
257+
match = re.match(INSTANCE_TYPE_REGEX, instance_type)
258+
if match and match[1] in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY:
259+
_validate_arg(region, GRAVITON_ALLOWED_REGIONS, "Graviton region")
260+
arg_name = f"{framework} version for Graviton instances"
261+
if framework == XGBOOST_FRAMEWORK:
262+
_validate_arg(version, XGBOOST_GRAVITON_INFERENCE_ENABLED_VERSIONS, arg_name)
263+
tag = f"{version}-arm64"
264+
else:
265+
_validate_arg(version, SKLEARN_GRAVITON_INFERENCE_ENABLED_VERSIONS, arg_name)
266+
tag = f"{version}-arm64-cpu-py3"
267+
247268
if tag:
248269
repo += ":{}".format(tag)
249270

@@ -295,7 +316,7 @@ def config_for_framework(framework):
295316
def _get_image_scope_for_instance_type(framework, instance_type, image_scope):
296317
"""Extract the image scope from instance type."""
297318
if framework in GRAVITON_ALLOWED_FRAMEWORKS and isinstance(instance_type, str):
298-
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
319+
match = re.match(INSTANCE_TYPE_REGEX, instance_type)
299320
if match and match[1] in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY:
300321
return "inference_graviton"
301322
return image_scope
@@ -304,7 +325,7 @@ def _get_image_scope_for_instance_type(framework, instance_type, image_scope):
304325
def _get_inference_tool(inference_tool, instance_type):
305326
"""Extract the inference tool name from instance type."""
306327
if not inference_tool and instance_type:
307-
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
328+
match = re.match(INSTANCE_TYPE_REGEX, instance_type)
308329
if match and match[1].startswith("inf"):
309330
return "neuron"
310331
return inference_tool
@@ -385,7 +406,7 @@ def _processor(instance_type, available_processors, serverless_inference_config=
385406
processor = "neuron"
386407
else:
387408
# looks for either "ml.<family>.<size>" or "ml_<family>"
388-
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
409+
match = re.match(INSTANCE_TYPE_REGEX, instance_type)
389410
if match:
390411
family = match[1]
391412

@@ -415,7 +436,7 @@ def _should_auto_select_container_version(instance_type, distribution):
415436
p4d = False
416437
if instance_type:
417438
# looks for either "ml.<family>.<size>" or "ml_<family>"
418-
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
439+
match = re.match(INSTANCE_TYPE_REGEX, instance_type)
419440
if match:
420441
family = match[1]
421442
p4d = family == "p4d"

tests/conftest.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from botocore.config import Config
2323
from packaging.version import Version
2424

25-
from sagemaker import Session, image_uris, utils
25+
from sagemaker import Session, fw_utils, image_uris, utils
2626
from sagemaker.local import LocalSession
2727
from sagemaker.workflow.pipeline_context import PipelineSession, LocalPipelineSession
2828

@@ -318,6 +318,16 @@ def graviton_pytorch_version():
318318
return "1.12.1"
319319

320320

321+
@pytest.fixture(scope="module")
322+
def graviton_xgboost_versions():
323+
return fw_utils.XGBOOST_GRAVITON_INFERENCE_ENABLED_VERSIONS
324+
325+
326+
@pytest.fixture(scope="module")
327+
def graviton_sklearn_versions():
328+
return fw_utils.SKLEARN_GRAVITON_INFERENCE_ENABLED_VERSIONS
329+
330+
321331
@pytest.fixture(scope="module")
322332
def huggingface_tensorflow_latest_training_py_version():
323333
return "py38"

tests/unit/sagemaker/image_uris/test_graviton.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,32 @@ def test_graviton_pytorch(graviton_pytorch_version):
7373
_test_graviton_framework_uris("pytorch", graviton_pytorch_version)
7474

7575

76+
def test_graviton_xgboost(graviton_xgboost_versions):
77+
for xgboost_version in graviton_xgboost_versions:
78+
for instance_type in GRAVITON_INSTANCE_TYPES:
79+
uri = image_uris.retrieve(
80+
"xgboost", "us-west-2", version=xgboost_version, instance_type=instance_type
81+
)
82+
expected = (
83+
"246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-xgboost:"
84+
f"{xgboost_version}-arm64"
85+
)
86+
assert expected == uri
87+
88+
89+
def test_graviton_sklearn(graviton_sklearn_versions):
90+
for sklearn_version in graviton_sklearn_versions:
91+
for instance_type in GRAVITON_INSTANCE_TYPES:
92+
uri = image_uris.retrieve(
93+
"sklearn", "us-west-2", version=sklearn_version, instance_type=instance_type
94+
)
95+
expected = (
96+
"246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-scikit-learn:"
97+
f"{sklearn_version}-arm64-cpu-py3"
98+
)
99+
assert expected == uri
100+
101+
76102
def _expected_graviton_framework_uri(framework, version, region):
77103
return expected_uris.graviton_framework_uri(
78104
"{}-inference-graviton".format(framework),

0 commit comments

Comments
 (0)