Skip to content

Commit 818f43a

Browse files
authored
infra: move unit tests for updating an endpoint to test_deploy.py (#1426)
1 parent cca0d40 commit 818f43a

File tree

2 files changed

+71
-74
lines changed

2 files changed

+71
-74
lines changed

tests/unit/sagemaker/model/test_deploy.py

Lines changed: 71 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
TIMESTAMP = "2017-10-10-14-14-15"
2626
MODEL_NAME = "{}-{}".format(MODEL_IMAGE, TIMESTAMP)
2727

28+
ACCELERATOR_TYPE = "ml.eia.medium"
2829
INSTANCE_COUNT = 2
2930
INSTANCE_TYPE = "ml.c4.4xlarge"
3031
ROLE = "some-role"
@@ -83,21 +84,19 @@ def test_deploy_accelerator_type(production_variant, create_sagemaker_model, sag
8384
MODEL_DATA, MODEL_IMAGE, role=ROLE, name=MODEL_NAME, sagemaker_session=sagemaker_session
8485
)
8586

86-
accelerator_type = "ml.eia.medium"
87-
8887
production_variant_result = copy.deepcopy(BASE_PRODUCTION_VARIANT)
89-
production_variant_result["AcceleratorType"] = accelerator_type
88+
production_variant_result["AcceleratorType"] = ACCELERATOR_TYPE
9089
production_variant.return_value = production_variant_result
9190

9291
model.deploy(
9392
instance_type=INSTANCE_TYPE,
9493
initial_instance_count=INSTANCE_COUNT,
95-
accelerator_type=accelerator_type,
94+
accelerator_type=ACCELERATOR_TYPE,
9695
)
9796

98-
create_sagemaker_model.assert_called_with(INSTANCE_TYPE, accelerator_type, None)
97+
create_sagemaker_model.assert_called_with(INSTANCE_TYPE, ACCELERATOR_TYPE, None)
9998
production_variant.assert_called_with(
100-
MODEL_NAME, INSTANCE_TYPE, INSTANCE_COUNT, accelerator_type=accelerator_type
99+
MODEL_NAME, INSTANCE_TYPE, INSTANCE_COUNT, accelerator_type=ACCELERATOR_TYPE
101100
)
102101

103102
sagemaker_session.endpoint_from_production_variants.assert_called_with(
@@ -267,3 +266,69 @@ def test_deploy_predictor_cls(production_variant, sagemaker_session):
267266
assert isinstance(predictor, sagemaker.predictor.RealTimePredictor)
268267
assert predictor.endpoint == endpoint_name
269268
assert predictor.sagemaker_session == sagemaker_session
269+
270+
271+
def test_deploy_update_endpoint(sagemaker_session):
272+
model = Model(MODEL_DATA, MODEL_IMAGE, role=ROLE, sagemaker_session=sagemaker_session)
273+
model.deploy(
274+
instance_type=INSTANCE_TYPE, initial_instance_count=INSTANCE_COUNT, update_endpoint=True
275+
)
276+
sagemaker_session.create_endpoint_config.assert_called_with(
277+
name=model.name,
278+
model_name=model.name,
279+
initial_instance_count=INSTANCE_COUNT,
280+
instance_type=INSTANCE_TYPE,
281+
accelerator_type=None,
282+
tags=None,
283+
kms_key=None,
284+
data_capture_config_dict=None,
285+
)
286+
config_name = sagemaker_session.create_endpoint_config(
287+
name=model.name,
288+
model_name=model.name,
289+
initial_instance_count=INSTANCE_COUNT,
290+
instance_type=INSTANCE_TYPE,
291+
accelerator_type=ACCELERATOR_TYPE,
292+
)
293+
sagemaker_session.update_endpoint.assert_called_with(model.name, config_name, wait=True)
294+
sagemaker_session.create_endpoint.assert_not_called()
295+
296+
297+
def test_deploy_update_endpoint_optional_args(sagemaker_session):
298+
endpoint_name = "endpoint-name"
299+
tags = [{"Key": "Value"}]
300+
kms_key = "foo"
301+
data_capture_config = Mock()
302+
303+
model = Model(MODEL_DATA, MODEL_IMAGE, role=ROLE, sagemaker_session=sagemaker_session)
304+
model.deploy(
305+
instance_type=INSTANCE_TYPE,
306+
initial_instance_count=INSTANCE_COUNT,
307+
update_endpoint=True,
308+
endpoint_name=endpoint_name,
309+
accelerator_type=ACCELERATOR_TYPE,
310+
tags=tags,
311+
kms_key=kms_key,
312+
wait=False,
313+
data_capture_config=data_capture_config,
314+
)
315+
sagemaker_session.create_endpoint_config.assert_called_with(
316+
name=model.name,
317+
model_name=model.name,
318+
initial_instance_count=INSTANCE_COUNT,
319+
instance_type=INSTANCE_TYPE,
320+
accelerator_type=ACCELERATOR_TYPE,
321+
tags=tags,
322+
kms_key=kms_key,
323+
data_capture_config_dict=data_capture_config._to_request_dict(),
324+
)
325+
config_name = sagemaker_session.create_endpoint_config(
326+
name=model.name,
327+
model_name=model.name,
328+
initial_instance_count=INSTANCE_COUNT,
329+
instance_type=INSTANCE_TYPE,
330+
accelerator_type=ACCELERATOR_TYPE,
331+
wait=False,
332+
)
333+
sagemaker_session.update_endpoint.assert_called_with(endpoint_name, config_name, wait=False)
334+
sagemaker_session.create_endpoint.assert_not_called()

tests/unit/sagemaker/model/test_framework_model.py

Lines changed: 0 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,6 @@
3333
BUCKET_NAME = "mybucket"
3434
INSTANCE_COUNT = 1
3535
INSTANCE_TYPE = "c4.4xlarge"
36-
ACCELERATOR_TYPE = "ml.eia.medium"
37-
IMAGE_NAME = "fakeimage"
3836
REGION = "us-west-2"
3937
MODEL_NAME = "{}-{}".format(MODEL_IMAGE, TIMESTAMP)
4038
GIT_REPO = "https://github.com/aws/sagemaker-python-sdk.git"
@@ -165,72 +163,6 @@ def test_prepare_container_def_no_model_defaults(sagemaker_session, tmpdir):
165163
}
166164

167165

168-
@patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock())
169-
def test_deploy_update_endpoint(sagemaker_session, tmpdir):
170-
model = DummyFrameworkModel(sagemaker_session, source_dir=tmpdir)
171-
model.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=1, update_endpoint=True)
172-
sagemaker_session.create_endpoint_config.assert_called_with(
173-
name=model.name,
174-
model_name=model.name,
175-
initial_instance_count=INSTANCE_COUNT,
176-
instance_type=INSTANCE_TYPE,
177-
accelerator_type=None,
178-
tags=None,
179-
kms_key=None,
180-
data_capture_config_dict=None,
181-
)
182-
config_name = sagemaker_session.create_endpoint_config(
183-
name=model.name,
184-
model_name=model.name,
185-
initial_instance_count=INSTANCE_COUNT,
186-
instance_type=INSTANCE_TYPE,
187-
accelerator_type=ACCELERATOR_TYPE,
188-
)
189-
sagemaker_session.update_endpoint.assert_called_with(model.name, config_name, wait=True)
190-
sagemaker_session.create_endpoint.assert_not_called()
191-
192-
193-
@patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock())
194-
def test_deploy_update_endpoint_optional_args(sagemaker_session, tmpdir):
195-
endpoint_name = "endpoint-name"
196-
tags = [{"Key": "Value"}]
197-
kms_key = "foo"
198-
data_capture_config = MagicMock()
199-
200-
model = DummyFrameworkModel(sagemaker_session, source_dir=tmpdir)
201-
model.deploy(
202-
instance_type=INSTANCE_TYPE,
203-
initial_instance_count=1,
204-
update_endpoint=True,
205-
endpoint_name=endpoint_name,
206-
accelerator_type=ACCELERATOR_TYPE,
207-
tags=tags,
208-
kms_key=kms_key,
209-
wait=False,
210-
data_capture_config=data_capture_config,
211-
)
212-
sagemaker_session.create_endpoint_config.assert_called_with(
213-
name=model.name,
214-
model_name=model.name,
215-
initial_instance_count=INSTANCE_COUNT,
216-
instance_type=INSTANCE_TYPE,
217-
accelerator_type=ACCELERATOR_TYPE,
218-
tags=tags,
219-
kms_key=kms_key,
220-
data_capture_config_dict=data_capture_config._to_request_dict(),
221-
)
222-
config_name = sagemaker_session.create_endpoint_config(
223-
name=model.name,
224-
model_name=model.name,
225-
initial_instance_count=INSTANCE_COUNT,
226-
instance_type=INSTANCE_TYPE,
227-
accelerator_type=ACCELERATOR_TYPE,
228-
wait=False,
229-
)
230-
sagemaker_session.update_endpoint.assert_called_with(endpoint_name, config_name, wait=False)
231-
sagemaker_session.create_endpoint.assert_not_called()
232-
233-
234166
@patch("sagemaker.git_utils.git_clone_repo")
235167
@patch("sagemaker.model.fw_utils.tar_and_upload_dir")
236168
def test_git_support_succeed(tar_and_upload_dir, git_clone_repo, sagemaker_session):

0 commit comments

Comments
 (0)