Skip to content

Commit 7342cc7

Browse files
authored
breaking: create new inference resources during model.deploy() and model.transformer() (#1666)
This also changes how model and endpoint names are generated for compiled models, and removes the unused private attribute model._model_name.
1 parent 2d0d549 commit 7342cc7

File tree

6 files changed

+255
-70
lines changed

6 files changed

+255
-70
lines changed

src/sagemaker/model.py

+32-14
Original file line numberDiff line numberDiff line change
@@ -116,9 +116,9 @@ def __init__(
116116
self.predictor_cls = predictor_cls
117117
self.env = env or {}
118118
self.name = name
119+
self._base_name = None
119120
self.vpc_config = vpc_config
120121
self.sagemaker_session = sagemaker_session
121-
self._model_name = None
122122
self.endpoint_name = None
123123
self._is_compiled_model = False
124124
self._enable_network_isolation = enable_network_isolation
@@ -184,7 +184,10 @@ def _create_sagemaker_model(self, instance_type=None, accelerator_type=None, tag
184184
/api/latest/reference/services/sagemaker.html#SageMaker.Client.add_tags
185185
"""
186186
container_def = self.prepare_container_def(instance_type, accelerator_type=accelerator_type)
187-
self.name = self.name or utils.name_from_image(container_def["Image"])
187+
188+
self._ensure_base_name_if_needed(container_def["Image"])
189+
self._set_model_name_if_needed()
190+
188191
enable_network_isolation = self.enable_network_isolation()
189192

190193
self._init_sagemaker_session_if_does_not_exist(instance_type)
@@ -197,6 +200,16 @@ def _create_sagemaker_model(self, instance_type=None, accelerator_type=None, tag
197200
tags=tags,
198201
)
199202

203+
def _ensure_base_name_if_needed(self, image):
204+
"""Create a base name from the image if there is no model name provided."""
205+
if self.name is None:
206+
self._base_name = self._base_name or utils.base_name_from_image(image)
207+
208+
def _set_model_name_if_needed(self):
209+
"""Generate a new model name if ``self._base_name`` is present."""
210+
if self._base_name:
211+
self.name = utils.name_from_base(self._base_name)
212+
200213
def _framework(self):
201214
"""Placeholder docstring"""
202215
return getattr(self, "__framework_name__", None)
@@ -471,10 +484,9 @@ def deploy(
471484

472485
compiled_model_suffix = "-".join(instance_type.split(".")[:-1])
473486
if self._is_compiled_model:
474-
name_prefix = self.name or utils.name_from_image(
475-
self.image, max_length=(62 - len(compiled_model_suffix))
476-
)
477-
self.name = "{}-{}".format(name_prefix, compiled_model_suffix)
487+
self._ensure_base_name_if_needed(self.image)
488+
if self._base_name is not None:
489+
self._base_name = "-".join((self._base_name, compiled_model_suffix))
478490

479491
self._create_sagemaker_model(instance_type, accelerator_type, tags)
480492
production_variant = sagemaker.production_variant(
@@ -483,9 +495,10 @@ def deploy(
483495
if endpoint_name:
484496
self.endpoint_name = endpoint_name
485497
else:
486-
self.endpoint_name = self.name
487-
if self._is_compiled_model and not self.endpoint_name.endswith(compiled_model_suffix):
488-
self.endpoint_name += compiled_model_suffix
498+
base_endpoint_name = self._base_name or utils.base_from_name(self.name)
499+
if self._is_compiled_model and not base_endpoint_name.endswith(compiled_model_suffix):
500+
base_endpoint_name = "-".join((base_endpoint_name, compiled_model_suffix))
501+
self.endpoint_name = utils.name_from_base(base_endpoint_name)
489502

490503
data_capture_config_dict = None
491504
if data_capture_config is not None:
@@ -568,7 +581,7 @@ def transformer(
568581
max_payload=max_payload,
569582
env=env,
570583
tags=tags,
571-
base_transform_job_name=self.name,
584+
base_transform_job_name=self._base_name or self.name,
572585
volume_kms_key=volume_kms_key,
573586
sagemaker_session=self.sagemaker_session,
574587
)
@@ -994,13 +1007,18 @@ def _create_sagemaker_model(self, *args, **kwargs): # pylint: disable=unused-ar
9941007
if self.env != {}:
9951008
container_def["Environment"] = self.env
9961009

997-
model_package_short_name = model_package_name.split("/")[-1]
998-
enable_network_isolation = self.enable_network_isolation()
999-
self.name = self.name or utils.name_from_base(model_package_short_name)
1010+
self._ensure_base_name_if_needed(model_package_name.split("/")[-1])
1011+
self._set_model_name_if_needed()
1012+
10001013
self.sagemaker_session.create_model(
10011014
self.name,
10021015
self.role,
10031016
container_def,
10041017
vpc_config=self.vpc_config,
1005-
enable_network_isolation=enable_network_isolation,
1018+
enable_network_isolation=self.enable_network_isolation(),
10061019
)
1020+
1021+
def _ensure_base_name_if_needed(self, base_name):
1022+
"""Set the base name if there is no model name provided."""
1023+
if self.name is None:
1024+
self._base_name = base_name

src/sagemaker/pipeline.py

-1
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@ def __init__(
6060
self.name = name
6161
self.vpc_config = vpc_config
6262
self.sagemaker_session = sagemaker_session
63-
self._model_name = None
6463
self.endpoint_name = None
6564

6665
def pipeline_container_def(self, instance_type):

tests/unit/sagemaker/model/test_deploy.py

+72-16
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,9 @@
2222

2323
MODEL_DATA = "s3://bucket/model.tar.gz"
2424
MODEL_IMAGE = "mi"
25-
TIMESTAMP = "2017-10-10-14-14-15"
25+
TIMESTAMP = "2020-07-02-20-10-30-288"
2626
MODEL_NAME = "{}-{}".format(MODEL_IMAGE, TIMESTAMP)
27+
ENDPOINT_NAME = "endpoint-{}".format(TIMESTAMP)
2728

2829
ACCELERATOR_TYPE = "ml.eia.medium"
2930
INSTANCE_COUNT = 2
@@ -46,9 +47,8 @@ def sagemaker_session():
4647

4748
@patch("sagemaker.production_variant")
4849
@patch("sagemaker.model.Model.prepare_container_def")
49-
@patch("sagemaker.utils.name_from_image")
50-
def test_deploy(name_from_image, prepare_container_def, production_variant, sagemaker_session):
51-
name_from_image.return_value = MODEL_NAME
50+
@patch("sagemaker.utils.name_from_base", return_value=MODEL_NAME)
51+
def test_deploy(name_from_base, prepare_container_def, production_variant, sagemaker_session):
5252
production_variant.return_value = BASE_PRODUCTION_VARIANT
5353

5454
container_def = {"Image": MODEL_IMAGE, "Environment": {}, "ModelDataUrl": MODEL_DATA}
@@ -57,7 +57,9 @@ def test_deploy(name_from_image, prepare_container_def, production_variant, sage
5757
model = Model(MODEL_IMAGE, MODEL_DATA, role=ROLE, sagemaker_session=sagemaker_session)
5858
model.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=INSTANCE_COUNT)
5959

60-
name_from_image.assert_called_with(MODEL_IMAGE)
60+
name_from_base.assert_called_with(MODEL_IMAGE)
61+
assert 2 == name_from_base.call_count
62+
6163
prepare_container_def.assert_called_with(INSTANCE_TYPE, accelerator_type=None)
6264
production_variant.assert_called_with(
6365
MODEL_NAME, INSTANCE_TYPE, INSTANCE_COUNT, accelerator_type=None
@@ -77,9 +79,12 @@ def test_deploy(name_from_image, prepare_container_def, production_variant, sage
7779
)
7880

7981

82+
@patch("sagemaker.utils.name_from_base", return_value=ENDPOINT_NAME)
8083
@patch("sagemaker.model.Model._create_sagemaker_model")
8184
@patch("sagemaker.production_variant")
82-
def test_deploy_accelerator_type(production_variant, create_sagemaker_model, sagemaker_session):
85+
def test_deploy_accelerator_type(
86+
production_variant, create_sagemaker_model, name_from_base, sagemaker_session
87+
):
8388
model = Model(
8489
MODEL_IMAGE, MODEL_DATA, role=ROLE, name=MODEL_NAME, sagemaker_session=sagemaker_session
8590
)
@@ -100,7 +105,7 @@ def test_deploy_accelerator_type(production_variant, create_sagemaker_model, sag
100105
)
101106

102107
sagemaker_session.endpoint_from_production_variants.assert_called_with(
103-
name=MODEL_NAME,
108+
name=ENDPOINT_NAME,
104109
production_variants=[production_variant_result],
105110
tags=None,
106111
kms_key=None,
@@ -109,7 +114,6 @@ def test_deploy_accelerator_type(production_variant, create_sagemaker_model, sag
109114
)
110115

111116

112-
@patch("sagemaker.utils.name_from_image", Mock())
113117
@patch("sagemaker.model.Model._create_sagemaker_model", Mock())
114118
@patch("sagemaker.production_variant", return_value=BASE_PRODUCTION_VARIANT)
115119
def test_deploy_endpoint_name(sagemaker_session):
@@ -122,6 +126,7 @@ def test_deploy_endpoint_name(sagemaker_session):
122126
initial_instance_count=INSTANCE_COUNT,
123127
)
124128

129+
assert endpoint_name == model.endpoint_name
125130
sagemaker_session.endpoint_from_production_variants.assert_called_with(
126131
name=endpoint_name,
127132
production_variants=[BASE_PRODUCTION_VARIANT],
@@ -132,9 +137,57 @@ def test_deploy_endpoint_name(sagemaker_session):
132137
)
133138

134139

140+
@patch("sagemaker.model.Model._create_sagemaker_model", Mock())
141+
@patch("sagemaker.utils.name_from_base")
142+
@patch("sagemaker.utils.base_from_name")
143+
@patch("sagemaker.production_variant")
144+
def test_deploy_generates_endpoint_name_each_time_from_model_name(
145+
production_variant, base_from_name, name_from_base, sagemaker_session
146+
):
147+
model = Model(
148+
MODEL_IMAGE, MODEL_DATA, name=MODEL_NAME, role=ROLE, sagemaker_session=sagemaker_session
149+
)
150+
151+
model.deploy(
152+
instance_type=INSTANCE_TYPE, initial_instance_count=INSTANCE_COUNT,
153+
)
154+
model.deploy(
155+
instance_type=INSTANCE_TYPE, initial_instance_count=INSTANCE_COUNT,
156+
)
157+
158+
base_from_name.assert_called_with(MODEL_NAME)
159+
name_from_base.assert_called_with(base_from_name.return_value)
160+
assert 2 == name_from_base.call_count
161+
162+
163+
@patch("sagemaker.model.Model._create_sagemaker_model", Mock())
164+
@patch("sagemaker.utils.name_from_base")
165+
@patch("sagemaker.utils.base_from_name")
166+
@patch("sagemaker.production_variant")
167+
def test_deploy_generates_endpoint_name_each_time_from_base_name(
168+
production_variant, base_from_name, name_from_base, sagemaker_session
169+
):
170+
model = Model(MODEL_IMAGE, MODEL_DATA, role=ROLE, sagemaker_session=sagemaker_session)
171+
172+
base_name = "foo"
173+
model._base_name = base_name
174+
175+
model.deploy(
176+
instance_type=INSTANCE_TYPE, initial_instance_count=INSTANCE_COUNT,
177+
)
178+
model.deploy(
179+
instance_type=INSTANCE_TYPE, initial_instance_count=INSTANCE_COUNT,
180+
)
181+
182+
base_from_name.assert_not_called()
183+
name_from_base.assert_called_with(base_name)
184+
assert 2 == name_from_base.call_count
185+
186+
187+
@patch("sagemaker.utils.name_from_base", return_value=ENDPOINT_NAME)
135188
@patch("sagemaker.production_variant", return_value=BASE_PRODUCTION_VARIANT)
136189
@patch("sagemaker.model.Model._create_sagemaker_model")
137-
def test_deploy_tags(create_sagemaker_model, production_variant, sagemaker_session):
190+
def test_deploy_tags(create_sagemaker_model, production_variant, name_from_base, sagemaker_session):
138191
model = Model(
139192
MODEL_IMAGE, MODEL_DATA, role=ROLE, name=MODEL_NAME, sagemaker_session=sagemaker_session
140193
)
@@ -144,7 +197,7 @@ def test_deploy_tags(create_sagemaker_model, production_variant, sagemaker_sessi
144197

145198
create_sagemaker_model.assert_called_with(INSTANCE_TYPE, None, tags)
146199
sagemaker_session.endpoint_from_production_variants.assert_called_with(
147-
name=MODEL_NAME,
200+
name=ENDPOINT_NAME,
148201
production_variants=[BASE_PRODUCTION_VARIANT],
149202
tags=tags,
150203
kms_key=None,
@@ -154,8 +207,9 @@ def test_deploy_tags(create_sagemaker_model, production_variant, sagemaker_sessi
154207

155208

156209
@patch("sagemaker.model.Model._create_sagemaker_model", Mock())
210+
@patch("sagemaker.utils.name_from_base", return_value=ENDPOINT_NAME)
157211
@patch("sagemaker.production_variant", return_value=BASE_PRODUCTION_VARIANT)
158-
def test_deploy_kms_key(production_variant, sagemaker_session):
212+
def test_deploy_kms_key(production_variant, name_from_base, sagemaker_session):
159213
model = Model(
160214
MODEL_IMAGE, MODEL_DATA, role=ROLE, name=MODEL_NAME, sagemaker_session=sagemaker_session
161215
)
@@ -164,7 +218,7 @@ def test_deploy_kms_key(production_variant, sagemaker_session):
164218
model.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=INSTANCE_COUNT, kms_key=key)
165219

166220
sagemaker_session.endpoint_from_production_variants.assert_called_with(
167-
name=MODEL_NAME,
221+
name=ENDPOINT_NAME,
168222
production_variants=[BASE_PRODUCTION_VARIANT],
169223
tags=None,
170224
kms_key=key,
@@ -174,16 +228,17 @@ def test_deploy_kms_key(production_variant, sagemaker_session):
174228

175229

176230
@patch("sagemaker.model.Model._create_sagemaker_model", Mock())
231+
@patch("sagemaker.utils.name_from_base", return_value=ENDPOINT_NAME)
177232
@patch("sagemaker.production_variant", return_value=BASE_PRODUCTION_VARIANT)
178-
def test_deploy_async(production_variant, sagemaker_session):
233+
def test_deploy_async(production_variant, name_from_base, sagemaker_session):
179234
model = Model(
180235
MODEL_IMAGE, MODEL_DATA, role=ROLE, name=MODEL_NAME, sagemaker_session=sagemaker_session
181236
)
182237

183238
model.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=INSTANCE_COUNT, wait=False)
184239

185240
sagemaker_session.endpoint_from_production_variants.assert_called_with(
186-
name=MODEL_NAME,
241+
name=ENDPOINT_NAME,
187242
production_variants=[BASE_PRODUCTION_VARIANT],
188243
tags=None,
189244
kms_key=None,
@@ -193,8 +248,9 @@ def test_deploy_async(production_variant, sagemaker_session):
193248

194249

195250
@patch("sagemaker.model.Model._create_sagemaker_model", Mock())
251+
@patch("sagemaker.utils.name_from_base", return_value=ENDPOINT_NAME)
196252
@patch("sagemaker.production_variant", return_value=BASE_PRODUCTION_VARIANT)
197-
def test_deploy_data_capture_config(production_variant, sagemaker_session):
253+
def test_deploy_data_capture_config(production_variant, name_from_base, sagemaker_session):
198254
model = Model(
199255
MODEL_IMAGE, MODEL_DATA, role=ROLE, name=MODEL_NAME, sagemaker_session=sagemaker_session
200256
)
@@ -210,7 +266,7 @@ def test_deploy_data_capture_config(production_variant, sagemaker_session):
210266

211267
data_capture_config._to_request_dict.assert_called_with()
212268
sagemaker_session.endpoint_from_production_variants.assert_called_with(
213-
name=MODEL_NAME,
269+
name=ENDPOINT_NAME,
214270
production_variants=[BASE_PRODUCTION_VARIANT],
215271
tags=None,
216272
kms_key=None,

0 commit comments

Comments
 (0)