Skip to content

Commit ced2e8f

Browse files
author
Mark Bunday
committed
feature: Graviton support for XGB and SKLearn frameworks
1 parent 5adc2d3 commit ced2e8f

File tree

4 files changed

+109
-5
lines changed

4 files changed

+109
-5
lines changed

src/sagemaker/fw_utils.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -134,13 +134,10 @@
134134
"1.12.0",
135135
]
136136

137-
138137
TORCH_DISTRIBUTED_SUPPORTED_FRAMEWORK_VERSIONS = ["1.11", "1.11.0"]
139138

140-
141139
TRAINIUM_SUPPORTED_DISTRIBUTION_STRATEGIES = ["torch_distributed"]
142140

143-
144141
SMDISTRIBUTED_SUPPORTED_STRATEGIES = ["dataparallel", "modelparallel"]
145142

146143

@@ -169,6 +166,43 @@ def validate_source_dir(script, directory):
169166
GRAVITON_ALLOWED_FRAMEWORKS = set(["tensorflow", "pytorch"])
170167

171168

169+
GRAVITON_ALLOWED_REGIONS = {
170+
"af-south-1",
171+
"ap-east-1",
172+
"ap-northeast-1",
173+
"ap-northeast-2",
174+
"ap-northeast-3",
175+
"ap-south-1",
176+
"ap-southeast-1",
177+
"ap-southeast-2",
178+
"ap-southeast-3",
179+
"ca-central-1",
180+
"cn-north-1",
181+
"cn-northwest-1",
182+
"eu-central-1",
183+
"eu-north-1",
184+
"eu-south-1",
185+
"eu-west-1",
186+
"eu-west-2",
187+
"eu-west-3",
188+
"me-south-1",
189+
"sa-east-1",
190+
"us-east-1",
191+
"us-east-2",
192+
"us-gov-west-1",
193+
"us-iso-east-1",
194+
"us-west-1",
195+
"us-west-2",
196+
}
197+
198+
199+
# TODO: Consider combining with GRAVITON_ALLOWED_FRAMEWORKS into a dictionary
200+
XGBOOST_GRAVITON_INFERENCE_ENABLED_VERSIONS = ["1.5-1", "1.3-1"]
201+
202+
203+
SKLEARN_GRAVITON_INFERENCE_ENABLED_VERSIONS = ["1.0-1"]
204+
205+
172206
def validate_source_code_input_against_pipeline_variables(
173207
entry_point: Optional[Union[str, PipelineVariable]] = None,
174208
source_dir: Optional[Union[str, PipelineVariable]] = None,

src/sagemaker/image_uris.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,20 @@
2525
from sagemaker.jumpstart import artifacts
2626
from sagemaker.workflow import is_pipeline_variable
2727
from sagemaker.workflow.utilities import override_pipeline_parameter_var
28-
from sagemaker.fw_utils import GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY, GRAVITON_ALLOWED_FRAMEWORKS
28+
from sagemaker.fw_utils import (
29+
GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY,
30+
GRAVITON_ALLOWED_FRAMEWORKS,
31+
GRAVITON_ALLOWED_REGIONS,
32+
SKLEARN_GRAVITON_INFERENCE_ENABLED_VERSIONS,
33+
XGBOOST_GRAVITON_INFERENCE_ENABLED_VERSIONS,
34+
)
2935

3036
logger = logging.getLogger(__name__)
3137

3238
ECR_URI_TEMPLATE = "{registry}.dkr.{hostname}/{repository}"
3339
HUGGING_FACE_FRAMEWORK = "huggingface"
40+
XGBOOST_FRAMEWORK = "xgboost"
41+
SKLEARN_FRAMEWORK = "sklearn"
3442

3543

3644
@override_pipeline_parameter_var
@@ -244,6 +252,35 @@ def retrieve(
244252
if key in container_versions:
245253
tag = "-".join([tag, container_versions[key]])
246254

255+
if framework in (XGBOOST_FRAMEWORK, SKLEARN_FRAMEWORK):
256+
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
257+
if match and match[1] in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY:
258+
if region not in GRAVITON_ALLOWED_REGIONS:
259+
raise ValueError(
260+
f"Unsupported Graviton region: {region}. "
261+
"You may need to upgrade your SDK version (pip install -U sagemaker) "
262+
"for newer regions. Graviton supported region(s): "
263+
f"{', '.join(set(GRAVITON_ALLOWED_REGIONS))}."
264+
)
265+
if framework == XGBOOST_FRAMEWORK:
266+
if version not in XGBOOST_GRAVITON_INFERENCE_ENABLED_VERSIONS:
267+
raise ValueError(
268+
f"Unsupported xgboost version for graviton instances: {version}. "
269+
"You may need to upgrade your SDK version (pip install -U sagemaker) "
270+
"for newer versions. Supported version(s): "
271+
f"{', '.join(XGBOOST_GRAVITON_INFERENCE_ENABLED_VERSIONS)}."
272+
)
273+
tag = f"{version}-arm64"
274+
else:
275+
if version not in SKLEARN_GRAVITON_INFERENCE_ENABLED_VERSIONS:
276+
raise ValueError(
277+
f"Unsupported sklearn version for graviton instances: {version}. "
278+
"You may need to upgrade your SDK version (pip install -U sagemaker) "
279+
"for newer versions. Supported version(s): "
280+
f"{', '.join(SKLEARN_GRAVITON_INFERENCE_ENABLED_VERSIONS)}."
281+
)
282+
tag = f"{version}-arm64-cpu-py3"
283+
247284
if tag:
248285
repo += ":{}".format(tag)
249286

tests/conftest.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from botocore.config import Config
2323
from packaging.version import Version
2424

25-
from sagemaker import Session, image_uris, utils
25+
from sagemaker import Session, fw_utils, image_uris, utils
2626
from sagemaker.local import LocalSession
2727
from sagemaker.workflow.pipeline_context import PipelineSession, LocalPipelineSession
2828

@@ -318,6 +318,16 @@ def graviton_pytorch_version():
318318
return "1.12.1"
319319

320320

321+
@pytest.fixture(scope="module")
322+
def graviton_xgboost_versions():
323+
return fw_utils.XGBOOST_GRAVITON_INFERENCE_ENABLED_VERSIONS
324+
325+
326+
@pytest.fixture(scope="module")
327+
def graviton_sklearn_versions():
328+
return fw_utils.SKLEARN_GRAVITON_INFERENCE_ENABLED_VERSIONS
329+
330+
321331
@pytest.fixture(scope="module")
322332
def huggingface_tensorflow_latest_training_py_version():
323333
return "py38"

tests/unit/sagemaker/image_uris/test_graviton.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,29 @@ def test_graviton_pytorch(graviton_pytorch_version):
7373
_test_graviton_framework_uris("pytorch", graviton_pytorch_version)
7474

7575

76+
def test_graviton_xgboost(graviton_xgboost_versions):
77+
for xgboost_version in graviton_xgboost_versions:
78+
for instance_type in GRAVITON_INSTANCE_TYPES:
79+
uri = image_uris.retrieve(
80+
"xgboost", "us-west-2", version=xgboost_version, instance_type=instance_type
81+
)
82+
expected = f"246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-xgboost:{xgboost_version}-arm64"
83+
assert expected == uri
84+
85+
86+
def test_graviton_sklearn(graviton_sklearn_versions):
87+
for sklearn_version in graviton_sklearn_versions:
88+
for instance_type in GRAVITON_INSTANCE_TYPES:
89+
uri = image_uris.retrieve(
90+
"sklearn", "us-west-2", version=sklearn_version, instance_type=instance_type
91+
)
92+
expected = (
93+
"246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-scikit-learn:"
94+
f"{sklearn_version}-arm64-cpu-py3"
95+
)
96+
assert expected == uri
97+
98+
7699
def _expected_graviton_framework_uri(framework, version, region):
77100
return expected_uris.graviton_framework_uri(
78101
"{}-inference-graviton".format(framework),

0 commit comments

Comments
 (0)