diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index 247b04cf79..38006ea136 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -274,9 +274,18 @@ def _compilation_image_uri(self, region, target_instance_type, framework, framew framework (str): The framework name. framework_version (str): The framework version. """ - framework_prefix = "inferentia-" if target_instance_type.startswith("ml_inf") else "neo-" + framework_prefix = "" + framework_suffix = "" + + if framework == "xgboost": + framework_suffix = "-neo" + elif target_instance_type.startswith("ml_inf"): + framework_prefix = "inferentia-" + else: + framework_prefix = "neo-" + return image_uris.retrieve( - "{}{}".format(framework_prefix, framework), + "{}{}{}".format(framework_prefix, framework, framework_suffix), region, instance_type=target_instance_type, version=framework_version,