Skip to content

Commit 78bdd92

Browse files
committed
feature: allow custom model name during deploy
1 parent 64aa936 commit 78bdd92

File tree

6 files changed

+62
-6
lines changed

6 files changed

+62
-6
lines changed

src/sagemaker/estimator.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,7 @@ def deploy(
392392
use_compiled_model=False,
393393
update_endpoint=False,
394394
wait=True,
395+
model_name=None,
395396
**kwargs
396397
):
397398
"""Deploy the trained model to an Amazon SageMaker endpoint and return a ``sagemaker.RealTimePredictor`` object.
@@ -413,11 +414,13 @@ def deploy(
413414
update_endpoint (bool): Flag to update the model in an existing Amazon SageMaker endpoint.
414415
If True, this will deploy a new EndpointConfig to an already existing endpoint and delete resources
415416
corresponding to the previous EndpointConfig. Default: False
417+
wait (bool): Whether the call should wait until the deployment of model completes (default: True).
418+
model_name (str): Name to use for creating an Amazon SageMaker model. If not specified, the name of
419+
the training job is used.
416420
tags(List[dict[str, str]]): Optional. The list of tags to attach to this specific endpoint. Example:
417421
>>> tags = [{'Key': 'tagname', 'Value': 'tagvalue'}]
418422
For more information about tags, see https://boto3.amazonaws.com/v1/documentation\
419423
/api/latest/reference/services/sagemaker.html#SageMaker.Client.add_tags
420-
wait (bool): Whether the call should wait until the deployment of model completes (default: True).
421424
422425
**kwargs: Passed to invocation of ``create_model()``. Implementations may customize
423426
``create_model()`` to accept ``**kwargs`` to customize model creation during deploy.
@@ -429,6 +432,7 @@ def deploy(
429432
"""
430433
self._ensure_latest_training_job()
431434
endpoint_name = endpoint_name or self.latest_training_job.name
435+
model_name = model_name or self.latest_training_job.name
432436
self.deploy_instance_type = instance_type
433437
if use_compiled_model:
434438
family = "_".join(instance_type.split(".")[:-1])
@@ -440,6 +444,7 @@ def deploy(
440444
model = self._compiled_models[family]
441445
else:
442446
model = self.create_model(**kwargs)
447+
model.name = model_name
443448
return model.deploy(
444449
instance_type=instance_type,
445450
initial_instance_count=initial_instance_count,

src/sagemaker/tuner.py

+4
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,7 @@ def deploy(
375375
accelerator_type=None,
376376
endpoint_name=None,
377377
wait=True,
378+
model_name=None,
378379
**kwargs
379380
):
380381
"""Deploy the best trained or user specified model to an Amazon SageMaker endpoint and return a
@@ -393,6 +394,8 @@ def deploy(
393394
endpoint_name (str): Name to use for creating an Amazon SageMaker endpoint. If not specified,
394395
the name of the training job is used.
395396
wait (bool): Whether the call should wait until the deployment of model completes (default: True).
397+
model_name (str): Name to use for creating an Amazon SageMaker model. If not specified, the name of
398+
the training job is used.
396399
**kwargs: Other arguments needed for deployment. Please refer to the ``create_model()`` method of
397400
the associated estimator to see what other arguments are needed.
398401
@@ -410,6 +413,7 @@ def deploy(
410413
accelerator_type=accelerator_type,
411414
endpoint_name=endpoint_name,
412415
wait=wait,
416+
model_name=model_name,
413417
**kwargs
414418
)
415419

tests/integ/test_tf_script_mode.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ def test_mnist_async(sagemaker_session):
159159
training_job_name = estimator.latest_training_job.name
160160
time.sleep(20)
161161
endpoint_name = training_job_name
162+
model_name = 'model-name-1'
162163
_assert_training_job_tags_match(
163164
sagemaker_session.sagemaker_client, estimator.latest_training_job.name, TAGS
164165
)
@@ -167,7 +168,8 @@ def test_mnist_async(sagemaker_session):
167168
training_job_name=training_job_name, sagemaker_session=sagemaker_session
168169
)
169170
predictor = estimator.deploy(
170-
initial_instance_count=1, instance_type="ml.c4.xlarge", endpoint_name=endpoint_name
171+
initial_instance_count=1, instance_type="ml.c4.xlarge", endpoint_name=endpoint_name,
172+
model_name=model_name
171173
)
172174

173175
result = predictor.predict(np.zeros(784))
@@ -176,6 +178,9 @@ def test_mnist_async(sagemaker_session):
176178
_assert_model_tags_match(
177179
sagemaker_session.sagemaker_client, estimator.latest_training_job.name, TAGS
178180
)
181+
_assert_model_name_match(
182+
sagemaker_session.sagemaker_client, endpoint_name, model_name
183+
)
179184

180185

181186
def test_deploy_with_input_handlers(sagemaker_session, instance_type):
@@ -241,3 +246,10 @@ def _assert_training_job_tags_match(sagemaker_client, training_job_name, tags):
241246
TrainingJobName=training_job_name
242247
)
243248
_assert_tags_match(sagemaker_client, training_job_description["TrainingJobArn"], tags)
249+
250+
251+
def _assert_model_name_match(sagemaker_client, endpoint_config_name, model_name):
252+
endpoint_config_description = sagemaker_client.describe_endpoint_config(
253+
EndpointConfigName=endpoint_config_name
254+
)
255+
assert model_name == endpoint_config_description['ProductionVariants'][0]['ModelName']

tests/integ/test_tuner.py

+12-3
Original file line numberDiff line numberDiff line change
@@ -843,14 +843,17 @@ def test_attach_tuning_pytorch(sagemaker_session):
843843
time.sleep(15)
844844
tuner.wait()
845845

846+
endpoint_name = tuning_job_name
847+
model_name = 'model-name-1'
846848
attached_tuner = HyperparameterTuner.attach(
847849
tuning_job_name, sagemaker_session=sagemaker_session
848850
)
849851
assert attached_tuner.early_stopping_type == "Auto"
850852

851-
best_training_job = tuner.best_training_job()
852-
with timeout_and_delete_endpoint_by_name(best_training_job, sagemaker_session):
853-
predictor = attached_tuner.deploy(1, "ml.c4.xlarge")
853+
with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
854+
predictor = attached_tuner.deploy(
855+
1, "ml.c4.xlarge", endpoint_name=endpoint_name, model_name=model_name
856+
)
854857
data = np.zeros(shape=(1, 1, 28, 28), dtype=np.float32)
855858
predictor.predict(data)
856859

@@ -859,6 +862,7 @@ def test_attach_tuning_pytorch(sagemaker_session):
859862
output = predictor.predict(data)
860863

861864
assert output.shape == (batch_size, 10)
865+
_assert_model_name_match(sagemaker_session.sagemaker_client, endpoint_name, model_name)
862866

863867

864868
@pytest.mark.canary_quick
@@ -941,3 +945,8 @@ def _fm_serializer(data):
941945
for row in data:
942946
js["instances"].append({"features": row.tolist()})
943947
return json.dumps(js)
948+
949+
950+
def _assert_model_name_match(sagemaker_client, endpoint_config_name, model_name):
951+
endpoint_config_description = sagemaker_client.describe_endpoint_config(EndpointConfigName=endpoint_config_name)
952+
assert model_name == endpoint_config_description['ProductionVariants'][0]['ModelName']

tests/unit/test_estimator.py

+25
Original file line numberDiff line numberDiff line change
@@ -1705,6 +1705,31 @@ def test_deploy_with_update_endpoint(sagemaker_session):
17051705
sagemaker_session.create_endpoint.assert_not_called()
17061706

17071707

1708+
def test_deploy_with_model_name(sagemaker_session):
1709+
estimator = Estimator(IMAGE_NAME, ROLE, INSTANCE_COUNT, INSTANCE_TYPE, output_path=OUTPUT_PATH,
1710+
sagemaker_session=sagemaker_session)
1711+
estimator.set_hyperparameters(**HYPERPARAMS)
1712+
estimator.fit({"train": "s3://bucket/training-prefix"})
1713+
model_name = "model-name"
1714+
estimator.deploy(INSTANCE_COUNT, INSTANCE_TYPE, model_name=model_name)
1715+
1716+
sagemaker_session.create_model.assert_called_once()
1717+
args, kwargs = sagemaker_session.create_model.call_args
1718+
assert args[0] == model_name
1719+
1720+
1721+
def test_deploy_with_no_model_name(sagemaker_session):
1722+
estimator = Estimator(IMAGE_NAME, ROLE, INSTANCE_COUNT, INSTANCE_TYPE, output_path=OUTPUT_PATH,
1723+
sagemaker_session=sagemaker_session)
1724+
estimator.set_hyperparameters(**HYPERPARAMS)
1725+
estimator.fit({'train': 's3://bucket/training-prefix'})
1726+
estimator.deploy(INSTANCE_COUNT, INSTANCE_TYPE)
1727+
1728+
sagemaker_session.create_model.assert_called_once()
1729+
args, kwargs = sagemaker_session.create_model.call_args
1730+
assert args[0].startswith(IMAGE_NAME)
1731+
1732+
17081733
@patch("sagemaker.estimator.LocalSession")
17091734
@patch("sagemaker.estimator.Session")
17101735
def test_local_mode(session_class, local_session_class):

tests/unit/test_tuner.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -645,7 +645,8 @@ def test_deploy_default(tuner):
645645

646646
tuner.estimator.sagemaker_session.create_model.assert_called_once()
647647
args = tuner.estimator.sagemaker_session.create_model.call_args[0]
648-
assert args[0].startswith(IMAGE_NAME)
648+
649+
assert args[0] == 'neo'
649650
assert args[1] == ROLE
650651
assert args[2]["Image"] == IMAGE_NAME
651652
assert args[2]["ModelDataUrl"] == MODEL_DATA

0 commit comments

Comments
 (0)