Skip to content

Commit 1470afb

Browse files
committed
change: Make Instance Type Fields as Optional
1 parent 255a339 commit 1470afb

File tree

13 files changed

+194
-54
lines changed

13 files changed

+194
-54
lines changed

src/sagemaker/estimator.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1266,8 +1266,8 @@ def register(
12661266
self,
12671267
content_types,
12681268
response_types,
1269-
inference_instances,
1270-
transform_instances,
1269+
inference_instances=None,
1270+
transform_instances=None,
12711271
image_uri=None,
12721272
model_package_name=None,
12731273
model_package_group_name=None,
@@ -1288,9 +1288,9 @@ def register(
12881288
content_types (list): The supported MIME types for the input data.
12891289
response_types (list): The supported MIME types for the output data.
12901290
inference_instances (list): A list of the instance types that are used to
1291-
generate inferences in real-time.
1291+
generate inferences in real-time (default: None).
12921292
transform_instances (list): A list of the instance types on which a transformation
1293-
job can be run or on which an endpoint can be deployed.
1293+
job can be run or on which an endpoint can be deployed (default: None).
12941294
image_uri (str): The container image uri for Model Package, if not specified,
12951295
Estimator's training container image will be used (default: None).
12961296
model_package_name (str): Model Package name, exclusive to `model_package_group_name`,

src/sagemaker/huggingface/model.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -293,8 +293,8 @@ def register(
293293
self,
294294
content_types,
295295
response_types,
296-
inference_instances,
297-
transform_instances,
296+
inference_instances=None,
297+
transform_instances=None,
298298
model_package_name=None,
299299
model_package_group_name=None,
300300
image_uri=None,
@@ -311,9 +311,9 @@ def register(
311311
content_types (list): The supported MIME types for the input data.
312312
response_types (list): The supported MIME types for the output data.
313313
inference_instances (list): A list of the instance types that are used to
314-
generate inferences in real-time.
314+
generate inferences in real-time (default: None).
315315
transform_instances (list): A list of the instance types on which a transformation
316-
job can be run or on which an endpoint can be deployed.
316+
job can be run or on which an endpoint can be deployed (default: None).
317317
model_package_name (str): Model Package name, exclusive to `model_package_group_name`,
318318
using `model_package_name` makes the Model Package un-versioned.
319319
Defaults to ``None``.
@@ -335,7 +335,7 @@ def register(
335335
Returns:
336336
A `sagemaker.model.ModelPackage` instance.
337337
"""
338-
instance_type = inference_instances[0]
338+
instance_type = inference_instances[0] if inference_instances else None
339339
self._init_sagemaker_session_if_does_not_exist(instance_type)
340340

341341
if image_uri:

src/sagemaker/model.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -296,8 +296,8 @@ def register(
296296
self,
297297
content_types,
298298
response_types,
299-
inference_instances,
300-
transform_instances,
299+
inference_instances=None,
300+
transform_instances=None,
301301
model_package_name=None,
302302
model_package_group_name=None,
303303
image_uri=None,
@@ -311,14 +311,14 @@ def register(
311311
validation_specification=None,
312312
):
313313
"""Creates a model package for creating SageMaker models or listing on Marketplace.
314-
314+
315315
Args:
316316
content_types (list): The supported MIME types for the input data.
317317
response_types (list): The supported MIME types for the output data.
318318
inference_instances (list): A list of the instance types that are used to
319-
generate inferences in real-time.
319+
generate inferences in real-time (default: None).
320320
transform_instances (list): A list of the instance types on which a transformation
321-
job can be run or on which an endpoint can be deployed.
321+
job can be run or on which an endpoint can be deployed (default: None).
322322
model_package_name (str): Model Package name, exclusive to `model_package_group_name`,
323323
using `model_package_name` makes the Model Package un-versioned (default: None).
324324
model_package_group_name (str): Model Package Group name, exclusive to
@@ -348,12 +348,11 @@ def register(
348348
container_def = self.prepare_container_def()
349349
else:
350350
container_def = {"Image": self.image_uri, "ModelDataUrl": self.model_data}
351-
352351
model_pkg_args = sagemaker.get_model_package_args(
353352
content_types,
354353
response_types,
355-
inference_instances,
356-
transform_instances,
354+
inference_instances=inference_instances,
355+
transform_instances=transform_instances,
357356
model_package_name=model_package_name,
358357
model_package_group_name=model_package_group_name,
359358
model_metrics=model_metrics,

src/sagemaker/mxnet/model.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -146,8 +146,8 @@ def register(
146146
self,
147147
content_types,
148148
response_types,
149-
inference_instances,
150-
transform_instances,
149+
inference_instances=None,
150+
transform_instances=None,
151151
model_package_name=None,
152152
model_package_group_name=None,
153153
image_uri=None,
@@ -165,9 +165,9 @@ def register(
165165
content_types (list): The supported MIME types for the input data.
166166
response_types (list): The supported MIME types for the output data.
167167
inference_instances (list): A list of the instance types that are used to
168-
generate inferences in real-time.
168+
generate inferences in real-time (default: None).
169169
transform_instances (list): A list of the instance types on which a transformation
170-
job can be run or on which an endpoint can be deployed.
170+
job can be run or on which an endpoint can be deployed (default: None).
171171
model_package_name (str): Model Package name, exclusive to `model_package_group_name`,
172172
using `model_package_name` makes the Model Package un-versioned (default: None).
173173
model_package_group_name (str): Model Package Group name, exclusive to
@@ -189,7 +189,7 @@ def register(
189189
Returns:
190190
A `sagemaker.model.ModelPackage` instance.
191191
"""
192-
instance_type = inference_instances[0]
192+
instance_type = inference_instances[0] if inference_instances else None
193193
self._init_sagemaker_session_if_does_not_exist(instance_type)
194194

195195
if image_uri:

src/sagemaker/pipeline.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -266,8 +266,8 @@ def register(
266266
self,
267267
content_types: list,
268268
response_types: list,
269-
inference_instances: list,
270-
transform_instances: list,
269+
inference_instances: Optional[list] = None,
270+
transform_instances: Optional[list] = None,
271271
model_package_name: Optional[str] = None,
272272
model_package_group_name: Optional[str] = None,
273273
image_uri: Optional[str] = None,
@@ -285,9 +285,9 @@ def register(
285285
content_types (list): The supported MIME types for the input data.
286286
response_types (list): The supported MIME types for the output data.
287287
inference_instances (list): A list of the instance types that are used to
288-
generate inferences in real-time.
288+
generate inferences in real-time (default: None).
289289
transform_instances (list): A list of the instance types on which a transformation
290-
job can be run or on which an endpoint can be deployed.
290+
job can be run or on which an endpoint can be deployed (default: None).
291291
model_package_name (str): Model Package name, exclusive to `model_package_group_name`,
292292
using `model_package_name` makes the Model Package un-versioned (default: None).
293293
model_package_group_name (str): Model Package Group name, exclusive to
@@ -313,7 +313,7 @@ def register(
313313
if model.model_data is None:
314314
raise ValueError("SageMaker Model Package cannot be created without model data.")
315315
if model_package_group_name is not None:
316-
container_def = self.pipeline_container_def(inference_instances[0])
316+
container_def = self.pipeline_container_def(inference_instances[0] if inference_instances else None)
317317
else:
318318
container_def = [
319319
{"Image": image_uri or model.image_uri, "ModelDataUrl": model.model_data}
@@ -323,8 +323,8 @@ def register(
323323
model_pkg_args = sagemaker.get_model_package_args(
324324
content_types,
325325
response_types,
326-
inference_instances,
327-
transform_instances,
326+
inference_instances=inference_instances,
327+
transform_instances=transform_instances,
328328
model_package_name=model_package_name,
329329
model_package_group_name=model_package_group_name,
330330
model_metrics=model_metrics,

src/sagemaker/pytorch/model.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -147,8 +147,8 @@ def register(
147147
self,
148148
content_types,
149149
response_types,
150-
inference_instances,
151-
transform_instances,
150+
inference_instances=None,
151+
transform_instances=None,
152152
model_package_name=None,
153153
model_package_group_name=None,
154154
image_uri=None,
@@ -166,9 +166,9 @@ def register(
166166
content_types (list): The supported MIME types for the input data.
167167
response_types (list): The supported MIME types for the output data.
168168
inference_instances (list): A list of the instance types that are used to
169-
generate inferences in real-time.
169+
generate inferences in real-time (default: None).
170170
transform_instances (list): A list of the instance types on which a transformation
171-
job can be run or on which an endpoint can be deployed.
171+
job can be run or on which an endpoint can be deployed (default: None).
172172
model_package_name (str): Model Package name, exclusive to `model_package_group_name`,
173173
using `model_package_name` makes the Model Package un-versioned (default: None).
174174
model_package_group_name (str): Model Package Group name, exclusive to
@@ -190,7 +190,7 @@ def register(
190190
Returns:
191191
A `sagemaker.model.ModelPackage` instance.
192192
"""
193-
instance_type = inference_instances[0]
193+
instance_type = inference_instances[0] if inference_instances else None
194194
self._init_sagemaker_session_if_does_not_exist(instance_type)
195195

196196
if image_uri:

src/sagemaker/session.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4202,8 +4202,8 @@ def _intercept_create_request(
42024202
def get_model_package_args(
42034203
content_types,
42044204
response_types,
4205-
inference_instances,
4206-
transform_instances,
4205+
inference_instances=None,
4206+
transform_instances=None,
42074207
model_package_name=None,
42084208
model_package_group_name=None,
42094209
model_data=None,
@@ -4225,9 +4225,9 @@ def get_model_package_args(
42254225
content_types (list): The supported MIME types for the input data.
42264226
response_types (list): The supported MIME types for the output data.
42274227
inference_instances (list): A list of the instance types that are used to
4228-
generate inferences in real-time.
4228+
generate inferences in real-time (default: None).
42294229
transform_instances (list): A list of the instance types on which a transformation
4230-
job can be run or on which an endpoint can be deployed.
4230+
job can be run or on which an endpoint can be deployed (default: None).
42314231
model_package_name (str): Model Package name, exclusive to `model_package_group_name`,
42324232
using `model_package_name` makes the Model Package un-versioned (default: None).
42334233
model_package_group_name (str): Model Package Group name, exclusive to
@@ -4363,9 +4363,9 @@ def get_create_model_package_request(
43634363
if validation_specification:
43644364
request_dict["ValidationSpecification"] = validation_specification
43654365
if containers is not None:
4366-
if not all([content_types, response_types, inference_instances, transform_instances]):
4366+
if not all([content_types, response_types]):
43674367
raise ValueError(
4368-
"content_types, response_types, inference_inferences and transform_instances "
4368+
"content_types and response_types "
43694369
"must be provided if containers is present."
43704370
)
43714371
inference_specification = {

src/sagemaker/sklearn/model.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -141,8 +141,8 @@ def register(
141141
self,
142142
content_types,
143143
response_types,
144-
inference_instances,
145-
transform_instances,
144+
inference_instances=None,
145+
transform_instances=None,
146146
model_package_name=None,
147147
model_package_group_name=None,
148148
image_uri=None,
@@ -158,9 +158,9 @@ def register(
158158
content_types (list): The supported MIME types for the input data.
159159
response_types (list): The supported MIME types for the output data.
160160
inference_instances (list): A list of the instance types that are used to
161-
generate inferences in real-time.
161+
generate inferences in real-time (default: None).
162162
transform_instances (list): A list of the instance types on which a transformation
163-
job can be run or on which an endpoint can be deployed.
163+
job can be run or on which an endpoint can be deployed (default: None).
164164
model_package_name (str): Model Package name, exclusive to `model_package_group_name`,
165165
using `model_package_name` makes the Model Package un-versioned (default: None).
166166
model_package_group_name (str): Model Package Group name, exclusive to
@@ -179,7 +179,7 @@ def register(
179179
Returns:
180180
A `sagemaker.model.ModelPackage` instance.
181181
"""
182-
instance_type = inference_instances[0]
182+
instance_type = inference_instances[0] if inference_instances else None
183183
self._init_sagemaker_session_if_does_not_exist(instance_type)
184184

185185
if image_uri:

src/sagemaker/tensorflow/model.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -193,8 +193,8 @@ def register(
193193
self,
194194
content_types,
195195
response_types,
196-
inference_instances,
197-
transform_instances,
196+
inference_instances=None,
197+
transform_instances=None,
198198
model_package_name=None,
199199
model_package_group_name=None,
200200
image_uri=None,
@@ -212,9 +212,9 @@ def register(
212212
content_types (list): The supported MIME types for the input data.
213213
response_types (list): The supported MIME types for the output data.
214214
inference_instances (list): A list of the instance types that are used to
215-
generate inferences in real-time.
215+
generate inferences in real-time (default: None).
216216
transform_instances (list): A list of the instance types on which a transformation
217-
job can be run or on which an endpoint can be deployed.
217+
job can be run or on which an endpoint can be deployed (default: None).
218218
model_package_name (str): Model Package name, exclusive to `model_package_group_name`,
219219
using `model_package_name` makes the Model Package un-versioned (default: None).
220220
model_package_group_name (str): Model Package Group name, exclusive to
@@ -237,7 +237,7 @@ def register(
237237
Returns:
238238
A `sagemaker.model.ModelPackage` instance.
239239
"""
240-
instance_type = inference_instances[0]
240+
instance_type = inference_instances[0] if inference_instances else None
241241
self._init_sagemaker_session_if_does_not_exist(instance_type)
242242

243243
if image_uri:

src/sagemaker/workflow/step_collections.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,8 @@ def __init__(
6363
name: str,
6464
content_types,
6565
response_types,
66-
inference_instances,
67-
transform_instances,
66+
inference_instances=None,
67+
transform_instances=None,
6868
estimator: EstimatorBase = None,
6969
model_data=None,
7070
depends_on: Optional[List[Union[str, Step, StepCollection]]] = None,
@@ -217,9 +217,9 @@ def __init__(
217217
kwargs.pop("output_kms_key", None)
218218

219219
if isinstance(model, PipelineModel):
220-
self.container_def_list = model.pipeline_container_def(inference_instances[0])
220+
self.container_def_list = model.pipeline_container_def(inference_instances[0] if inference_instances else None)
221221
elif isinstance(model, Model):
222-
self.container_def_list = [model.prepare_container_def(inference_instances[0])]
222+
self.container_def_list = [model.prepare_container_def(inference_instances[0] if inference_instances else None)]
223223

224224
register_model_step = _RegisterModelStep(
225225
name=name,

tests/unit/sagemaker/workflow/test_pipeline_session.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,3 +122,39 @@ def test_pipeline_session_context_for_model_step(pipeline_session_mock):
122122
assert not register_step_args.create_model_request
123123
assert register_step_args.create_model_package_request
124124
assert len(register_step_args.need_runtime_repack) == 0
125+
126+
127+
def test_pipeline_session_context_for_model_step_without_instance_types(pipeline_session_mock):
128+
model = Model(
129+
name="MyModel",
130+
image_uri="fakeimage",
131+
model_data=ParameterString(name="ModelData", default_value="s3://my-bucket/file"),
132+
sagemaker_session=pipeline_session_mock,
133+
entry_point=f"{DATA_DIR}/dummy_script.py",
134+
source_dir=f"{DATA_DIR}",
135+
role=_ROLE,
136+
)
137+
# CreateModelStep requires runtime repack
138+
create_step_args = model.create(
139+
instance_type="c4.4xlarge",
140+
accelerator_type="ml.eia1.medium",
141+
)
142+
# The context should be cleaned up before return
143+
assert pipeline_session_mock.context is None
144+
assert create_step_args.create_model_request
145+
assert not create_step_args.create_model_package_request
146+
assert len(create_step_args.need_runtime_repack) == 1
147+
148+
# _RegisterModelStep does not require runtime repack
149+
model.entry_point = None
150+
model.source_dir = None
151+
register_step_args = model.register(
152+
content_types=["text/csv"],
153+
response_types=["text/csv"],
154+
model_package_group_name="MyModelPackageGroup",
155+
)
156+
# The context should be cleaned up before return
157+
assert not pipeline_session_mock.context
158+
assert not register_step_args.create_model_request
159+
assert register_step_args.create_model_package_request
160+
assert len(register_step_args.need_runtime_repack) == 0

0 commit comments

Comments
 (0)