Skip to content

Commit 9fb828a

Browse files
committed
fix: update_endpoint for PipelineModel
1 parent c53472b commit 9fb828a

File tree

1 file changed

+27
-4
lines changed

1 file changed

+27
-4
lines changed

src/sagemaker/pipeline.py

+27-4
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,13 @@ def pipeline_container_def(self, instance_type):
8383
return sagemaker.pipeline_container_def(self.models, instance_type)
8484

8585
def deploy(
86-
self, initial_instance_count, instance_type, endpoint_name=None, tags=None, wait=True
86+
self,
87+
initial_instance_count,
88+
instance_type,
89+
endpoint_name=None,
90+
tags=None,
91+
wait=True,
92+
update_endpoint=False,
8793
):
8894
"""Deploy this ``Model`` to an ``Endpoint`` and optionally return a
8995
``Predictor``.
@@ -110,6 +116,11 @@ def deploy(
110116
specific endpoint.
111117
wait (bool): Whether the call should wait until the deployment of
112118
model completes (default: True).
119+
update_endpoint (bool): Flag to update the model in an existing
120+
Amazon SageMaker endpoint. If True, this will deploy a new
121+
EndpointConfig to an already existing endpoint and delete
122+
resources corresponding to the previous EndpointConfig. If
123+
False, a new endpoint will be created. Default: False
113124
114125
Returns:
115126
callable[string, sagemaker.session.Session] or None: Invocation of
@@ -130,9 +141,21 @@ def deploy(
130141
self.name, instance_type, initial_instance_count
131142
)
132143
self.endpoint_name = endpoint_name or self.name
133-
self.sagemaker_session.endpoint_from_production_variants(
134-
self.endpoint_name, [production_variant], tags, wait=wait
135-
)
144+
145+
if update_endpoint:
146+
endpoint_config_name = self.sagemaker_session.create_endpoint_config(
147+
name=self.name,
148+
model_name=self.name,
149+
initial_instance_count=initial_instance_count,
150+
instance_type=instance_type,
151+
tags=tags,
152+
)
153+
self.sagemaker_session.update_endpoint(self.endpoint_name, endpoint_config_name)
154+
else:
155+
self.sagemaker_session.endpoint_from_production_variants(
156+
self.endpoint_name, [production_variant], tags, wait=wait
157+
)
158+
136159
if self.predictor_cls:
137160
return self.predictor_cls(self.endpoint_name, self.sagemaker_session)
138161
return None

0 commit comments

Comments
 (0)