Skip to content

Commit 61905cc

Browse files
author
Mark Bunday
committed
feature: Graviton support for XGB and SKLearn frameworks
1 parent 5adc2d3 commit 61905cc

File tree

4 files changed

+101
-2
lines changed

4 files changed

+101
-2
lines changed

src/sagemaker/fw_utils.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,43 @@ def validate_source_dir(script, directory):
169169
GRAVITON_ALLOWED_FRAMEWORKS = set(["tensorflow", "pytorch"])
170170

171171

172+
GRAVITON_ALLOWED_REGIONS = {
173+
"af-south-1",
174+
"ap-east-1",
175+
"ap-northeast-1",
176+
"ap-northeast-2",
177+
"ap-northeast-3",
178+
"ap-south-1",
179+
"ap-southeast-1",
180+
"ap-southeast-2",
181+
"ap-southeast-3",
182+
"ca-central-1",
183+
"cn-north-1",
184+
"cn-northwest-1",
185+
"eu-central-1",
186+
"eu-north-1",
187+
"eu-south-1",
188+
"eu-west-1",
189+
"eu-west-2",
190+
"eu-west-3",
191+
"me-south-1",
192+
"sa-east-1",
193+
"us-east-1",
194+
"us-east-2",
195+
"us-gov-west-1",
196+
"us-iso-east-1",
197+
"us-west-1",
198+
"us-west-2",
199+
}
200+
201+
202+
# TODO: Consider combining with GRAVITON_ALLOWED_FRAMEWORKS into a dictionary
203+
XGBOOST_GRAVITON_INFERENCE_ENABLED_VERSIONS = ["1.5-1", "1.3-1"]
204+
205+
206+
SKLEARN_GRAVITON_INFERENCE_ENABLED_VERSIONS = ["1.0-1"]
207+
208+
172209
def validate_source_code_input_against_pipeline_variables(
173210
entry_point: Optional[Union[str, PipelineVariable]] = None,
174211
source_dir: Optional[Union[str, PipelineVariable]] = None,

src/sagemaker/image_uris.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,20 @@
2525
from sagemaker.jumpstart import artifacts
2626
from sagemaker.workflow import is_pipeline_variable
2727
from sagemaker.workflow.utilities import override_pipeline_parameter_var
28-
from sagemaker.fw_utils import GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY, GRAVITON_ALLOWED_FRAMEWORKS
28+
from sagemaker.fw_utils import (
29+
GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY,
30+
GRAVITON_ALLOWED_FRAMEWORKS,
31+
GRAVITON_ALLOWED_REGIONS,
32+
SKLEARN_GRAVITON_INFERENCE_ENABLED_VERSIONS,
33+
XGBOOST_GRAVITON_INFERENCE_ENABLED_VERSIONS,
34+
)
2935

3036
logger = logging.getLogger(__name__)
3137

3238
ECR_URI_TEMPLATE = "{registry}.dkr.{hostname}/{repository}"
3339
HUGGING_FACE_FRAMEWORK = "huggingface"
40+
XGBOOST_FRAMEWORK = "xgboost"
41+
SKLEARN_FRAMEWORK = "sklearn"
3442

3543

3644
@override_pipeline_parameter_var
@@ -244,6 +252,25 @@ def retrieve(
244252
if key in container_versions:
245253
tag = "-".join([tag, container_versions[key]])
246254

255+
if framework in (XGBOOST_FRAMEWORK, SKLEARN_FRAMEWORK):
256+
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
257+
if match and match[1] in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY:
258+
_validate_arg(region, GRAVITON_ALLOWED_REGIONS, "Graviton region")
259+
if framework == XGBOOST_FRAMEWORK:
260+
_validate_arg(
261+
version,
262+
XGBOOST_GRAVITON_INFERENCE_ENABLED_VERSIONS,
263+
"xgboost version for Graviton instances"
264+
)
265+
tag = f"{version}-arm64"
266+
else:
267+
_validate_arg(
268+
version,
269+
SKLEARN_GRAVITON_INFERENCE_ENABLED_VERSIONS,
270+
"sklearn version for Graviton instances"
271+
)
272+
tag = f"{version}-arm64-cpu-py3"
273+
247274
if tag:
248275
repo += ":{}".format(tag)
249276

tests/conftest.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from botocore.config import Config
2323
from packaging.version import Version
2424

25-
from sagemaker import Session, image_uris, utils
25+
from sagemaker import Session, fw_utils, image_uris, utils
2626
from sagemaker.local import LocalSession
2727
from sagemaker.workflow.pipeline_context import PipelineSession, LocalPipelineSession
2828

@@ -318,6 +318,16 @@ def graviton_pytorch_version():
318318
return "1.12.1"
319319

320320

321+
@pytest.fixture(scope="module")
322+
def graviton_xgboost_versions():
323+
return fw_utils.XGBOOST_GRAVITON_INFERENCE_ENABLED_VERSIONS
324+
325+
326+
@pytest.fixture(scope="module")
327+
def graviton_sklearn_versions():
328+
return fw_utils.SKLEARN_GRAVITON_INFERENCE_ENABLED_VERSIONS
329+
330+
321331
@pytest.fixture(scope="module")
322332
def huggingface_tensorflow_latest_training_py_version():
323333
return "py38"

tests/unit/sagemaker/image_uris/test_graviton.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,31 @@ def test_graviton_pytorch(graviton_pytorch_version):
7373
_test_graviton_framework_uris("pytorch", graviton_pytorch_version)
7474

7575

76+
def test_graviton_xgboost(graviton_xgboost_versions):
77+
for xgboost_version in graviton_xgboost_versions:
78+
for instance_type in GRAVITON_INSTANCE_TYPES:
79+
uri = image_uris.retrieve(
80+
"xgboost", "us-west-2", version=xgboost_version, instance_type=instance_type
81+
)
82+
expected = (
83+
"246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-xgboost:"
84+
f"{xgboost_version}-arm64")
85+
assert expected == uri
86+
87+
88+
def test_graviton_sklearn(graviton_sklearn_versions):
89+
for sklearn_version in graviton_sklearn_versions:
90+
for instance_type in GRAVITON_INSTANCE_TYPES:
91+
uri = image_uris.retrieve(
92+
"sklearn", "us-west-2", version=sklearn_version, instance_type=instance_type
93+
)
94+
expected = (
95+
"246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-scikit-learn:"
96+
f"{sklearn_version}-arm64-cpu-py3"
97+
)
98+
assert expected == uri
99+
100+
76101
def _expected_graviton_framework_uri(framework, version, region):
77102
return expected_uris.graviton_framework_uri(
78103
"{}-inference-graviton".format(framework),

0 commit comments

Comments
 (0)