Skip to content

Commit 8fff159

Browse files
nikhil-skNikhil Kulkarni
and
Nikhil Kulkarni
authored
fix: add a condition to retrieve correct image URI for xgboost (#1938)
Co-authored-by: Nikhil Kulkarni <[email protected]>
1 parent 7e2ee02 commit 8fff159

File tree

1 file changed

+11
-2
lines changed

1 file changed

+11
-2
lines changed

src/sagemaker/model.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -274,9 +274,18 @@ def _compilation_image_uri(self, region, target_instance_type, framework, framew
274274
framework (str): The framework name.
275275
framework_version (str): The framework version.
276276
"""
277-
framework_prefix = "inferentia-" if target_instance_type.startswith("ml_inf") else "neo-"
277+
framework_prefix = ""
278+
framework_suffix = ""
279+
280+
if framework == "xgboost":
281+
framework_suffix = "-neo"
282+
elif target_instance_type.startswith("ml_inf"):
283+
framework_prefix = "inferentia-"
284+
else:
285+
framework_prefix = "neo-"
286+
278287
return image_uris.retrieve(
279-
"{}{}".format(framework_prefix, framework),
288+
"{}{}{}".format(framework_prefix, framework, framework_suffix),
280289
region,
281290
instance_type=target_instance_type,
282291
version=framework_version,

0 commit comments

Comments
 (0)