Skip to content

Commit 865e5e0

Browse files
committed
feature: allow custom model name during deploy
1 parent 4e7e0db commit 865e5e0

File tree

2 files changed

+30
-1
lines changed

2 files changed

+30
-1
lines changed

src/sagemaker/estimator.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,7 @@ def attach(cls, training_job_name, sagemaker_session=None, model_channel_name='m
328328
return estimator
329329

330330
def deploy(self, initial_instance_count, instance_type, accelerator_type=None, endpoint_name=None,
331-
use_compiled_model=False, update_endpoint=False, **kwargs):
331+
model_name=None, use_compiled_model=False, update_endpoint=False, **kwargs):
332332
"""Deploy the trained model to an Amazon SageMaker endpoint and return a ``sagemaker.RealTimePredictor`` object.
333333
334334
More information:
@@ -344,6 +344,8 @@ def deploy(self, initial_instance_count, instance_type, accelerator_type=None, e
344344
For more information: https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html
345345
endpoint_name (str): Name to use for creating an Amazon SageMaker endpoint. If not specified, the name of
346346
the training job is used.
347+
model_name (str): Name to use for creating an Amazon SageMaker model. If not specified, the name of
348+
the training job is used.
347349
use_compiled_model (bool): Flag to select whether to use compiled (optimized) model. Default: False.
348350
update_endpoint (bool): Flag to update the model in an existing Amazon SageMaker endpoint.
349351
If True, this will deploy a new EndpointConfig to an already existing endpoint and delete resources
@@ -363,6 +365,7 @@ def deploy(self, initial_instance_count, instance_type, accelerator_type=None, e
363365
"""
364366
self._ensure_latest_training_job()
365367
endpoint_name = endpoint_name or self.latest_training_job.name
368+
model_name = model_name or self.latest_training_job.name
366369
self.deploy_instance_type = instance_type
367370
if use_compiled_model:
368371
family = '_'.join(instance_type.split('.')[:-1])
@@ -372,6 +375,7 @@ def deploy(self, initial_instance_count, instance_type, accelerator_type=None, e
372375
model = self._compiled_models[family]
373376
else:
374377
model = self.create_model(**kwargs)
378+
model.name = model_name
375379
return model.deploy(
376380
instance_type=instance_type,
377381
initial_instance_count=initial_instance_count,

tests/unit/test_estimator.py

+25
Original file line numberDiff line numberDiff line change
@@ -1131,6 +1131,31 @@ def test_deploy_with_update_endpoint(sagemaker_session):
11311131
sagemaker_session.create_endpoint.assert_not_called()
11321132

11331133

1134+
def test_deploy_with_model_name(sagemaker_session):
1135+
estimator = Estimator(IMAGE_NAME, ROLE, INSTANCE_COUNT, INSTANCE_TYPE, output_path=OUTPUT_PATH,
1136+
sagemaker_session=sagemaker_session)
1137+
estimator.set_hyperparameters(**HYPERPARAMS)
1138+
estimator.fit({'train': 's3://bucket/training-prefix'})
1139+
model_name = 'model-name'
1140+
estimator.deploy(INSTANCE_COUNT, INSTANCE_TYPE, model_name=model_name)
1141+
1142+
sagemaker_session.create_model.assert_called_once()
1143+
args, kwargs = sagemaker_session.create_model.call_args
1144+
assert args[0] == model_name
1145+
1146+
1147+
def test_deploy_with_no_model_name(sagemaker_session):
1148+
estimator = Estimator(IMAGE_NAME, ROLE, INSTANCE_COUNT, INSTANCE_TYPE, output_path=OUTPUT_PATH,
1149+
sagemaker_session=sagemaker_session)
1150+
estimator.set_hyperparameters(**HYPERPARAMS)
1151+
estimator.fit({'train': 's3://bucket/training-prefix'})
1152+
estimator.deploy(INSTANCE_COUNT, INSTANCE_TYPE)
1153+
1154+
sagemaker_session.create_model.assert_called_once()
1155+
args, kwargs = sagemaker_session.create_model.call_args
1156+
assert args[0].startswith(IMAGE_NAME)
1157+
1158+
11341159
@patch('sagemaker.estimator.LocalSession')
11351160
@patch('sagemaker.estimator.Session')
11361161
def test_local_mode(session_class, local_session_class):

0 commit comments

Comments
 (0)