Skip to content

Commit 9ead9c8

Browse files
rsareddy0329Roja Reddy Sareddy
and
Roja Reddy Sareddy
authored
feature: Enabled update_endpoint through model_builder (aws#5085)
* feature: Enabled update_endpoint through model_builder * fix: fix unit test, black-check, pylint errors * fix: fix black-check, pylint errors --------- Co-authored-by: Roja Reddy Sareddy <[email protected]>
1 parent 65482fa commit 9ead9c8

File tree

8 files changed

+330
-18
lines changed

8 files changed

+330
-18
lines changed

src/sagemaker/huggingface/model.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,7 @@ def deploy(
218218
container_startup_health_check_timeout=None,
219219
inference_recommendation_id=None,
220220
explainer_config=None,
221+
update_endpoint: Optional[bool] = False,
221222
**kwargs,
222223
):
223224
"""Deploy this ``Model`` to an ``Endpoint`` and optionally return a ``Predictor``.
@@ -296,6 +297,11 @@ def deploy(
296297
would like to deploy the model and endpoint with recommended parameters.
297298
explainer_config (sagemaker.explainer.ExplainerConfig): Specifies online explainability
298299
configuration for use with Amazon SageMaker Clarify. (default: None)
300+
update_endpoint (Optional[bool]):
301+
Flag to update the model in an existing Amazon SageMaker endpoint.
302+
If True, this will deploy a new EndpointConfig to an already existing endpoint
303+
and delete resources corresponding to the previous EndpointConfig. Default: False
304+
Note: Currently this is supported for single model endpoints
299305
Raises:
300306
ValueError: If arguments combination check failed in these circumstances:
301307
- If no role is specified or
@@ -335,6 +341,7 @@ def deploy(
335341
container_startup_health_check_timeout=container_startup_health_check_timeout,
336342
inference_recommendation_id=inference_recommendation_id,
337343
explainer_config=explainer_config,
344+
update_endpoint=update_endpoint,
338345
**kwargs,
339346
)
340347

src/sagemaker/model.py

Lines changed: 42 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@
5353
from sagemaker.model_card.schema_constraints import ModelApprovalStatusEnum
5454
from sagemaker.session import Session
5555
from sagemaker.model_metrics import ModelMetrics
56-
from sagemaker.deprecations import removed_kwargs
5756
from sagemaker.drift_check_baselines import DriftCheckBaselines
5857
from sagemaker.explainer import ExplainerConfig
5958
from sagemaker.metadata_properties import MetadataProperties
@@ -1386,6 +1385,7 @@ def deploy(
13861385
routing_config: Optional[Dict[str, Any]] = None,
13871386
model_reference_arn: Optional[str] = None,
13881387
inference_ami_version: Optional[str] = None,
1388+
update_endpoint: Optional[bool] = False,
13891389
**kwargs,
13901390
):
13911391
"""Deploy this ``Model`` to an ``Endpoint`` and optionally return a ``Predictor``.
@@ -1497,6 +1497,11 @@ def deploy(
14971497
inference_ami_version (Optional [str]): Specifies an option from a collection of preconfigured
14981498
Amazon Machine Image (AMI) images. For a full list of options, see:
14991499
https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_ProductionVariant.html
1500+
update_endpoint (Optional[bool]):
1501+
Flag to update the model in an existing Amazon SageMaker endpoint.
1502+
If True, this will deploy a new EndpointConfig to an already existing endpoint
1503+
and delete resources corresponding to the previous EndpointConfig. Default: False
1504+
Note: Currently this is supported for single model endpoints
15001505
Raises:
15011506
ValueError: If arguments combination check failed in these circumstances:
15021507
- If no role is specified or
@@ -1512,8 +1517,6 @@ def deploy(
15121517
"""
15131518
self.accept_eula = accept_eula
15141519

1515-
removed_kwargs("update_endpoint", kwargs)
1516-
15171520
self._init_sagemaker_session_if_does_not_exist(instance_type)
15181521
# Depending on the instance type, a local session (or) a session is initialized.
15191522
self.role = resolve_value_from_config(
@@ -1628,6 +1631,10 @@ def deploy(
16281631

16291632
# Support multiple models on same endpoint
16301633
if endpoint_type == EndpointType.INFERENCE_COMPONENT_BASED:
1634+
if update_endpoint:
1635+
raise ValueError(
1636+
"Currently update_endpoint is supported for single model endpoints"
1637+
)
16311638
if endpoint_name:
16321639
self.endpoint_name = endpoint_name
16331640
else:
@@ -1783,17 +1790,38 @@ def deploy(
17831790
if is_explainer_enabled:
17841791
explainer_config_dict = explainer_config._to_request_dict()
17851792

1786-
self.sagemaker_session.endpoint_from_production_variants(
1787-
name=self.endpoint_name,
1788-
production_variants=[production_variant],
1789-
tags=tags,
1790-
kms_key=kms_key,
1791-
wait=wait,
1792-
data_capture_config_dict=data_capture_config_dict,
1793-
explainer_config_dict=explainer_config_dict,
1794-
async_inference_config_dict=async_inference_config_dict,
1795-
live_logging=endpoint_logging,
1796-
)
1793+
if update_endpoint:
1794+
endpoint_config_name = self.sagemaker_session.create_endpoint_config(
1795+
name=self.name,
1796+
model_name=self.name,
1797+
initial_instance_count=initial_instance_count,
1798+
instance_type=instance_type,
1799+
accelerator_type=accelerator_type,
1800+
tags=tags,
1801+
kms_key=kms_key,
1802+
data_capture_config_dict=data_capture_config_dict,
1803+
volume_size=volume_size,
1804+
model_data_download_timeout=model_data_download_timeout,
1805+
container_startup_health_check_timeout=container_startup_health_check_timeout,
1806+
explainer_config_dict=explainer_config_dict,
1807+
async_inference_config_dict=async_inference_config_dict,
1808+
serverless_inference_config=serverless_inference_config_dict,
1809+
routing_config=routing_config,
1810+
inference_ami_version=inference_ami_version,
1811+
)
1812+
self.sagemaker_session.update_endpoint(self.endpoint_name, endpoint_config_name)
1813+
else:
1814+
self.sagemaker_session.endpoint_from_production_variants(
1815+
name=self.endpoint_name,
1816+
production_variants=[production_variant],
1817+
tags=tags,
1818+
kms_key=kms_key,
1819+
wait=wait,
1820+
data_capture_config_dict=data_capture_config_dict,
1821+
explainer_config_dict=explainer_config_dict,
1822+
async_inference_config_dict=async_inference_config_dict,
1823+
live_logging=endpoint_logging,
1824+
)
17971825

17981826
if self.predictor_cls:
17991827
predictor = self.predictor_cls(self.endpoint_name, self.sagemaker_session)

src/sagemaker/serve/builder/model_builder.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1602,6 +1602,7 @@ def deploy(
16021602
ResourceRequirements,
16031603
]
16041604
] = None,
1605+
update_endpoint: Optional[bool] = False,
16051606
) -> Union[Predictor, Transformer]:
16061607
"""Deploys the built Model.
16071608
@@ -1615,24 +1616,33 @@ def deploy(
16151616
AsyncInferenceConfig, BatchTransformInferenceConfig, ResourceRequirements]]) :
16161617
Additional Config for different deployment types such as
16171618
serverless, async, batch and multi-model/container
1619+
update_endpoint (Optional[bool]):
1620+
Flag to update the model in an existing Amazon SageMaker endpoint.
1621+
If True, this will deploy a new EndpointConfig to an already existing endpoint
1622+
and delete resources corresponding to the previous EndpointConfig. Default: False
1623+
Note: Currently this is supported for single model endpoints
16181624
Returns:
16191625
Transformer for Batch Deployments
16201626
Predictors for all others
16211627
"""
16221628
if not hasattr(self, "built_model"):
16231629
raise ValueError("Model Needs to be built before deploying")
1624-
endpoint_name = unique_name_from_base(endpoint_name)
1630+
if not update_endpoint:
1631+
endpoint_name = unique_name_from_base(endpoint_name)
1632+
16251633
if not inference_config: # Real-time Deployment
16261634
return self.built_model.deploy(
16271635
instance_type=self.instance_type,
16281636
initial_instance_count=initial_instance_count,
16291637
endpoint_name=endpoint_name,
1638+
update_endpoint=update_endpoint,
16301639
)
16311640

16321641
if isinstance(inference_config, ServerlessInferenceConfig):
16331642
return self.built_model.deploy(
16341643
serverless_inference_config=inference_config,
16351644
endpoint_name=endpoint_name,
1645+
update_endpoint=update_endpoint,
16361646
)
16371647

16381648
if isinstance(inference_config, AsyncInferenceConfig):
@@ -1641,6 +1651,7 @@ def deploy(
16411651
initial_instance_count=initial_instance_count,
16421652
async_inference_config=inference_config,
16431653
endpoint_name=endpoint_name,
1654+
update_endpoint=update_endpoint,
16441655
)
16451656

16461657
if isinstance(inference_config, BatchTransformInferenceConfig):
@@ -1652,6 +1663,10 @@ def deploy(
16521663
return transformer
16531664

16541665
if isinstance(inference_config, ResourceRequirements):
1666+
if update_endpoint:
1667+
raise ValueError(
1668+
"Currently update_endpoint is supported for single model endpoints"
1669+
)
16551670
# Multi Model and MultiContainer endpoints with Inference Component
16561671
return self.built_model.deploy(
16571672
instance_type=self.instance_type,
@@ -1660,6 +1675,7 @@ def deploy(
16601675
resources=inference_config,
16611676
initial_instance_count=initial_instance_count,
16621677
role=self.role_arn,
1678+
update_endpoint=update_endpoint,
16631679
)
16641680

16651681
raise ValueError("Deployment Options not supported")

src/sagemaker/session.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4488,6 +4488,10 @@ def create_endpoint_config(
44884488
model_data_download_timeout=None,
44894489
container_startup_health_check_timeout=None,
44904490
explainer_config_dict=None,
4491+
async_inference_config_dict=None,
4492+
serverless_inference_config_dict=None,
4493+
routing_config: Optional[Dict[str, Any]] = None,
4494+
inference_ami_version: Optional[str] = None,
44914495
):
44924496
"""Create an Amazon SageMaker endpoint configuration.
44934497
@@ -4525,6 +4529,30 @@ def create_endpoint_config(
45254529
-inference-algo-ping-requests
45264530
explainer_config_dict (dict): Specifies configuration to enable explainers.
45274531
Default: None.
4532+
async_inference_config_dict (dict): Specifies
4533+
configuration related to async endpoint. Use this configuration when trying
4534+
to create async endpoint and make async inference. If empty config object
4535+
passed through, will use default config to deploy async endpoint. Deploy a
4536+
real-time endpoint if it's None. (default: None).
4537+
serverless_inference_config_dict (dict):
4538+
Specifies configuration related to serverless endpoint. Use this configuration
4539+
when trying to create serverless endpoint and make serverless inference. If
4540+
empty object passed through, will use pre-defined values in
4541+
``ServerlessInferenceConfig`` class to deploy serverless endpoint. Deploy an
4542+
instance based endpoint if it's None. (default: None).
4543+
routing_config (Optional[Dict[str, Any]): Settings the control how the endpoint routes
4544+
incoming traffic to the instances that the endpoint hosts.
4545+
Currently, support dictionary key ``RoutingStrategy``.
4546+
4547+
.. code:: python
4548+
4549+
{
4550+
"RoutingStrategy": sagemaker.enums.RoutingStrategy.RANDOM
4551+
}
4552+
inference_ami_version (Optional [str]):
4553+
Specifies an option from a collection of preconfigured
4554+
Amazon Machine Image (AMI) images. For a full list of options, see:
4555+
https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_ProductionVariant.html
45284556
45294557
Example:
45304558
>>> tags = [{'Key': 'tagname', 'Value': 'tagvalue'}]
@@ -4544,9 +4572,12 @@ def create_endpoint_config(
45444572
instance_type,
45454573
initial_instance_count,
45464574
accelerator_type=accelerator_type,
4575+
serverless_inference_config=serverless_inference_config_dict,
45474576
volume_size=volume_size,
45484577
model_data_download_timeout=model_data_download_timeout,
45494578
container_startup_health_check_timeout=container_startup_health_check_timeout,
4579+
routing_config=routing_config,
4580+
inference_ami_version=inference_ami_version,
45504581
)
45514582
production_variants = [provided_production_variant]
45524583
# Currently we just inject CoreDumpConfig.KmsKeyId from the config for production variant.
@@ -4586,6 +4617,14 @@ def create_endpoint_config(
45864617
)
45874618
request["DataCaptureConfig"] = inferred_data_capture_config_dict
45884619

4620+
if async_inference_config_dict is not None:
4621+
inferred_async_inference_config_dict = update_nested_dictionary_with_values_from_config(
4622+
async_inference_config_dict,
4623+
ENDPOINT_CONFIG_ASYNC_INFERENCE_PATH,
4624+
sagemaker_session=self,
4625+
)
4626+
request["AsyncInferenceConfig"] = inferred_async_inference_config_dict
4627+
45894628
if explainer_config_dict is not None:
45904629
request["ExplainerConfig"] = explainer_config_dict
45914630

src/sagemaker/tensorflow/model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,7 @@ def deploy(
358358
container_startup_health_check_timeout=None,
359359
inference_recommendation_id=None,
360360
explainer_config=None,
361+
update_endpoint: Optional[bool] = False,
361362
**kwargs,
362363
):
363364
"""Deploy a Tensorflow ``Model`` to a SageMaker ``Endpoint``."""
@@ -383,6 +384,7 @@ def deploy(
383384
container_startup_health_check_timeout=container_startup_health_check_timeout,
384385
inference_recommendation_id=inference_recommendation_id,
385386
explainer_config=explainer_config,
387+
update_endpoint=update_endpoint,
386388
**kwargs,
387389
)
388390

tests/unit/sagemaker/jumpstart/model/test_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -794,7 +794,7 @@ def test_jumpstart_model_kwargs_match_parent_class(self):
794794
and reach out to JumpStart team."""
795795

796796
init_args_to_skip: Set[str] = set(["model_reference_arn"])
797-
deploy_args_to_skip: Set[str] = set(["kwargs", "model_reference_arn"])
797+
deploy_args_to_skip: Set[str] = set(["kwargs", "model_reference_arn", "update_endpoint"])
798798
deploy_args_removed_at_deploy_time: Set[str] = set(["model_access_configs"])
799799

800800
parent_class_init = Model.__init__

0 commit comments

Comments
 (0)