Skip to content

Commit 44336fc

Browse files
committed
feature: allow custom model name during deploy
1 parent 686569e commit 44336fc

File tree

6 files changed

+55
-8
lines changed

6 files changed

+55
-8
lines changed

src/sagemaker/estimator.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,7 @@ def attach(cls, training_job_name, sagemaker_session=None, model_channel_name='m
331331
return estimator
332332

333333
def deploy(self, initial_instance_count, instance_type, accelerator_type=None, endpoint_name=None,
334-
use_compiled_model=False, update_endpoint=False, wait=True, **kwargs):
334+
use_compiled_model=False, update_endpoint=False, wait=True, model_name=None, **kwargs):
335335
"""Deploy the trained model to an Amazon SageMaker endpoint and return a ``sagemaker.RealTimePredictor`` object.
336336
337337
More information:
@@ -351,11 +351,13 @@ def deploy(self, initial_instance_count, instance_type, accelerator_type=None, e
351351
update_endpoint (bool): Flag to update the model in an existing Amazon SageMaker endpoint.
352352
If True, this will deploy a new EndpointConfig to an already existing endpoint and delete resources
353353
corresponding to the previous EndpointConfig. Default: False
354+
wait (bool): Whether the call should wait until the deployment of model completes (default: True).
355+
model_name (str): Name to use for creating an Amazon SageMaker model. If not specified, the name of
356+
the training job is used.
354357
tags(List[dict[str, str]]): Optional. The list of tags to attach to this specific endpoint. Example:
355358
>>> tags = [{'Key': 'tagname', 'Value': 'tagvalue'}]
356359
For more information about tags, see https://boto3.amazonaws.com/v1/documentation\
357360
/api/latest/reference/services/sagemaker.html#SageMaker.Client.add_tags
358-
wait (bool): Whether the call should wait until the deployment of model completes (default: True).
359361
360362
**kwargs: Passed to invocation of ``create_model()``. Implementations may customize
361363
``create_model()`` to accept ``**kwargs`` to customize model creation during deploy.
@@ -367,6 +369,7 @@ def deploy(self, initial_instance_count, instance_type, accelerator_type=None, e
367369
"""
368370
self._ensure_latest_training_job()
369371
endpoint_name = endpoint_name or self.latest_training_job.name
372+
model_name = model_name or self.latest_training_job.name
370373
self.deploy_instance_type = instance_type
371374
if use_compiled_model:
372375
family = '_'.join(instance_type.split('.')[:-1])
@@ -376,6 +379,7 @@ def deploy(self, initial_instance_count, instance_type, accelerator_type=None, e
376379
model = self._compiled_models[family]
377380
else:
378381
model = self.create_model(**kwargs)
382+
model.name = model_name
379383
return model.deploy(
380384
instance_type=instance_type,
381385
initial_instance_count=initial_instance_count,

src/sagemaker/tuner.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,7 @@ def attach(cls, tuning_job_name, sagemaker_session=None, job_details=None, estim
327327
return tuner
328328

329329
def deploy(self, initial_instance_count, instance_type, accelerator_type=None, endpoint_name=None, wait=True,
330-
**kwargs):
330+
model_name=None, **kwargs):
331331
"""Deploy the best trained or user specified model to an Amazon SageMaker endpoint and return a
332332
``sagemaker.RealTimePredictor`` object.
333333
@@ -344,6 +344,8 @@ def deploy(self, initial_instance_count, instance_type, accelerator_type=None, e
344344
endpoint_name (str): Name to use for creating an Amazon SageMaker endpoint. If not specified,
345345
the name of the training job is used.
346346
wait (bool): Whether the call should wait until the deployment of model completes (default: True).
347+
model_name (str): Name to use for creating an Amazon SageMaker model. If not specified, the name of
348+
the training job is used.
347349
**kwargs: Other arguments needed for deployment. Please refer to the ``create_model()`` method of
348350
the associated estimator to see what other arguments are needed.
349351
@@ -356,7 +358,7 @@ def deploy(self, initial_instance_count, instance_type, accelerator_type=None, e
356358
sagemaker_session=self.estimator.sagemaker_session)
357359
return best_estimator.deploy(initial_instance_count, instance_type,
358360
accelerator_type=accelerator_type,
359-
endpoint_name=endpoint_name, wait=wait, **kwargs)
361+
endpoint_name=endpoint_name, wait=wait, model_name=model_name, **kwargs)
360362

361363
def stop_tuning_job(self):
362364
"""Stop latest running hyperparameter tuning job.

tests/integ/test_tf_script_mode.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -136,19 +136,21 @@ def test_mnist_async(sagemaker_session):
136136
training_job_name = estimator.latest_training_job.name
137137
time.sleep(20)
138138
endpoint_name = training_job_name
139+
model_name = 'model-name-1'
139140
_assert_training_job_tags_match(sagemaker_session.sagemaker_client,
140141
estimator.latest_training_job.name, TAGS)
141142
with tests.integ.timeout.timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
142143
estimator = TensorFlow.attach(training_job_name=training_job_name,
143144
sagemaker_session=sagemaker_session)
144145
predictor = estimator.deploy(initial_instance_count=1, instance_type='ml.c4.xlarge',
145-
endpoint_name=endpoint_name)
146+
endpoint_name=endpoint_name, model_name=model_name)
146147

147148
result = predictor.predict(np.zeros(784))
148149
print('predict result: {}'.format(result))
149150
_assert_endpoint_tags_match(sagemaker_session.sagemaker_client, predictor.endpoint, TAGS)
150151
_assert_model_tags_match(sagemaker_session.sagemaker_client,
151152
estimator.latest_training_job.name, TAGS)
153+
_assert_model_name_match(sagemaker_session.sagemaker_client, endpoint_name, model_name)
152154

153155

154156
def test_deploy_with_input_handlers(sagemaker_session, instance_type):
@@ -208,3 +210,8 @@ def _assert_training_job_tags_match(sagemaker_client, training_job_name, tags):
208210
training_job_description = sagemaker_client.describe_training_job(
209211
TrainingJobName=training_job_name)
210212
_assert_tags_match(sagemaker_client, training_job_description['TrainingJobArn'], tags)
213+
214+
215+
def _assert_model_name_match(sagemaker_client, endpoint_config_name, model_name):
216+
endpoint_config_description = sagemaker_client.describe_endpoint_config(EndpointConfigName=endpoint_config_name)
217+
assert model_name == endpoint_config_description['ProductionVariants'][0]['ModelName']

tests/integ/test_tuner.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -660,13 +660,15 @@ def test_attach_tuning_pytorch(sagemaker_session):
660660
time.sleep(15)
661661
tuner.wait()
662662

663+
endpoint_name = tuning_job_name
664+
model_name = 'model-name-1'
663665
attached_tuner = HyperparameterTuner.attach(tuning_job_name,
664666
sagemaker_session=sagemaker_session)
665667
assert attached_tuner.early_stopping_type == 'Auto'
666668

667669
best_training_job = tuner.best_training_job()
668-
with timeout_and_delete_endpoint_by_name(best_training_job, sagemaker_session):
669-
predictor = attached_tuner.deploy(1, 'ml.c4.xlarge')
670+
with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
671+
predictor = attached_tuner.deploy(1, 'ml.c4.xlarge', endpoint_name=endpoint_name, model_name=model_name)
670672
data = np.zeros(shape=(1, 1, 28, 28), dtype=np.float32)
671673
predictor.predict(data)
672674

@@ -675,6 +677,7 @@ def test_attach_tuning_pytorch(sagemaker_session):
675677
output = predictor.predict(data)
676678

677679
assert output.shape == (batch_size, 10)
680+
_assert_model_name_match(sagemaker_session.sagemaker_client, endpoint_name, model_name)
678681

679682

680683
@pytest.mark.canary_quick
@@ -749,3 +752,8 @@ def _fm_serializer(data):
749752
for row in data:
750753
js['instances'].append({'features': row.tolist()})
751754
return json.dumps(js)
755+
756+
757+
def _assert_model_name_match(sagemaker_client, endpoint_config_name, model_name):
758+
endpoint_config_description = sagemaker_client.describe_endpoint_config(EndpointConfigName=endpoint_config_name)
759+
assert model_name == endpoint_config_description['ProductionVariants'][0]['ModelName']

tests/unit/test_estimator.py

+25
Original file line numberDiff line numberDiff line change
@@ -1140,6 +1140,31 @@ def test_deploy_with_update_endpoint(sagemaker_session):
11401140
sagemaker_session.create_endpoint.assert_not_called()
11411141

11421142

1143+
def test_deploy_with_model_name(sagemaker_session):
1144+
estimator = Estimator(IMAGE_NAME, ROLE, INSTANCE_COUNT, INSTANCE_TYPE, output_path=OUTPUT_PATH,
1145+
sagemaker_session=sagemaker_session)
1146+
estimator.set_hyperparameters(**HYPERPARAMS)
1147+
estimator.fit({'train': 's3://bucket/training-prefix'})
1148+
model_name = 'model-name'
1149+
estimator.deploy(INSTANCE_COUNT, INSTANCE_TYPE, model_name=model_name)
1150+
1151+
sagemaker_session.create_model.assert_called_once()
1152+
args, kwargs = sagemaker_session.create_model.call_args
1153+
assert args[0] == model_name
1154+
1155+
1156+
def test_deploy_with_no_model_name(sagemaker_session):
1157+
estimator = Estimator(IMAGE_NAME, ROLE, INSTANCE_COUNT, INSTANCE_TYPE, output_path=OUTPUT_PATH,
1158+
sagemaker_session=sagemaker_session)
1159+
estimator.set_hyperparameters(**HYPERPARAMS)
1160+
estimator.fit({'train': 's3://bucket/training-prefix'})
1161+
estimator.deploy(INSTANCE_COUNT, INSTANCE_TYPE)
1162+
1163+
sagemaker_session.create_model.assert_called_once()
1164+
args, kwargs = sagemaker_session.create_model.call_args
1165+
assert args[0].startswith(IMAGE_NAME)
1166+
1167+
11431168
@patch('sagemaker.estimator.LocalSession')
11441169
@patch('sagemaker.estimator.Session')
11451170
def test_local_mode(session_class, local_session_class):

tests/unit/test_tuner.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -566,7 +566,8 @@ def test_deploy_default(tuner):
566566

567567
tuner.estimator.sagemaker_session.create_model.assert_called_once()
568568
args = tuner.estimator.sagemaker_session.create_model.call_args[0]
569-
assert args[0].startswith(IMAGE_NAME)
569+
570+
assert args[0] == 'neo'
570571
assert args[1] == ROLE
571572
assert args[2]['Image'] == IMAGE_NAME
572573
assert args[2]['ModelDataUrl'] == MODEL_DATA

0 commit comments

Comments
 (0)