Skip to content

Commit 0201f2a

Browse files
chrstfubenieric
authored andcommitted
feat: AMI support for BRM (#1589)
* feat: AMI support for BRM
1 parent 9a7b6de commit 0201f2a

File tree

7 files changed

+22
-1
lines changed

7 files changed

+22
-1
lines changed

src/sagemaker/jumpstart/factory/model.py

+2
Original file line numberDiff line numberDiff line change
@@ -662,6 +662,7 @@ def get_deploy_kwargs(
662662
config_name: Optional[str] = None,
663663
routing_config: Optional[Dict[str, Any]] = None,
664664
model_access_configs: Optional[Dict[str, ModelAccessConfig]] = None,
665+
inference_ami_version: Optional[str] = None,
665666
) -> JumpStartModelDeployKwargs:
666667
"""Returns kwargs required to call `deploy` on `sagemaker.estimator.Model` object."""
667668

@@ -699,6 +700,7 @@ def get_deploy_kwargs(
699700
config_name=config_name,
700701
routing_config=routing_config,
701702
model_access_configs=model_access_configs,
703+
inference_ami_version=inference_ami_version,
702704
)
703705
deploy_kwargs, orig_session = _set_temp_sagemaker_session_if_not_set(kwargs=deploy_kwargs)
704706
deploy_kwargs.specs = verify_model_region_and_return_specs(

src/sagemaker/jumpstart/hub/interfaces.py

+3
Original file line numberDiff line numberDiff line change
@@ -471,6 +471,7 @@ class HubModelDocument(HubDataHolderType):
471471
"hosting_use_script_uri",
472472
"hosting_eula_uri",
473473
"hosting_model_package_arn",
474+
"inference_ami_version",
474475
"model_subscription_link",
475476
"inference_configs",
476477
"inference_config_components",
@@ -593,6 +594,8 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
593594
self.hosting_eula_uri: Optional[str] = json_obj.get("HostingEulaUri")
594595
self.hosting_model_package_arn: Optional[str] = json_obj.get("HostingModelPackageArn")
595596

597+
self.inference_ami_version: Optional[str] = json_obj.get("InferenceAmiVersion")
598+
596599
self.model_subscription_link: Optional[str] = json_obj.get("ModelSubscriptionLink")
597600

598601
self.inference_config_rankings = self._get_config_rankings(json_obj)

src/sagemaker/jumpstart/hub/parsers.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,11 @@ def get_model_spec_arg_keys(
7272
"""
7373
arg_keys: List[str] = []
7474
if arg_type == ModelSpecKwargType.DEPLOY:
75-
arg_keys = ["ModelDataDownloadTimeout", "ContainerStartupHealthCheckTimeout"]
75+
arg_keys = [
76+
"ModelDataDownloadTimeout",
77+
"ContainerStartupHealthCheckTimeout",
78+
"InferenceAmiVersion",
79+
]
7680
elif arg_type == ModelSpecKwargType.ESTIMATOR:
7781
arg_keys = [
7882
"EncryptInterContainerTraffic",

src/sagemaker/jumpstart/model.py

+2
Original file line numberDiff line numberDiff line change
@@ -666,6 +666,7 @@ def deploy(
666666
endpoint_type: EndpointType = EndpointType.MODEL_BASED,
667667
routing_config: Optional[Dict[str, Any]] = None,
668668
model_access_configs: Optional[Dict[str, ModelAccessConfig]] = None,
669+
inference_ami_version: Optional[str] = None,
669670
) -> PredictorBase:
670671
"""Creates endpoint by calling base ``Model`` class `deploy` method.
671672
@@ -808,6 +809,7 @@ def deploy(
808809
config_name=self.config_name,
809810
routing_config=routing_config,
810811
model_access_configs=model_access_configs,
812+
inference_ami_version=inference_ami_version,
811813
)
812814
if (
813815
self.model_type == JumpStartModelType.PROPRIETARY

src/sagemaker/jumpstart/types.py

+4
Original file line numberDiff line numberDiff line change
@@ -1389,6 +1389,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
13891389
self.hosting_model_package_arns: Optional[Dict] = (
13901390
model_package_arns if model_package_arns is not None else {}
13911391
)
1392+
13921393
self.hosting_use_script_uri: bool = json_obj.get("hosting_use_script_uri", True)
13931394

13941395
self.hosting_instance_type_variants: Optional[JumpStartInstanceTypeVariants] = (
@@ -2245,6 +2246,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs):
22452246
"routing_config",
22462247
"specs",
22472248
"model_access_configs",
2249+
"inference_ami_version",
22482250
]
22492251

22502252
SERIALIZATION_EXCLUSION_SET = {
@@ -2298,6 +2300,7 @@ def __init__(
22982300
config_name: Optional[str] = None,
22992301
routing_config: Optional[Dict[str, Any]] = None,
23002302
model_access_configs: Optional[Dict[str, CoreModelAccessConfig]] = None,
2303+
inference_ami_version: Optional[str] = None,
23012304
) -> None:
23022305
"""Instantiates JumpStartModelDeployKwargs object."""
23032306

@@ -2336,6 +2339,7 @@ def __init__(
23362339
self.config_name = config_name
23372340
self.routing_config = routing_config
23382341
self.model_access_configs = model_access_configs
2342+
self.inference_ami_version = inference_ami_version
23392343

23402344

23412345
class JumpStartEstimatorInitKwargs(JumpStartKwargs):

src/sagemaker/model.py

+2
Original file line numberDiff line numberDiff line change
@@ -1383,6 +1383,7 @@ def deploy(
13831383
inference_component_name=None,
13841384
routing_config: Optional[Dict[str, Any]] = None,
13851385
model_reference_arn: Optional[str] = None,
1386+
inference_ami_version: Optional[str] = None,
13861387
**kwargs,
13871388
):
13881389
"""Deploy this ``Model`` to an ``Endpoint`` and optionally return a ``Predictor``.
@@ -1652,6 +1653,7 @@ def deploy(
16521653
container_startup_health_check_timeout=container_startup_health_check_timeout,
16531654
managed_instance_scaling=managed_instance_scaling_config,
16541655
routing_config=routing_config,
1656+
inference_ami_version=inference_ami_version,
16551657
)
16561658

16571659
self.sagemaker_session.endpoint_from_production_variants(

src/sagemaker/session.py

+4
Original file line numberDiff line numberDiff line change
@@ -7735,6 +7735,7 @@ def production_variant(
77357735
container_startup_health_check_timeout=None,
77367736
managed_instance_scaling=None,
77377737
routing_config=None,
7738+
inference_ami_version=None,
77387739
):
77397740
"""Create a production variant description suitable for use in a ``ProductionVariant`` list.
77407741
@@ -7799,6 +7800,9 @@ def production_variant(
77997800
RoutingConfig=routing_config,
78007801
)
78017802

7803+
if inference_ami_version:
7804+
production_variant_configuration["InferenceAmiVersion"] = inference_ami_version
7805+
78027806
return production_variant_configuration
78037807

78047808

0 commit comments

Comments
 (0)