Skip to content

Commit d753109

Browse files
authored
infra: move Model.deploy unit tests to separate file (#1425)
1 parent 98a8037 commit d753109

File tree

2 files changed

+269
-236
lines changed

2 files changed

+269
-236
lines changed
Lines changed: 269 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,269 @@
1+
# Copyright 2017-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import copy
16+
17+
import pytest
18+
from mock import Mock, patch
19+
20+
import sagemaker
21+
from sagemaker.model import Model
22+
23+
MODEL_DATA = "s3://bucket/model.tar.gz"
24+
MODEL_IMAGE = "mi"
25+
TIMESTAMP = "2017-10-10-14-14-15"
26+
MODEL_NAME = "{}-{}".format(MODEL_IMAGE, TIMESTAMP)
27+
28+
INSTANCE_COUNT = 2
29+
INSTANCE_TYPE = "ml.c4.4xlarge"
30+
ROLE = "some-role"
31+
32+
BASE_PRODUCTION_VARIANT = {
33+
"ModelName": MODEL_NAME,
34+
"InstanceType": INSTANCE_TYPE,
35+
"InitialInstanceCount": INSTANCE_COUNT,
36+
"VariantName": "AllTraffic",
37+
"InitialVariantWeight": 1,
38+
}
39+
40+
41+
@pytest.fixture
42+
def sagemaker_session():
43+
return Mock()
44+
45+
46+
@patch("sagemaker.production_variant")
47+
@patch("sagemaker.model.Model.prepare_container_def")
48+
@patch("sagemaker.utils.name_from_image")
49+
def test_deploy(name_from_image, prepare_container_def, production_variant, sagemaker_session):
50+
name_from_image.return_value = MODEL_NAME
51+
production_variant.return_value = BASE_PRODUCTION_VARIANT
52+
53+
container_def = {"Image": MODEL_IMAGE, "Environment": {}, "ModelDataUrl": MODEL_DATA}
54+
prepare_container_def.return_value = container_def
55+
56+
model = Model(MODEL_DATA, MODEL_IMAGE, role=ROLE, sagemaker_session=sagemaker_session)
57+
model.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=INSTANCE_COUNT)
58+
59+
name_from_image.assert_called_with(MODEL_IMAGE)
60+
prepare_container_def.assert_called_with(INSTANCE_TYPE, accelerator_type=None)
61+
production_variant.assert_called_with(
62+
MODEL_NAME, INSTANCE_TYPE, INSTANCE_COUNT, accelerator_type=None
63+
)
64+
65+
sagemaker_session.create_model.assert_called_with(
66+
MODEL_NAME, ROLE, container_def, vpc_config=None, enable_network_isolation=False, tags=None
67+
)
68+
69+
sagemaker_session.endpoint_from_production_variants.assert_called_with(
70+
name=MODEL_NAME,
71+
production_variants=[BASE_PRODUCTION_VARIANT],
72+
tags=None,
73+
kms_key=None,
74+
wait=True,
75+
data_capture_config_dict=None,
76+
)
77+
78+
79+
@patch("sagemaker.model.Model._create_sagemaker_model")
80+
@patch("sagemaker.production_variant")
81+
def test_deploy_accelerator_type(production_variant, create_sagemaker_model, sagemaker_session):
82+
model = Model(
83+
MODEL_DATA, MODEL_IMAGE, role=ROLE, name=MODEL_NAME, sagemaker_session=sagemaker_session
84+
)
85+
86+
accelerator_type = "ml.eia.medium"
87+
88+
production_variant_result = copy.deepcopy(BASE_PRODUCTION_VARIANT)
89+
production_variant_result["AcceleratorType"] = accelerator_type
90+
production_variant.return_value = production_variant_result
91+
92+
model.deploy(
93+
instance_type=INSTANCE_TYPE,
94+
initial_instance_count=INSTANCE_COUNT,
95+
accelerator_type=accelerator_type,
96+
)
97+
98+
create_sagemaker_model.assert_called_with(INSTANCE_TYPE, accelerator_type, None)
99+
production_variant.assert_called_with(
100+
MODEL_NAME, INSTANCE_TYPE, INSTANCE_COUNT, accelerator_type=accelerator_type
101+
)
102+
103+
sagemaker_session.endpoint_from_production_variants.assert_called_with(
104+
name=MODEL_NAME,
105+
production_variants=[production_variant_result],
106+
tags=None,
107+
kms_key=None,
108+
wait=True,
109+
data_capture_config_dict=None,
110+
)
111+
112+
113+
@patch("sagemaker.utils.name_from_image", Mock())
114+
@patch("sagemaker.model.Model._create_sagemaker_model", Mock())
115+
@patch("sagemaker.production_variant", return_value=BASE_PRODUCTION_VARIANT)
116+
def test_deploy_endpoint_name(sagemaker_session):
117+
model = Model(MODEL_DATA, MODEL_IMAGE, role=ROLE, sagemaker_session=sagemaker_session)
118+
119+
endpoint_name = "blah"
120+
model.deploy(
121+
endpoint_name=endpoint_name,
122+
instance_type=INSTANCE_TYPE,
123+
initial_instance_count=INSTANCE_COUNT,
124+
)
125+
126+
sagemaker_session.endpoint_from_production_variants.assert_called_with(
127+
name=endpoint_name,
128+
production_variants=[BASE_PRODUCTION_VARIANT],
129+
tags=None,
130+
kms_key=None,
131+
wait=True,
132+
data_capture_config_dict=None,
133+
)
134+
135+
136+
@patch("sagemaker.production_variant", return_value=BASE_PRODUCTION_VARIANT)
137+
@patch("sagemaker.model.Model._create_sagemaker_model")
138+
def test_deploy_tags(create_sagemaker_model, production_variant, sagemaker_session):
139+
model = Model(
140+
MODEL_DATA, MODEL_IMAGE, role=ROLE, name=MODEL_NAME, sagemaker_session=sagemaker_session
141+
)
142+
143+
tags = [{"Key": "ModelName", "Value": "TestModel"}]
144+
model.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=INSTANCE_COUNT, tags=tags)
145+
146+
create_sagemaker_model.assert_called_with(INSTANCE_TYPE, None, tags)
147+
sagemaker_session.endpoint_from_production_variants.assert_called_with(
148+
name=MODEL_NAME,
149+
production_variants=[BASE_PRODUCTION_VARIANT],
150+
tags=tags,
151+
kms_key=None,
152+
wait=True,
153+
data_capture_config_dict=None,
154+
)
155+
156+
157+
@patch("sagemaker.model.Model._create_sagemaker_model", Mock())
158+
@patch("sagemaker.production_variant", return_value=BASE_PRODUCTION_VARIANT)
159+
def test_deploy_kms_key(production_variant, sagemaker_session):
160+
model = Model(
161+
MODEL_DATA, MODEL_IMAGE, role=ROLE, name=MODEL_NAME, sagemaker_session=sagemaker_session
162+
)
163+
164+
key = "some-key-arn"
165+
model.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=INSTANCE_COUNT, kms_key=key)
166+
167+
sagemaker_session.endpoint_from_production_variants.assert_called_with(
168+
name=MODEL_NAME,
169+
production_variants=[BASE_PRODUCTION_VARIANT],
170+
tags=None,
171+
kms_key=key,
172+
wait=True,
173+
data_capture_config_dict=None,
174+
)
175+
176+
177+
@patch("sagemaker.model.Model._create_sagemaker_model", Mock())
178+
@patch("sagemaker.production_variant", return_value=BASE_PRODUCTION_VARIANT)
179+
def test_deploy_async(production_variant, sagemaker_session):
180+
model = Model(
181+
MODEL_DATA, MODEL_IMAGE, role=ROLE, name=MODEL_NAME, sagemaker_session=sagemaker_session
182+
)
183+
184+
model.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=INSTANCE_COUNT, wait=False)
185+
186+
sagemaker_session.endpoint_from_production_variants.assert_called_with(
187+
name=MODEL_NAME,
188+
production_variants=[BASE_PRODUCTION_VARIANT],
189+
tags=None,
190+
kms_key=None,
191+
wait=False,
192+
data_capture_config_dict=None,
193+
)
194+
195+
196+
@patch("sagemaker.model.Model._create_sagemaker_model", Mock())
197+
@patch("sagemaker.production_variant", return_value=BASE_PRODUCTION_VARIANT)
198+
def test_deploy_data_capture_config(production_variant, sagemaker_session):
199+
model = Model(
200+
MODEL_DATA, MODEL_IMAGE, role=ROLE, name=MODEL_NAME, sagemaker_session=sagemaker_session
201+
)
202+
203+
data_capture_config = Mock()
204+
data_capture_config_dict = {"EnableCapture": True}
205+
data_capture_config._to_request_dict.return_value = data_capture_config_dict
206+
model.deploy(
207+
instance_type=INSTANCE_TYPE,
208+
initial_instance_count=INSTANCE_COUNT,
209+
data_capture_config=data_capture_config,
210+
)
211+
212+
data_capture_config._to_request_dict.assert_called_with()
213+
sagemaker_session.endpoint_from_production_variants.assert_called_with(
214+
name=MODEL_NAME,
215+
production_variants=[BASE_PRODUCTION_VARIANT],
216+
tags=None,
217+
kms_key=None,
218+
wait=True,
219+
data_capture_config_dict=data_capture_config_dict,
220+
)
221+
222+
223+
@patch("sagemaker.session.Session")
224+
@patch("sagemaker.local.LocalSession")
225+
def test_deploy_creates_correct_session(local_session, session):
226+
# We expect a LocalSession when deploying to instance_type = 'local'
227+
model = Model(MODEL_DATA, MODEL_IMAGE, role=ROLE)
228+
model.deploy(endpoint_name="blah", instance_type="local", initial_instance_count=1)
229+
assert model.sagemaker_session == local_session.return_value
230+
231+
# We expect a real Session when deploying to instance_type != local/local_gpu
232+
model = Model(MODEL_DATA, MODEL_IMAGE, role=ROLE)
233+
model.deploy(
234+
endpoint_name="remote_endpoint", instance_type="ml.m4.4xlarge", initial_instance_count=2
235+
)
236+
assert model.sagemaker_session == session.return_value
237+
238+
239+
def test_deploy_no_role(sagemaker_session):
240+
model = Model(MODEL_DATA, MODEL_IMAGE, sagemaker_session=sagemaker_session)
241+
242+
with pytest.raises(ValueError, match="Role can not be null for deploying a model"):
243+
model.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=INSTANCE_COUNT)
244+
245+
246+
@patch("sagemaker.model.Model._create_sagemaker_model", Mock())
247+
@patch("sagemaker.predictor.RealTimePredictor._get_endpoint_config_name", Mock())
248+
@patch("sagemaker.predictor.RealTimePredictor._get_model_names", Mock())
249+
@patch("sagemaker.production_variant", return_value=BASE_PRODUCTION_VARIANT)
250+
def test_deploy_predictor_cls(production_variant, sagemaker_session):
251+
model = Model(
252+
MODEL_DATA,
253+
MODEL_IMAGE,
254+
role=ROLE,
255+
name=MODEL_NAME,
256+
predictor_cls=sagemaker.predictor.RealTimePredictor,
257+
sagemaker_session=sagemaker_session,
258+
)
259+
260+
endpoint_name = "foo"
261+
predictor = model.deploy(
262+
instance_type=INSTANCE_TYPE,
263+
initial_instance_count=INSTANCE_COUNT,
264+
endpoint_name=endpoint_name,
265+
)
266+
267+
assert isinstance(predictor, sagemaker.predictor.RealTimePredictor)
268+
assert predictor.endpoint == endpoint_name
269+
assert predictor.sagemaker_session == sagemaker_session

0 commit comments

Comments
 (0)