14
14
from __future__ import print_function , absolute_import
15
15
16
16
import abc
17
- from typing import Any , Optional , Tuple , Union
17
+ from typing import Any , Dict , Optional , Tuple , Union
18
+ import logging
18
19
20
+ from sagemaker .enums import EndpointType
19
21
from sagemaker .deprecations import (
20
22
deprecated_class ,
21
23
deprecated_deserialize ,
55
57
from sagemaker .model_monitor .model_monitoring import DEFAULT_REPOSITORY_NAME
56
58
57
59
from sagemaker .lineage .context import EndpointContext
60
+ from sagemaker .compute_resource_requirements .resource_requirements import ResourceRequirements
61
+
62
+ LOGGER = logging .getLogger ("sagemaker" )
58
63
59
64
60
65
class PredictorBase (abc .ABC ):
@@ -92,6 +97,7 @@ def __init__(
92
97
sagemaker_session = None ,
93
98
serializer = IdentitySerializer (),
94
99
deserializer = BytesDeserializer (),
100
+ component_name = None ,
95
101
** kwargs ,
96
102
):
97
103
"""Initialize a ``Predictor``.
@@ -115,11 +121,14 @@ def __init__(
115
121
deserializer (:class:`~sagemaker.deserializers.BaseDeserializer`): A
116
122
deserializer object, used to decode data from an inference
117
123
endpoint (default: :class:`~sagemaker.deserializers.BytesDeserializer`).
124
+ component_name (str): Name of the Amazon SageMaker inference component
125
+ corresponding the predictor.
118
126
"""
119
127
removed_kwargs ("content_type" , kwargs )
120
128
removed_kwargs ("accept" , kwargs )
121
129
endpoint_name = renamed_kwargs ("endpoint" , "endpoint_name" , endpoint_name , kwargs )
122
130
self .endpoint_name = endpoint_name
131
+ self .component_name = component_name
123
132
self .sagemaker_session = sagemaker_session or Session ()
124
133
self .serializer = serializer
125
134
self .deserializer = deserializer
@@ -137,6 +146,7 @@ def predict(
137
146
target_variant = None ,
138
147
inference_id = None ,
139
148
custom_attributes = None ,
149
+ component_name : Optional [str ] = None ,
140
150
):
141
151
"""Return the inference from the specified endpoint.
142
152
@@ -169,22 +179,29 @@ def predict(
169
179
value is returned. For example, if a custom attribute represents the trace ID, your
170
180
model can prepend the custom attribute with Trace ID: in your post-processing
171
181
function (Default: None).
182
+ component_name (str): Optional. Name of the Amazon SageMaker inference component
183
+ corresponding the predictor.
172
184
173
185
Returns:
174
186
object: Inference for the given input. If a deserializer was specified when creating
175
187
the Predictor, the result of the deserializer is
176
188
returned. Otherwise the response returns the sequence of bytes
177
189
as is.
178
190
"""
179
-
191
+ # [TODO]: clean up component_name in _create_request_args
180
192
request_args = self ._create_request_args (
181
- data ,
182
- initial_args ,
183
- target_model ,
184
- target_variant ,
185
- inference_id ,
186
- custom_attributes ,
193
+ data = data ,
194
+ initial_args = initial_args ,
195
+ target_model = target_model ,
196
+ target_variant = target_variant ,
197
+ inference_id = inference_id ,
198
+ custom_attributes = custom_attributes ,
187
199
)
200
+
201
+ inference_component_name = component_name or self ._get_component_name ()
202
+ if inference_component_name :
203
+ request_args ["InferenceComponentName" ] = inference_component_name
204
+
188
205
response = self .sagemaker_session .sagemaker_runtime_client .invoke_endpoint (** request_args )
189
206
return self ._handle_response (response )
190
207
@@ -260,6 +277,8 @@ def _create_request_args(
260
277
if isinstance (data , JumpStartSerializablePayload ) and jumpstart_serialized_data
261
278
else self .serializer .serialize (data )
262
279
)
280
+ if self ._get_component_name ():
281
+ args ["InferenceComponentName" ] = self .component_name
263
282
264
283
args ["Body" ] = data
265
284
return args
@@ -273,6 +292,8 @@ def update_endpoint(
273
292
tags = None ,
274
293
kms_key = None ,
275
294
data_capture_config_dict = None ,
295
+ max_instance_count = None ,
296
+ min_instance_count = None ,
276
297
wait = True ,
277
298
):
278
299
"""Update the existing endpoint with the provided attributes.
@@ -310,6 +331,8 @@ def update_endpoint(
310
331
data_capture_config_dict (dict): The endpoint data capture configuration
311
332
for use with Amazon SageMaker Model Monitoring. If not specified,
312
333
the data capture configuration of the existing endpoint configuration is used.
334
+ max_instance_count (int): The maximum instance count used for scaling instance.
335
+ min_instance_count (int): The minimum instance count used for scaling instance.
313
336
314
337
Raises:
315
338
ValueError: If there is not enough information to create a new ``ProductionVariant``:
@@ -348,23 +371,45 @@ def update_endpoint(
348
371
else :
349
372
self ._model_names = [model_name ]
350
373
351
- production_variant_config = production_variant (
352
- model_name ,
353
- instance_type ,
354
- initial_instance_count = initial_instance_count ,
355
- accelerator_type = accelerator_type ,
356
- )
374
+ managed_instance_scaling = {}
375
+ if max_instance_count :
376
+ managed_instance_scaling ["MaxInstanceCount" ] = max_instance_count
377
+ if min_instance_count :
378
+ managed_instance_scaling ["MinInstanceCount" ] = min_instance_count
379
+
380
+ if managed_instance_scaling and len (managed_instance_scaling ) > 0 :
381
+ production_variant_config = production_variant (
382
+ model_name ,
383
+ instance_type ,
384
+ initial_instance_count = initial_instance_count ,
385
+ accelerator_type = accelerator_type ,
386
+ managed_instance_scaling = managed_instance_scaling ,
387
+ )
388
+ else :
389
+ production_variant_config = production_variant (
390
+ model_name ,
391
+ instance_type ,
392
+ initial_instance_count = initial_instance_count ,
393
+ accelerator_type = accelerator_type ,
394
+ )
357
395
production_variants = [production_variant_config ]
358
396
359
397
current_endpoint_config_name = self ._get_endpoint_config_name ()
360
398
new_endpoint_config_name = name_from_base (current_endpoint_config_name )
399
+
400
+ if self ._get_component_name ():
401
+ endpoint_type = EndpointType .GOLDFINCH
402
+ else :
403
+ endpoint_type = EndpointType .OTHERS
404
+
361
405
self .sagemaker_session .create_endpoint_config_from_existing (
362
406
current_endpoint_config_name ,
363
407
new_endpoint_config_name ,
364
408
new_tags = tags ,
365
409
new_kms_key = kms_key ,
366
410
new_data_capture_config_dict = data_capture_config_dict ,
367
411
new_production_variants = production_variants ,
412
+ endpoint_type = endpoint_type ,
368
413
)
369
414
self .sagemaker_session .update_endpoint (
370
415
self .endpoint_name , new_endpoint_config_name , wait = wait
@@ -393,10 +438,123 @@ def delete_endpoint(self, delete_endpoint_config=True):
393
438
394
439
self .sagemaker_session .delete_endpoint (self .endpoint_name )
395
440
396
- delete_predictor = delete_endpoint
441
+ def delete_predictor (self ) -> None :
442
+ """Delete the Amazon SageMaker inference component or endpoint backing this predictor.
443
+
444
+ Delete the corresponding inference component if the endpoint is Goldfinch.
445
+ Otherwise delete the endpoint where this predictor is on.
446
+ """
447
+ # [TODO]: wait and describe inference component until not found to ensure
448
+ # it gets deleted successfully. Throw appropriate exception/error type.
449
+ if self .component_name :
450
+ self .sagemaker_session .delete_inference_component (self .component_name )
451
+ else :
452
+ self .delete_endpoint ()
453
+
454
+ def update_predictor (
455
+ self ,
456
+ image_uri : Optional [str ] = None ,
457
+ model_data : Optional [Union [str , dict ]] = None ,
458
+ env : Optional [Dict [str , str ]] = None ,
459
+ model_data_download_timeout : Optional [int ] = None ,
460
+ container_startup_health_check_timeout : Optional [int ] = None ,
461
+ resources : Optional [ResourceRequirements ] = None ,
462
+ ) -> str :
463
+ """Updates the predictor to deploy a new Model specification and apply new configurations.
464
+
465
+ This is done by updating the SageMaker InferenceComponent.
466
+
467
+ Args:
468
+ image_uri (Optional[str]): A Docker image URI. (Default: None).
469
+ model_data (Optional[Union[str, dict]]): Location
470
+ of SageMaker model data. (Default: None).
471
+ env (Optional[dict[str, str]]): Environment variables
472
+ to run with ``image_uri`` when hosted in SageMaker. (Default: None).
473
+ model_data_download_timeout (Optional[int]): The timeout value, in seconds, to download
474
+ and extract model data from Amazon S3 to the individual inference instance
475
+ associated with this production variant. (Default: None).
476
+ container_startup_health_check_timeout (Optional[int]): The timeout value, in seconds,
477
+ for your inference container to pass health check by SageMaker Hosting. For more
478
+ information about health check see:
479
+ https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-inference-code.html#your-algorithms-inference-algo-ping-requests
480
+ (Default: None).
481
+ resources (Optional[ResourceRequirements]): The compute resource requirements
482
+ for a model to be deployed to an endpoint. Only EndpointType.Goldfinch supports
483
+ this feature. (Default: None).
484
+
485
+ Returns:
486
+ String: The updated Amazon SageMaker Inference Component name
487
+ """
488
+ if self .component_name is None :
489
+ raise ValueError (
490
+ "No existing Inference Component; "
491
+ "Please ensure you deployed Inference Component first."
492
+ )
493
+ # [TODO]: Move to a module
494
+ request = {
495
+ "InferenceComponentName" : self .component_name ,
496
+ "Specification" : {},
497
+ }
498
+
499
+ if resources :
500
+ request ["Specification" ][
501
+ "ComputeResourceRequirements"
502
+ ] = resources .get_compute_resource_requirements ()
503
+
504
+ if image_uri :
505
+ request ["Specification" ]["Container" ]["Image" ] = image_uri
506
+
507
+ if env :
508
+ request ["Specification" ]["Container" ]["Environment" ] = env
509
+
510
+ if model_data :
511
+ request ["Specification" ]["Container" ]["ArtifactUrl" ] = model_data
512
+
513
+ if resources .copy_count :
514
+ request ["RuntimeConfig" ] = {"CopyCount" : resources .copy_count }
515
+
516
+ if model_data_download_timeout :
517
+ request ["Specification" ]["StartupParameters" ][
518
+ "ModelDataDownloadTimeoutInSeconds"
519
+ ] = model_data_download_timeout
520
+
521
+ if container_startup_health_check_timeout :
522
+ request ["Specification" ]["StartupParameters" ][
523
+ "ContainerStartupHealthCheckTimeoutInSeconds"
524
+ ] = container_startup_health_check_timeout
525
+
526
+ empty_keys = []
527
+ for key , value in request ["Specification" ].items ():
528
+ if not value :
529
+ empty_keys .append (key )
530
+ for key in empty_keys :
531
+ del request ["Specification" ][key ]
532
+
533
+ self .sagemaker_session .update_inference_component (** request )
534
+ return self .component_name
535
+
536
+ # [TODO]: Check with doc writer for colocated vs collocated
537
+ def list_colocated_models (self ):
538
+ """List the deployed models co-located with this predictor.
539
+
540
+ Calls SageMaker:ListInferenceComponents on the endpoint associated with the predictor.
541
+
542
+ Returns:
543
+ Dict[str, list]: A list of Amazon SageMaker Inference Component objects.
544
+ """
545
+
546
+ inference_component_dict = self .sagemaker_session .list_inference_components (
547
+ filters = {"EndpointNameEquals" : self .endpoint_name }
548
+ )
549
+
550
+ if len (inference_component_dict ) == 0 :
551
+ LOGGER .info ("No deployed models found for endpoint %s." , self .endpoint_name )
552
+ return []
553
+
554
+ return inference_component_dict ["InferenceComponents" ]
397
555
398
556
def delete_model (self ):
399
- """Deletes the Amazon SageMaker models backing this predictor."""
557
+ """Delete the Amazon SageMaker model backing this predictor."""
400
558
request_failed = False
401
559
failed_models = []
402
560
current_model_names = self ._get_model_names ()
@@ -594,9 +752,16 @@ def _get_model_names(self):
594
752
EndpointConfigName = current_endpoint_config_name
595
753
)
596
754
production_variants = endpoint_config ["ProductionVariants" ]
597
- self ._model_names = [d ["ModelName" ] for d in production_variants ]
755
+ self ._model_names = []
756
+ for d in production_variants :
757
+ if "ModelName" in d :
758
+ self ._model_names .append (d ["ModelName" ])
598
759
return self ._model_names
599
760
761
+ def _get_component_name (self ) -> Optional [str ]:
762
+ """Get the inference component name field if it exists in the Predictor object."""
763
+ return getattr (self , "component_name" , None )
764
+
600
765
@property
601
766
def content_type (self ):
602
767
"""The MIME type of the data sent to the inference endpoint."""
0 commit comments