@@ -331,7 +331,7 @@ def attach(cls, training_job_name, sagemaker_session=None, model_channel_name='m
331
331
return estimator
332
332
333
333
def deploy (self , initial_instance_count , instance_type , accelerator_type = None , endpoint_name = None ,
334
- use_compiled_model = False , update_endpoint = False , wait = True , ** kwargs ):
334
+ use_compiled_model = False , update_endpoint = False , wait = True , model_name = None , ** kwargs ):
335
335
"""Deploy the trained model to an Amazon SageMaker endpoint and return a ``sagemaker.RealTimePredictor`` object.
336
336
337
337
More information:
@@ -351,11 +351,13 @@ def deploy(self, initial_instance_count, instance_type, accelerator_type=None, e
351
351
update_endpoint (bool): Flag to update the model in an existing Amazon SageMaker endpoint.
352
352
If True, this will deploy a new EndpointConfig to an already existing endpoint and delete resources
353
353
corresponding to the previous EndpointConfig. Default: False
354
+ wait (bool): Whether the call should wait until the deployment of model completes (default: True).
355
+ model_name (str): Name to use for creating an Amazon SageMaker model. If not specified, the name of
356
+ the training job is used.
354
357
tags(List[dict[str, str]]): Optional. The list of tags to attach to this specific endpoint. Example:
355
358
>>> tags = [{'Key': 'tagname', 'Value': 'tagvalue'}]
356
359
For more information about tags, see https://boto3.amazonaws.com/v1/documentation\
357
360
/api/latest/reference/services/sagemaker.html#SageMaker.Client.add_tags
358
- wait (bool): Whether the call should wait until the deployment of model completes (default: True).
359
361
360
362
**kwargs: Passed to invocation of ``create_model()``. Implementations may customize
361
363
``create_model()`` to accept ``**kwargs`` to customize model creation during deploy.
@@ -367,6 +369,7 @@ def deploy(self, initial_instance_count, instance_type, accelerator_type=None, e
367
369
"""
368
370
self ._ensure_latest_training_job ()
369
371
endpoint_name = endpoint_name or self .latest_training_job .name
372
+ model_name = model_name or self .latest_training_job .name
370
373
self .deploy_instance_type = instance_type
371
374
if use_compiled_model :
372
375
family = '_' .join (instance_type .split ('.' )[:- 1 ])
@@ -376,6 +379,7 @@ def deploy(self, initial_instance_count, instance_type, accelerator_type=None, e
376
379
model = self ._compiled_models [family ]
377
380
else :
378
381
model = self .create_model (** kwargs )
382
+ model .name = model_name
379
383
return model .deploy (
380
384
instance_type = instance_type ,
381
385
initial_instance_count = initial_instance_count ,
0 commit comments