Skip to content

Commit 2416254

Browse files
authored
change: make instance_type optional for prepare_container_def (#1567)
This argument was used only for determining default framework image tags, and is not needed in many cases.
1 parent 8a0d640 commit 2416254

File tree

15 files changed

+126
-64
lines changed

15 files changed

+126
-64
lines changed

src/sagemaker/chainer/model.py

+11-12
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@
1515

1616
import logging
1717

18-
from sagemaker import fw_utils
19-
2018
import sagemaker
2119
from sagemaker.fw_utils import (
2220
create_image_uri,
@@ -126,7 +124,7 @@ def __init__(
126124
self.framework_version = framework_version or defaults.CHAINER_VERSION
127125
self.model_server_workers = model_server_workers
128126

129-
def prepare_container_def(self, instance_type, accelerator_type=None):
127+
def prepare_container_def(self, instance_type=None, accelerator_type=None):
130128
"""Return a container definition with framework configuration set in
131129
model environment variables.
132130
@@ -143,14 +141,14 @@ def prepare_container_def(self, instance_type, accelerator_type=None):
143141
"""
144142
deploy_image = self.image
145143
if not deploy_image:
144+
if instance_type is None:
145+
raise ValueError(
146+
"Must supply either an instance type (for choosing CPU vs GPU) or an image URI."
147+
)
148+
146149
region_name = self.sagemaker_session.boto_session.region_name
147-
deploy_image = create_image_uri(
148-
region_name,
149-
self.__framework_name__,
150-
instance_type,
151-
self.framework_version,
152-
self.py_version,
153-
accelerator_type=accelerator_type,
150+
deploy_image = self.serving_image_uri(
151+
region_name, instance_type, accelerator_type=accelerator_type
154152
)
155153

156154
deploy_key_prefix = model_code_key_prefix(self.key_prefix, self.name, deploy_image)
@@ -162,7 +160,7 @@ def prepare_container_def(self, instance_type, accelerator_type=None):
162160
deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str(self.model_server_workers)
163161
return sagemaker.container_def(deploy_image, self.model_data, deploy_env)
164162

165-
def serving_image_uri(self, region_name, instance_type):
163+
def serving_image_uri(self, region_name, instance_type, accelerator_type=None):
166164
"""Create a URI for the serving image.
167165
168166
Args:
@@ -174,10 +172,11 @@ def serving_image_uri(self, region_name, instance_type):
174172
str: The appropriate image URI based on the given parameters.
175173
176174
"""
177-
return fw_utils.create_image_uri(
175+
return create_image_uri(
178176
region_name,
179177
self.__framework_name__,
180178
instance_type,
181179
self.framework_version,
182180
self.py_version,
181+
accelerator_type=accelerator_type,
183182
)

src/sagemaker/model.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def _init_sagemaker_session_if_does_not_exist(self, instance_type):
138138
self.sagemaker_session = session.Session()
139139

140140
def prepare_container_def(
141-
self, instance_type, accelerator_type=None
141+
self, instance_type=None, accelerator_type=None
142142
): # pylint: disable=unused-argument
143143
"""Return a dict created by ``sagemaker.container_def()`` for deploying
144144
this model to a specified instance type.
@@ -166,7 +166,7 @@ def enable_network_isolation(self):
166166
"""
167167
return self._enable_network_isolation
168168

169-
def _create_sagemaker_model(self, instance_type, accelerator_type=None, tags=None):
169+
def _create_sagemaker_model(self, instance_type=None, accelerator_type=None, tags=None):
170170
"""Create a SageMaker Model Entity
171171
172172
Args:
@@ -807,9 +807,7 @@ def __init__(
807807
self.uploaded_code = None
808808
self.repacked_model_data = None
809809

810-
def prepare_container_def(
811-
self, instance_type, accelerator_type=None
812-
): # pylint disable=unused-argument
810+
def prepare_container_def(self, instance_type=None, accelerator_type=None):
813811
"""Return a container definition with framework configuration set in
814812
model environment variables.
815813

src/sagemaker/multidatamodel.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def __init__(
111111
**kwargs
112112
)
113113

114-
def prepare_container_def(self, instance_type, accelerator_type=None):
114+
def prepare_container_def(self, instance_type=None, accelerator_type=None):
115115
"""Return a container definition set with MultiModel mode,
116116
model data and other parameters from the model (if available).
117117

src/sagemaker/mxnet/model.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def __init__(
126126
self.framework_version = framework_version or defaults.MXNET_VERSION
127127
self.model_server_workers = model_server_workers
128128

129-
def prepare_container_def(self, instance_type, accelerator_type=None):
129+
def prepare_container_def(self, instance_type=None, accelerator_type=None):
130130
"""Return a container definition with framework configuration set in
131131
model environment variables.
132132
@@ -143,6 +143,11 @@ def prepare_container_def(self, instance_type, accelerator_type=None):
143143
"""
144144
deploy_image = self.image
145145
if not deploy_image:
146+
if instance_type is None:
147+
raise ValueError(
148+
"Must supply either an instance type (for choosing CPU vs GPU) or an image URI."
149+
)
150+
146151
region_name = self.sagemaker_session.boto_session.region_name
147152
deploy_image = self.serving_image_uri(
148153
region_name, instance_type, accelerator_type=accelerator_type

src/sagemaker/pytorch/model.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def __init__(
127127
self.framework_version = framework_version or defaults.PYTORCH_VERSION
128128
self.model_server_workers = model_server_workers
129129

130-
def prepare_container_def(self, instance_type, accelerator_type=None):
130+
def prepare_container_def(self, instance_type=None, accelerator_type=None):
131131
"""Return a container definition with framework configuration set in
132132
model environment variables.
133133
@@ -144,6 +144,11 @@ def prepare_container_def(self, instance_type, accelerator_type=None):
144144
"""
145145
deploy_image = self.image
146146
if not deploy_image:
147+
if instance_type is None:
148+
raise ValueError(
149+
"Must supply either an instance type (for choosing CPU vs GPU) or an image URI."
150+
)
151+
147152
region_name = self.sagemaker_session.boto_session.region_name
148153
deploy_image = self.serving_image_uri(
149154
region_name, instance_type, accelerator_type=accelerator_type

src/sagemaker/sklearn/model.py

+10-18
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@
1515

1616
import logging
1717

18-
from sagemaker import fw_utils
19-
2018
import sagemaker
2119
from sagemaker.fw_utils import model_code_key_prefix, python_deprecation_warning
2220
from sagemaker.fw_registry import default_framework_uri
@@ -118,16 +116,16 @@ def __init__(
118116
self.framework_version = framework_version
119117
self.model_server_workers = model_server_workers
120118

121-
def prepare_container_def(self, instance_type, accelerator_type=None):
119+
def prepare_container_def(self, instance_type=None, accelerator_type=None):
122120
"""Return a container definition with framework configuration set in
123121
model environment variables.
124122
125123
Args:
126124
instance_type (str): The EC2 instance type to deploy this Model to.
127-
For example, 'ml.p2.xlarge'.
125+
This parameter is unused because Scikit-learn supports only CPU.
128126
accelerator_type (str): The Elastic Inference accelerator type to
129127
deploy to the instance for loading and making inferences to the
130-
model. For example, 'ml.eia1.medium'. Note: accelerator types
128+
model. This parameter is unused because accelerator types
131129
are not supported by SKLearnModel.
132130
133131
Returns:
@@ -139,9 +137,8 @@ def prepare_container_def(self, instance_type, accelerator_type=None):
139137

140138
deploy_image = self.image
141139
if not deploy_image:
142-
image_tag = "{}-{}-{}".format(self.framework_version, "cpu", self.py_version)
143-
deploy_image = default_framework_uri(
144-
self.__framework_name__, self.sagemaker_session.boto_region_name, image_tag
140+
deploy_image = self.serving_image_uri(
141+
self.sagemaker_session.boto_region_name, instance_type
145142
)
146143

147144
deploy_key_prefix = model_code_key_prefix(self.key_prefix, self.name, deploy_image)
@@ -156,22 +153,17 @@ def prepare_container_def(self, instance_type, accelerator_type=None):
156153
)
157154
return sagemaker.container_def(deploy_image, model_data_uri, deploy_env)
158155

159-
def serving_image_uri(self, region_name, instance_type):
156+
def serving_image_uri(self, region_name, instance_type): # pylint: disable=unused-argument
160157
"""Create a URI for the serving image.
161158
162159
Args:
163160
region_name (str): AWS region where the image is uploaded.
164-
instance_type (str): SageMaker instance type. Used to determine device type
165-
(cpu/gpu/family-specific optimized).
161+
instance_type (str): SageMaker instance type. This parameter is unused because
162+
Scikit-learn supports only CPU.
166163
167164
Returns:
168165
str: The appropriate image URI based on the given parameters.
169166
170167
"""
171-
return fw_utils.create_image_uri(
172-
region_name,
173-
self.__framework_name__,
174-
instance_type,
175-
self.framework_version,
176-
self.py_version,
177-
)
168+
image_tag = "{}-{}-{}".format(self.framework_version, "cpu", self.py_version)
169+
return default_framework_uri(self.__framework_name__, region_name, image_tag)

src/sagemaker/tensorflow/model.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def __init__(
124124
self.framework_version = framework_version or defaults.TF_VERSION
125125
self.model_server_workers = model_server_workers
126126

127-
def prepare_container_def(self, instance_type, accelerator_type=None):
127+
def prepare_container_def(self, instance_type=None, accelerator_type=None):
128128
"""Return a container definition with framework configuration set in
129129
model environment variables.
130130
@@ -143,6 +143,11 @@ def prepare_container_def(self, instance_type, accelerator_type=None):
143143
"""
144144
deploy_image = self.image
145145
if not deploy_image:
146+
if instance_type is None:
147+
raise ValueError(
148+
"Must supply either an instance type (for choosing CPU vs GPU) or an image URI."
149+
)
150+
146151
region_name = self.sagemaker_session.boto_region_name
147152
deploy_image = self.serving_image_uri(
148153
region_name, instance_type, accelerator_type=accelerator_type

src/sagemaker/tensorflow/serving.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -215,12 +215,17 @@ def _eia_supported(self):
215215
"""Return true if TF version is EIA enabled"""
216216
return [int(s) for s in self._framework_version.split(".")][:2] <= self.LATEST_EIA_VERSION
217217

218-
def prepare_container_def(self, instance_type, accelerator_type=None):
218+
def prepare_container_def(self, instance_type=None, accelerator_type=None):
219219
"""
220220
Args:
221221
instance_type:
222222
accelerator_type:
223223
"""
224+
if self.image is None and instance_type is None:
225+
raise ValueError(
226+
"Must supply either an instance type (for choosing CPU vs GPU) or an image URI."
227+
)
228+
224229
image = self._get_image_uri(instance_type, accelerator_type)
225230
env = self._get_container_env()
226231

src/sagemaker/xgboost/model.py

+12-21
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@
1515

1616
import logging
1717

18-
from sagemaker import fw_utils
19-
2018
import sagemaker
2119
from sagemaker.fw_utils import model_code_key_prefix
2220
from sagemaker.fw_registry import default_framework_uri
@@ -107,26 +105,24 @@ def __init__(
107105
self.framework_version = framework_version
108106
self.model_server_workers = model_server_workers
109107

110-
def prepare_container_def(self, instance_type, accelerator_type=None):
108+
def prepare_container_def(self, instance_type=None, accelerator_type=None):
111109
"""Return a container definition with framework configuration
112110
set in model environment variables.
113111
114112
Args:
115-
instance_type (str): The EC2 instance type to deploy this Model to. For example,
116-
'ml.m5.xlarge'.
113+
instance_type (str): The EC2 instance type to deploy this Model to.
114+
This parameter is unused because XGBoost supports only CPU.
117115
accelerator_type (str): The Elastic Inference accelerator type to deploy to the
118-
instance for loading and making inferences to the model. For example,
119-
'ml.eia1.medium'.
120-
Note: accelerator types are not supported by XGBoostModel.
116+
instance for loading and making inferences to the model. This parameter is
117+
unused because accelerator types are not supported by XGBoostModel.
121118
122119
Returns:
123120
dict[str, str]: A container definition object usable with the CreateModel API.
124121
"""
125122
deploy_image = self.image
126123
if not deploy_image:
127-
image_tag = "{}-{}-{}".format(self.framework_version, "cpu", self.py_version)
128-
deploy_image = default_framework_uri(
129-
self.__framework_name__, self.sagemaker_session.boto_region_name, image_tag
124+
deploy_image = self.serving_image_uri(
125+
self.sagemaker_session.boto_region_name, instance_type
130126
)
131127

132128
deploy_key_prefix = model_code_key_prefix(self.key_prefix, self.name, deploy_image)
@@ -138,22 +134,17 @@ def prepare_container_def(self, instance_type, accelerator_type=None):
138134
deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str(self.model_server_workers)
139135
return sagemaker.container_def(deploy_image, self.model_data, deploy_env)
140136

141-
def serving_image_uri(self, region_name, instance_type):
137+
def serving_image_uri(self, region_name, instance_type): # pylint: disable=unused-argument
142138
"""Create a URI for the serving image.
143139
144140
Args:
145141
region_name (str): AWS region where the image is uploaded.
146-
instance_type (str): SageMaker instance type. Used to determine device type
147-
(cpu/gpu/family-specific optimized).
142+
instance_type (str): SageMaker instance type. This parameter is unused because
143+
XGBoost supports only CPU.
148144
149145
Returns:
150146
str: The appropriate image URI based on the given parameters.
151147
152148
"""
153-
return fw_utils.create_image_uri(
154-
region_name,
155-
self.__framework_name__,
156-
instance_type,
157-
self.framework_version,
158-
self.py_version,
159-
)
149+
image_tag = "{}-{}-{}".format(self.framework_version, "cpu", self.py_version)
150+
return default_framework_uri(self.__framework_name__, region_name, image_tag)

tests/unit/sagemaker/model/test_model.py

+15-3
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,12 @@ def test_prepare_container_def():
3737
env = {"FOO": "BAR"}
3838
model = Model(MODEL_DATA, MODEL_IMAGE, env=env)
3939

40+
expected = {"Image": MODEL_IMAGE, "Environment": env, "ModelDataUrl": MODEL_DATA}
41+
4042
container_def = model.prepare_container_def(INSTANCE_TYPE, "ml.eia.medium")
43+
assert expected == container_def
4144

42-
expected = {"Image": MODEL_IMAGE, "Environment": env, "ModelDataUrl": MODEL_DATA}
45+
container_def = model.prepare_container_def()
4346
assert expected == container_def
4447

4548

@@ -60,16 +63,25 @@ def test_create_sagemaker_model(name_from_image, prepare_container_def, sagemake
6063
prepare_container_def.return_value = container_def
6164

6265
model = Model(MODEL_DATA, MODEL_IMAGE, sagemaker_session=sagemaker_session)
63-
model._create_sagemaker_model(INSTANCE_TYPE)
66+
model._create_sagemaker_model()
6467

65-
prepare_container_def.assert_called_with(INSTANCE_TYPE, accelerator_type=None)
68+
prepare_container_def.assert_called_with(None, accelerator_type=None)
6669
name_from_image.assert_called_with(MODEL_IMAGE)
6770

6871
sagemaker_session.create_model.assert_called_with(
6972
MODEL_NAME, None, container_def, vpc_config=None, enable_network_isolation=False, tags=None
7073
)
7174

7275

76+
@patch("sagemaker.utils.name_from_image", Mock())
77+
@patch("sagemaker.model.Model.prepare_container_def")
78+
def test_create_sagemaker_model_instance_type(prepare_container_def, sagemaker_session):
79+
model = Model(MODEL_DATA, MODEL_IMAGE, sagemaker_session=sagemaker_session)
80+
model._create_sagemaker_model(INSTANCE_TYPE)
81+
82+
prepare_container_def.assert_called_with(INSTANCE_TYPE, accelerator_type=None)
83+
84+
7385
@patch("sagemaker.utils.name_from_image", Mock())
7486
@patch("sagemaker.model.Model.prepare_container_def")
7587
def test_create_sagemaker_model_accelerator_type(prepare_container_def, sagemaker_session):

tests/unit/test_chainer.py

+10
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,16 @@ def test_model_prepare_container_def_accelerator_error(sagemaker_session):
434434
model.prepare_container_def(INSTANCE_TYPE, accelerator_type=ACCELERATOR_TYPE)
435435

436436

437+
def test_model_prepare_container_def_no_instance_type_or_image():
438+
model = ChainerModel(MODEL_DATA, role=ROLE, entry_point=SCRIPT_PATH)
439+
440+
with pytest.raises(ValueError) as e:
441+
model.prepare_container_def()
442+
443+
expected_msg = "Must supply either an instance type (for choosing CPU vs GPU) or an image URI."
444+
assert expected_msg in str(e)
445+
446+
437447
def test_train_image_default(sagemaker_session):
438448
chainer = Chainer(
439449
entry_point=SCRIPT_PATH,

tests/unit/test_mxnet.py

+10
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,16 @@ def test_model_image_accelerator_mms_version(sagemaker_session):
491491
)
492492

493493

494+
def test_model_prepare_container_def_no_instance_type_or_image():
495+
model = MXNetModel(MODEL_DATA, role=ROLE, entry_point=SCRIPT_PATH)
496+
497+
with pytest.raises(ValueError) as e:
498+
model.prepare_container_def()
499+
500+
expected_msg = "Must supply either an instance type (for choosing CPU vs GPU) or an image URI."
501+
assert expected_msg in str(e)
502+
503+
494504
def test_train_image_default(sagemaker_session):
495505
mx = MXNet(
496506
entry_point=SCRIPT_PATH,

0 commit comments

Comments
 (0)