Skip to content

Commit ecfa7b5

Browse files
knikureJohnaAtAWSverayu43bhaoz
committed
feat: Goldfinch InferenceComponent integration (#1311)
Co-authored-by: JohnaAtAWS <[email protected]> Co-authored-by: Vera Yu <[email protected]> Co-authored-by: bhaoz <[email protected]>
1 parent 055191c commit ecfa7b5

28 files changed

+1842
-112
lines changed

src/sagemaker/base_predictor.py

+182-17
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@
1414
from __future__ import print_function, absolute_import
1515

1616
import abc
17-
from typing import Any, Optional, Tuple, Union
17+
from typing import Any, Dict, Optional, Tuple, Union
18+
import logging
1819

20+
from sagemaker.enums import EndpointType
1921
from sagemaker.deprecations import (
2022
deprecated_class,
2123
deprecated_deserialize,
@@ -55,6 +57,9 @@
5557
from sagemaker.model_monitor.model_monitoring import DEFAULT_REPOSITORY_NAME
5658

5759
from sagemaker.lineage.context import EndpointContext
60+
from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements
61+
62+
LOGGER = logging.getLogger("sagemaker")
5863

5964

6065
class PredictorBase(abc.ABC):
@@ -92,6 +97,7 @@ def __init__(
9297
sagemaker_session=None,
9398
serializer=IdentitySerializer(),
9499
deserializer=BytesDeserializer(),
100+
component_name=None,
95101
**kwargs,
96102
):
97103
"""Initialize a ``Predictor``.
@@ -115,11 +121,14 @@ def __init__(
115121
deserializer (:class:`~sagemaker.deserializers.BaseDeserializer`): A
116122
deserializer object, used to decode data from an inference
117123
endpoint (default: :class:`~sagemaker.deserializers.BytesDeserializer`).
124+
component_name (str): Name of the Amazon SageMaker inference component
125+
corresponding the predictor.
118126
"""
119127
removed_kwargs("content_type", kwargs)
120128
removed_kwargs("accept", kwargs)
121129
endpoint_name = renamed_kwargs("endpoint", "endpoint_name", endpoint_name, kwargs)
122130
self.endpoint_name = endpoint_name
131+
self.component_name = component_name
123132
self.sagemaker_session = sagemaker_session or Session()
124133
self.serializer = serializer
125134
self.deserializer = deserializer
@@ -137,6 +146,7 @@ def predict(
137146
target_variant=None,
138147
inference_id=None,
139148
custom_attributes=None,
149+
component_name: Optional[str] = None,
140150
):
141151
"""Return the inference from the specified endpoint.
142152
@@ -169,22 +179,29 @@ def predict(
169179
value is returned. For example, if a custom attribute represents the trace ID, your
170180
model can prepend the custom attribute with Trace ID: in your post-processing
171181
function (Default: None).
182+
component_name (str): Optional. Name of the Amazon SageMaker inference component
183+
corresponding the predictor.
172184
173185
Returns:
174186
object: Inference for the given input. If a deserializer was specified when creating
175187
the Predictor, the result of the deserializer is
176188
returned. Otherwise the response returns the sequence of bytes
177189
as is.
178190
"""
179-
191+
# [TODO]: clean up component_name in _create_request_args
180192
request_args = self._create_request_args(
181-
data,
182-
initial_args,
183-
target_model,
184-
target_variant,
185-
inference_id,
186-
custom_attributes,
193+
data=data,
194+
initial_args=initial_args,
195+
target_model=target_model,
196+
target_variant=target_variant,
197+
inference_id=inference_id,
198+
custom_attributes=custom_attributes,
187199
)
200+
201+
inference_component_name = component_name or self._get_component_name()
202+
if inference_component_name:
203+
request_args["InferenceComponentName"] = inference_component_name
204+
188205
response = self.sagemaker_session.sagemaker_runtime_client.invoke_endpoint(**request_args)
189206
return self._handle_response(response)
190207

@@ -260,6 +277,8 @@ def _create_request_args(
260277
if isinstance(data, JumpStartSerializablePayload) and jumpstart_serialized_data
261278
else self.serializer.serialize(data)
262279
)
280+
if self._get_component_name():
281+
args["InferenceComponentName"] = self.component_name
263282

264283
args["Body"] = data
265284
return args
@@ -273,6 +292,8 @@ def update_endpoint(
273292
tags=None,
274293
kms_key=None,
275294
data_capture_config_dict=None,
295+
max_instance_count=None,
296+
min_instance_count=None,
276297
wait=True,
277298
):
278299
"""Update the existing endpoint with the provided attributes.
@@ -310,6 +331,8 @@ def update_endpoint(
310331
data_capture_config_dict (dict): The endpoint data capture configuration
311332
for use with Amazon SageMaker Model Monitoring. If not specified,
312333
the data capture configuration of the existing endpoint configuration is used.
334+
max_instance_count (int): The maximum instance count used for scaling instance.
335+
min_instance_count (int): The minimum instance count used for scaling instance.
313336
314337
Raises:
315338
ValueError: If there is not enough information to create a new ``ProductionVariant``:
@@ -348,23 +371,45 @@ def update_endpoint(
348371
else:
349372
self._model_names = [model_name]
350373

351-
production_variant_config = production_variant(
352-
model_name,
353-
instance_type,
354-
initial_instance_count=initial_instance_count,
355-
accelerator_type=accelerator_type,
356-
)
374+
managed_instance_scaling = {}
375+
if max_instance_count:
376+
managed_instance_scaling["MaxInstanceCount"] = max_instance_count
377+
if min_instance_count:
378+
managed_instance_scaling["MinInstanceCount"] = min_instance_count
379+
380+
if managed_instance_scaling and len(managed_instance_scaling) > 0:
381+
production_variant_config = production_variant(
382+
model_name,
383+
instance_type,
384+
initial_instance_count=initial_instance_count,
385+
accelerator_type=accelerator_type,
386+
managed_instance_scaling=managed_instance_scaling,
387+
)
388+
else:
389+
production_variant_config = production_variant(
390+
model_name,
391+
instance_type,
392+
initial_instance_count=initial_instance_count,
393+
accelerator_type=accelerator_type,
394+
)
357395
production_variants = [production_variant_config]
358396

359397
current_endpoint_config_name = self._get_endpoint_config_name()
360398
new_endpoint_config_name = name_from_base(current_endpoint_config_name)
399+
400+
if self._get_component_name():
401+
endpoint_type = EndpointType.GOLDFINCH
402+
else:
403+
endpoint_type = EndpointType.OTHERS
404+
361405
self.sagemaker_session.create_endpoint_config_from_existing(
362406
current_endpoint_config_name,
363407
new_endpoint_config_name,
364408
new_tags=tags,
365409
new_kms_key=kms_key,
366410
new_data_capture_config_dict=data_capture_config_dict,
367411
new_production_variants=production_variants,
412+
endpoint_type=endpoint_type,
368413
)
369414
self.sagemaker_session.update_endpoint(
370415
self.endpoint_name, new_endpoint_config_name, wait=wait
@@ -393,10 +438,123 @@ def delete_endpoint(self, delete_endpoint_config=True):
393438

394439
self.sagemaker_session.delete_endpoint(self.endpoint_name)
395440

396-
delete_predictor = delete_endpoint
441+
def delete_predictor(self) -> None:
442+
"""Delete the Amazon SageMaker inference component or endpoint backing this predictor.
443+
444+
Delete the corresponding inference component if the endpoint is Goldfinch.
445+
Otherwise delete the endpoint where this predictor is on.
446+
"""
447+
# [TODO]: wait and describe inference component until not found to ensure
448+
# it gets deleted successfully. Throw appropriate exception/error type.
449+
if self.component_name:
450+
self.sagemaker_session.delete_inference_component(self.component_name)
451+
else:
452+
self.delete_endpoint()
453+
454+
def update_predictor(
455+
self,
456+
image_uri: Optional[str] = None,
457+
model_data: Optional[Union[str, dict]] = None,
458+
env: Optional[Dict[str, str]] = None,
459+
model_data_download_timeout: Optional[int] = None,
460+
container_startup_health_check_timeout: Optional[int] = None,
461+
resources: Optional[ResourceRequirements] = None,
462+
) -> str:
463+
"""Updates the predictor to deploy a new Model specification and apply new configurations.
464+
465+
This is done by updating the SageMaker InferenceComponent.
466+
467+
Args:
468+
image_uri (Optional[str]): A Docker image URI. (Default: None).
469+
model_data (Optional[Union[str, dict]]): Location
470+
of SageMaker model data. (Default: None).
471+
env (Optional[dict[str, str]]): Environment variables
472+
to run with ``image_uri`` when hosted in SageMaker. (Default: None).
473+
model_data_download_timeout (Optional[int]): The timeout value, in seconds, to download
474+
and extract model data from Amazon S3 to the individual inference instance
475+
associated with this production variant. (Default: None).
476+
container_startup_health_check_timeout (Optional[int]): The timeout value, in seconds,
477+
for your inference container to pass health check by SageMaker Hosting. For more
478+
information about health check see:
479+
https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-inference-code.html#your-algorithms-inference-algo-ping-requests
480+
(Default: None).
481+
resources (Optional[ResourceRequirements]): The compute resource requirements
482+
for a model to be deployed to an endpoint. Only EndpointType.Goldfinch supports
483+
this feature. (Default: None).
484+
485+
Returns:
486+
String: The updated Amazon SageMaker Inference Component name
487+
"""
488+
if self.component_name is None:
489+
raise ValueError(
490+
"No existing Inference Component; "
491+
"Please ensure you deployed Inference Component first."
492+
)
493+
# [TODO]: Move to a module
494+
request = {
495+
"InferenceComponentName": self.component_name,
496+
"Specification": {},
497+
}
498+
499+
if resources:
500+
request["Specification"][
501+
"ComputeResourceRequirements"
502+
] = resources.get_compute_resource_requirements()
503+
504+
if image_uri:
505+
request["Specification"]["Container"]["Image"] = image_uri
506+
507+
if env:
508+
request["Specification"]["Container"]["Environment"] = env
509+
510+
if model_data:
511+
request["Specification"]["Container"]["ArtifactUrl"] = model_data
512+
513+
if resources.copy_count:
514+
request["RuntimeConfig"] = {"CopyCount": resources.copy_count}
515+
516+
if model_data_download_timeout:
517+
request["Specification"]["StartupParameters"][
518+
"ModelDataDownloadTimeoutInSeconds"
519+
] = model_data_download_timeout
520+
521+
if container_startup_health_check_timeout:
522+
request["Specification"]["StartupParameters"][
523+
"ContainerStartupHealthCheckTimeoutInSeconds"
524+
] = container_startup_health_check_timeout
525+
526+
empty_keys = []
527+
for key, value in request["Specification"].items():
528+
if not value:
529+
empty_keys.append(key)
530+
for key in empty_keys:
531+
del request["Specification"][key]
532+
533+
self.sagemaker_session.update_inference_component(**request)
534+
return self.component_name
535+
536+
# [TODO]: Check with doc writer for colocated vs collocated
537+
def list_colocated_models(self):
538+
"""List the deployed models co-located with this predictor.
539+
540+
Calls SageMaker:ListInferenceComponents on the endpoint associated with the predictor.
541+
542+
Returns:
543+
Dict[str, list]: A list of Amazon SageMaker Inference Component objects.
544+
"""
545+
546+
inference_component_dict = self.sagemaker_session.list_inference_components(
547+
filters={"EndpointNameEquals": self.endpoint_name}
548+
)
549+
550+
if len(inference_component_dict) == 0:
551+
LOGGER.info("No deployed models found for endpoint %s.", self.endpoint_name)
552+
return []
553+
554+
return inference_component_dict["InferenceComponents"]
397555

398556
def delete_model(self):
399-
"""Deletes the Amazon SageMaker models backing this predictor."""
557+
"""Delete the Amazon SageMaker model backing this predictor."""
400558
request_failed = False
401559
failed_models = []
402560
current_model_names = self._get_model_names()
@@ -594,9 +752,16 @@ def _get_model_names(self):
594752
EndpointConfigName=current_endpoint_config_name
595753
)
596754
production_variants = endpoint_config["ProductionVariants"]
597-
self._model_names = [d["ModelName"] for d in production_variants]
755+
self._model_names = []
756+
for d in production_variants:
757+
if "ModelName" in d:
758+
self._model_names.append(d["ModelName"])
598759
return self._model_names
599760

761+
def _get_component_name(self) -> Optional[str]:
762+
"""Get the inference component name field if it exists in the Predictor object."""
763+
return getattr(self, "component_name", None)
764+
600765
@property
601766
def content_type(self):
602767
"""The MIME type of the data sent to the inference endpoint."""
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying athis file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Compute Resource Requirements needed to deploy a model"""
14+
from __future__ import absolute_import
15+
16+
from sagemaker.compute_resource_requirements.resource_requirements import ( # noqa: F401
17+
ResourceRequirements,
18+
)

0 commit comments

Comments
 (0)