Skip to content

Commit bad2cef

Browse files
committed
feature: allow custom model name during deploy
1 parent 3ab6411 commit bad2cef

File tree

4 files changed

+40
-3
lines changed

4 files changed

+40
-3
lines changed

src/sagemaker/estimator.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,7 @@ def attach(cls, training_job_name, sagemaker_session=None, model_channel_name='m
331331
return estimator
332332

333333
def deploy(self, initial_instance_count, instance_type, accelerator_type=None, endpoint_name=None,
334-
use_compiled_model=False, update_endpoint=False, **kwargs):
334+
model_name=None, use_compiled_model=False, update_endpoint=False, **kwargs):
335335
"""Deploy the trained model to an Amazon SageMaker endpoint and return a ``sagemaker.RealTimePredictor`` object.
336336
337337
More information:
@@ -347,6 +347,8 @@ def deploy(self, initial_instance_count, instance_type, accelerator_type=None, e
347347
For more information: https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html
348348
endpoint_name (str): Name to use for creating an Amazon SageMaker endpoint. If not specified, the name of
349349
the training job is used.
350+
model_name (str): Name to use for creating an Amazon SageMaker model. If not specified, the name of
351+
the training job is used.
350352
use_compiled_model (bool): Flag to select whether to use compiled (optimized) model. Default: False.
351353
update_endpoint (bool): Flag to update the model in an existing Amazon SageMaker endpoint.
352354
If True, this will deploy a new EndpointConfig to an already existing endpoint and delete resources
@@ -366,6 +368,7 @@ def deploy(self, initial_instance_count, instance_type, accelerator_type=None, e
366368
"""
367369
self._ensure_latest_training_job()
368370
endpoint_name = endpoint_name or self.latest_training_job.name
371+
model_name = model_name or self.latest_training_job.name
369372
self.deploy_instance_type = instance_type
370373
if use_compiled_model:
371374
family = '_'.join(instance_type.split('.')[:-1])
@@ -375,6 +378,7 @@ def deploy(self, initial_instance_count, instance_type, accelerator_type=None, e
375378
model = self._compiled_models[family]
376379
else:
377380
model = self.create_model(**kwargs)
381+
model.name = model_name
378382
return model.deploy(
379383
instance_type=instance_type,
380384
initial_instance_count=initial_instance_count,

tests/integ/test_tf_script_mode.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -140,19 +140,21 @@ def test_mnist_async(sagemaker_session):
140140
training_job_name = estimator.latest_training_job.name
141141
time.sleep(20)
142142
endpoint_name = training_job_name
143+
model_name = 'model-name-1'
143144
_assert_training_job_tags_match(sagemaker_session.sagemaker_client,
144145
estimator.latest_training_job.name, TAGS)
145146
with tests.integ.timeout.timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
146147
estimator = TensorFlow.attach(training_job_name=training_job_name,
147148
sagemaker_session=sagemaker_session)
148149
predictor = estimator.deploy(initial_instance_count=1, instance_type='ml.c4.xlarge',
149-
endpoint_name=endpoint_name)
150+
endpoint_name=endpoint_name, model_name=model_name)
150151

151152
result = predictor.predict(np.zeros(784))
152153
print('predict result: {}'.format(result))
153154
_assert_endpoint_tags_match(sagemaker_session.sagemaker_client, predictor.endpoint, TAGS)
154155
_assert_model_tags_match(sagemaker_session.sagemaker_client,
155156
estimator.latest_training_job.name, TAGS)
157+
_assert_model_name_match(sagemaker_session.sagemaker_client, endpoint_name, model_name)
156158

157159

158160
@pytest.mark.skipif(tests.integ.PYTHON_VERSION != 'py3',
@@ -214,3 +216,8 @@ def _assert_training_job_tags_match(sagemaker_client, training_job_name, tags):
214216
training_job_description = sagemaker_client.describe_training_job(
215217
TrainingJobName=training_job_name)
216218
_assert_tags_match(sagemaker_client, training_job_description['TrainingJobArn'], tags)
219+
220+
221+
def _assert_model_name_match(sagemaker_client, endpoint_config_name, model_name):
222+
endpoint_config_description = sagemaker_client.describe_endpoint(EndpointConfigName=endpoint_config_name)
223+
assert model_name == endpoint_config_description['ProductionVariants'][0]['ModelName']

tests/unit/test_estimator.py

+25
Original file line numberDiff line numberDiff line change
@@ -1139,6 +1139,31 @@ def test_deploy_with_update_endpoint(sagemaker_session):
11391139
sagemaker_session.create_endpoint.assert_not_called()
11401140

11411141

1142+
def test_deploy_with_model_name(sagemaker_session):
1143+
estimator = Estimator(IMAGE_NAME, ROLE, INSTANCE_COUNT, INSTANCE_TYPE, output_path=OUTPUT_PATH,
1144+
sagemaker_session=sagemaker_session)
1145+
estimator.set_hyperparameters(**HYPERPARAMS)
1146+
estimator.fit({'train': 's3://bucket/training-prefix'})
1147+
model_name = 'model-name'
1148+
estimator.deploy(INSTANCE_COUNT, INSTANCE_TYPE, model_name=model_name)
1149+
1150+
sagemaker_session.create_model.assert_called_once()
1151+
args, kwargs = sagemaker_session.create_model.call_args
1152+
assert args[0] == model_name
1153+
1154+
1155+
def test_deploy_with_no_model_name(sagemaker_session):
1156+
estimator = Estimator(IMAGE_NAME, ROLE, INSTANCE_COUNT, INSTANCE_TYPE, output_path=OUTPUT_PATH,
1157+
sagemaker_session=sagemaker_session)
1158+
estimator.set_hyperparameters(**HYPERPARAMS)
1159+
estimator.fit({'train': 's3://bucket/training-prefix'})
1160+
estimator.deploy(INSTANCE_COUNT, INSTANCE_TYPE)
1161+
1162+
sagemaker_session.create_model.assert_called_once()
1163+
args, kwargs = sagemaker_session.create_model.call_args
1164+
assert args[0].startswith(IMAGE_NAME)
1165+
1166+
11421167
@patch('sagemaker.estimator.LocalSession')
11431168
@patch('sagemaker.estimator.Session')
11441169
def test_local_mode(session_class, local_session_class):

tests/unit/test_tuner.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -566,7 +566,8 @@ def test_deploy_default(tuner):
566566

567567
tuner.estimator.sagemaker_session.create_model.assert_called_once()
568568
args = tuner.estimator.sagemaker_session.create_model.call_args[0]
569-
assert args[0].startswith(IMAGE_NAME)
569+
570+
assert args[0] == 'neo'
570571
assert args[1] == ROLE
571572
assert args[2]['Image'] == IMAGE_NAME
572573
assert args[2]['ModelDataUrl'] == MODEL_DATA

0 commit comments

Comments
 (0)