Skip to content

Commit 4d27ff4

Browse files
rahven14jerrypeng7773
authored andcommitted
fix: make instance type fields as optional (aws#3135)
1 parent a1b7e2a commit 4d27ff4

File tree

13 files changed

+219
-56
lines changed

13 files changed

+219
-56
lines changed

src/sagemaker/estimator.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1286,8 +1286,8 @@ def register(
12861286
self,
12871287
content_types,
12881288
response_types,
1289-
inference_instances,
1290-
transform_instances,
1289+
inference_instances=None,
1290+
transform_instances=None,
12911291
image_uri=None,
12921292
model_package_name=None,
12931293
model_package_group_name=None,
@@ -1309,9 +1309,9 @@ def register(
13091309
content_types (list): The supported MIME types for the input data.
13101310
response_types (list): The supported MIME types for the output data.
13111311
inference_instances (list): A list of the instance types that are used to
1312-
generate inferences in real-time.
1312+
generate inferences in real-time (default: None).
13131313
transform_instances (list): A list of the instance types on which a transformation
1314-
job can be run or on which an endpoint can be deployed.
1314+
job can be run or on which an endpoint can be deployed (default: None).
13151315
image_uri (str): The container image uri for Model Package, if not specified,
13161316
Estimator's training container image will be used (default: None).
13171317
model_package_name (str): Model Package name, exclusive to `model_package_group_name`,

src/sagemaker/huggingface/model.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -293,8 +293,8 @@ def register(
293293
self,
294294
content_types,
295295
response_types,
296-
inference_instances,
297-
transform_instances,
296+
inference_instances=None,
297+
transform_instances=None,
298298
model_package_name=None,
299299
model_package_group_name=None,
300300
image_uri=None,
@@ -313,9 +313,9 @@ def register(
313313
content_types (list): The supported MIME types for the input data.
314314
response_types (list): The supported MIME types for the output data.
315315
inference_instances (list): A list of the instance types that are used to
316-
generate inferences in real-time.
316+
generate inferences in real-time (default: None).
317317
transform_instances (list): A list of the instance types on which a transformation
318-
job can be run or on which an endpoint can be deployed.
318+
job can be run or on which an endpoint can be deployed (default: None).
319319
model_package_name (str): Model Package name, exclusive to `model_package_group_name`,
320320
using `model_package_name` makes the Model Package un-versioned.
321321
Defaults to ``None``.
@@ -341,7 +341,7 @@ def register(
341341
Returns:
342342
A `sagemaker.model.ModelPackage` instance.
343343
"""
344-
instance_type = inference_instances[0]
344+
instance_type = inference_instances[0] if inference_instances else None
345345
self._init_sagemaker_session_if_does_not_exist(instance_type)
346346

347347
if image_uri:

src/sagemaker/model.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -296,8 +296,8 @@ def register(
296296
self,
297297
content_types,
298298
response_types,
299-
inference_instances,
300-
transform_instances,
299+
inference_instances=None,
300+
transform_instances=None,
301301
model_package_name=None,
302302
model_package_group_name=None,
303303
image_uri=None,
@@ -317,9 +317,9 @@ def register(
317317
content_types (list): The supported MIME types for the input data.
318318
response_types (list): The supported MIME types for the output data.
319319
inference_instances (list): A list of the instance types that are used to
320-
generate inferences in real-time.
320+
generate inferences in real-time (default: None).
321321
transform_instances (list): A list of the instance types on which a transformation
322-
job can be run or on which an endpoint can be deployed.
322+
job can be run or on which an endpoint can be deployed (default: None).
323323
model_package_name (str): Model Package name, exclusive to `model_package_group_name`,
324324
using `model_package_name` makes the Model Package un-versioned (default: None).
325325
model_package_group_name (str): Model Package Group name, exclusive to
@@ -351,12 +351,11 @@ def register(
351351
container_def = self.prepare_container_def()
352352
else:
353353
container_def = {"Image": self.image_uri, "ModelDataUrl": self.model_data}
354-
355354
model_pkg_args = sagemaker.get_model_package_args(
356355
content_types,
357356
response_types,
358-
inference_instances,
359-
transform_instances,
357+
inference_instances=inference_instances,
358+
transform_instances=transform_instances,
360359
model_package_name=model_package_name,
361360
model_package_group_name=model_package_group_name,
362361
model_metrics=model_metrics,

src/sagemaker/mxnet/model.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -146,8 +146,8 @@ def register(
146146
self,
147147
content_types,
148148
response_types,
149-
inference_instances,
150-
transform_instances,
149+
inference_instances=None,
150+
transform_instances=None,
151151
model_package_name=None,
152152
model_package_group_name=None,
153153
image_uri=None,
@@ -166,9 +166,9 @@ def register(
166166
content_types (list): The supported MIME types for the input data.
167167
response_types (list): The supported MIME types for the output data.
168168
inference_instances (list): A list of the instance types that are used to
169-
generate inferences in real-time.
169+
generate inferences in real-time (default: None).
170170
transform_instances (list): A list of the instance types on which a transformation
171-
job can be run or on which an endpoint can be deployed.
171+
job can be run or on which an endpoint can be deployed (default: None).
172172
model_package_name (str): Model Package name, exclusive to `model_package_group_name`,
173173
using `model_package_name` makes the Model Package un-versioned (default: None).
174174
model_package_group_name (str): Model Package Group name, exclusive to
@@ -192,7 +192,7 @@ def register(
192192
Returns:
193193
A `sagemaker.model.ModelPackage` instance.
194194
"""
195-
instance_type = inference_instances[0]
195+
instance_type = inference_instances[0] if inference_instances else None
196196
self._init_sagemaker_session_if_does_not_exist(instance_type)
197197

198198
if image_uri:

src/sagemaker/pipeline.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def __init__(
8484
self.enable_network_isolation = enable_network_isolation
8585
self.endpoint_name = None
8686

87-
def pipeline_container_def(self, instance_type):
87+
def pipeline_container_def(self, instance_type=None):
8888
"""The pipeline definition for deploying this model.
8989
9090
This is the dict created by ``sagemaker.pipeline_container_def()``.
@@ -266,8 +266,8 @@ def register(
266266
self,
267267
content_types: list,
268268
response_types: list,
269-
inference_instances: list,
270-
transform_instances: list,
269+
inference_instances: Optional[list] = None,
270+
transform_instances: Optional[list] = None,
271271
model_package_name: Optional[str] = None,
272272
model_package_group_name: Optional[str] = None,
273273
image_uri: Optional[str] = None,
@@ -286,9 +286,9 @@ def register(
286286
content_types (list): The supported MIME types for the input data.
287287
response_types (list): The supported MIME types for the output data.
288288
inference_instances (list): A list of the instance types that are used to
289-
generate inferences in real-time.
289+
generate inferences in real-time (default: None).
290290
transform_instances (list): A list of the instance types on which a transformation
291-
job can be run or on which an endpoint can be deployed.
291+
job can be run or on which an endpoint can be deployed (default: None).
292292
model_package_name (str): Model Package name, exclusive to `model_package_group_name`,
293293
using `model_package_name` makes the Model Package un-versioned (default: None).
294294
model_package_group_name (str): Model Package Group name, exclusive to
@@ -316,18 +316,23 @@ def register(
316316
if model.model_data is None:
317317
raise ValueError("SageMaker Model Package cannot be created without model data.")
318318
if model_package_group_name is not None:
319-
container_def = self.pipeline_container_def(inference_instances[0])
319+
container_def = self.pipeline_container_def(
320+
inference_instances[0] if inference_instances else None
321+
)
320322
else:
321323
container_def = [
322-
{"Image": image_uri or model.image_uri, "ModelDataUrl": model.model_data}
324+
{
325+
"Image": image_uri or model.image_uri,
326+
"ModelDataUrl": model.model_data,
327+
}
323328
for model in self.models
324329
]
325330

326331
model_pkg_args = sagemaker.get_model_package_args(
327332
content_types,
328333
response_types,
329-
inference_instances,
330-
transform_instances,
334+
inference_instances=inference_instances,
335+
transform_instances=transform_instances,
331336
model_package_name=model_package_name,
332337
model_package_group_name=model_package_group_name,
333338
model_metrics=model_metrics,

src/sagemaker/pytorch/model.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -147,8 +147,8 @@ def register(
147147
self,
148148
content_types,
149149
response_types,
150-
inference_instances,
151-
transform_instances,
150+
inference_instances=None,
151+
transform_instances=None,
152152
model_package_name=None,
153153
model_package_group_name=None,
154154
image_uri=None,
@@ -167,9 +167,9 @@ def register(
167167
content_types (list): The supported MIME types for the input data.
168168
response_types (list): The supported MIME types for the output data.
169169
inference_instances (list): A list of the instance types that are used to
170-
generate inferences in real-time.
170+
generate inferences in real-time (default: None).
171171
transform_instances (list): A list of the instance types on which a transformation
172-
job can be run or on which an endpoint can be deployed.
172+
job can be run or on which an endpoint can be deployed (default: None).
173173
model_package_name (str): Model Package name, exclusive to `model_package_group_name`,
174174
using `model_package_name` makes the Model Package un-versioned (default: None).
175175
model_package_group_name (str): Model Package Group name, exclusive to
@@ -193,7 +193,7 @@ def register(
193193
Returns:
194194
A `sagemaker.model.ModelPackage` instance.
195195
"""
196-
instance_type = inference_instances[0]
196+
instance_type = inference_instances[0] if inference_instances else None
197197
self._init_sagemaker_session_if_does_not_exist(instance_type)
198198

199199
if image_uri:

src/sagemaker/session.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4206,8 +4206,8 @@ def _intercept_create_request(
42064206
def get_model_package_args(
42074207
content_types,
42084208
response_types,
4209-
inference_instances,
4210-
transform_instances,
4209+
inference_instances=None,
4210+
transform_instances=None,
42114211
model_package_name=None,
42124212
model_package_group_name=None,
42134213
model_data=None,
@@ -4230,9 +4230,9 @@ def get_model_package_args(
42304230
content_types (list): The supported MIME types for the input data.
42314231
response_types (list): The supported MIME types for the output data.
42324232
inference_instances (list): A list of the instance types that are used to
4233-
generate inferences in real-time.
4233+
generate inferences in real-time (default: None).
42344234
transform_instances (list): A list of the instance types on which a transformation
4235-
job can be run or on which an endpoint can be deployed.
4235+
job can be run or on which an endpoint can be deployed (default: None).
42364236
model_package_name (str): Model Package name, exclusive to `model_package_group_name`,
42374237
using `model_package_name` makes the Model Package un-versioned (default: None).
42384238
model_package_group_name (str): Model Package Group name, exclusive to
@@ -4377,10 +4377,9 @@ def get_create_model_package_request(
43774377
if domain is not None:
43784378
request_dict["Domain"] = domain
43794379
if containers is not None:
4380-
if not all([content_types, response_types, inference_instances, transform_instances]):
4380+
if not all([content_types, response_types]):
43814381
raise ValueError(
4382-
"content_types, response_types, inference_inferences and transform_instances "
4383-
"must be provided if containers is present."
4382+
"content_types and response_types " "must be provided if containers is present."
43844383
)
43854384
inference_specification = {
43864385
"Containers": containers,

src/sagemaker/sklearn/model.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -141,8 +141,8 @@ def register(
141141
self,
142142
content_types,
143143
response_types,
144-
inference_instances,
145-
transform_instances,
144+
inference_instances=None,
145+
transform_instances=None,
146146
model_package_name=None,
147147
model_package_group_name=None,
148148
image_uri=None,
@@ -161,9 +161,9 @@ def register(
161161
content_types (list): The supported MIME types for the input data.
162162
response_types (list): The supported MIME types for the output data.
163163
inference_instances (list): A list of the instance types that are used to
164-
generate inferences in real-time.
164+
generate inferences in real-time (default: None).
165165
transform_instances (list): A list of the instance types on which a transformation
166-
job can be run or on which an endpoint can be deployed.
166+
job can be run or on which an endpoint can be deployed (default: None).
167167
model_package_name (str): Model Package name, exclusive to `model_package_group_name`,
168168
using `model_package_name` makes the Model Package un-versioned (default: None).
169169
model_package_group_name (str): Model Package Group name, exclusive to
@@ -187,7 +187,7 @@ def register(
187187
Returns:
188188
A `sagemaker.model.ModelPackage` instance.
189189
"""
190-
instance_type = inference_instances[0]
190+
instance_type = inference_instances[0] if inference_instances else None
191191
self._init_sagemaker_session_if_does_not_exist(instance_type)
192192

193193
if image_uri:

src/sagemaker/tensorflow/model.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -193,8 +193,8 @@ def register(
193193
self,
194194
content_types,
195195
response_types,
196-
inference_instances,
197-
transform_instances,
196+
inference_instances=None,
197+
transform_instances=None,
198198
model_package_name=None,
199199
model_package_group_name=None,
200200
image_uri=None,
@@ -213,9 +213,9 @@ def register(
213213
content_types (list): The supported MIME types for the input data.
214214
response_types (list): The supported MIME types for the output data.
215215
inference_instances (list): A list of the instance types that are used to
216-
generate inferences in real-time.
216+
generate inferences in real-time (default: None).
217217
transform_instances (list): A list of the instance types on which a transformation
218-
job can be run or on which an endpoint can be deployed.
218+
job can be run or on which an endpoint can be deployed (default: None).
219219
model_package_name (str): Model Package name, exclusive to `model_package_group_name`,
220220
using `model_package_name` makes the Model Package un-versioned (default: None).
221221
model_package_group_name (str): Model Package Group name, exclusive to
@@ -239,7 +239,7 @@ def register(
239239
Returns:
240240
A `sagemaker.model.ModelPackage` instance.
241241
"""
242-
instance_type = inference_instances[0]
242+
instance_type = inference_instances[0] if inference_instances else None
243243
self._init_sagemaker_session_if_does_not_exist(instance_type)
244244

245245
if image_uri:

src/sagemaker/workflow/step_collections.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,8 @@ def __init__(
6363
name: str,
6464
content_types,
6565
response_types,
66-
inference_instances,
67-
transform_instances,
66+
inference_instances=None,
67+
transform_instances=None,
6868
estimator: EstimatorBase = None,
6969
model_data=None,
7070
depends_on: Optional[List[Union[str, Step, StepCollection]]] = None,
@@ -220,9 +220,15 @@ def __init__(
220220
kwargs.pop("output_kms_key", None)
221221

222222
if isinstance(model, PipelineModel):
223-
self.container_def_list = model.pipeline_container_def(inference_instances[0])
223+
self.container_def_list = model.pipeline_container_def(
224+
inference_instances[0] if inference_instances else None
225+
)
224226
elif isinstance(model, Model):
225-
self.container_def_list = [model.prepare_container_def(inference_instances[0])]
227+
self.container_def_list = [
228+
model.prepare_container_def(
229+
inference_instances[0] if inference_instances else None
230+
)
231+
]
226232

227233
register_model_step = _RegisterModelStep(
228234
name=name,

tests/unit/sagemaker/workflow/test_pipeline_session.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,3 +122,52 @@ def test_pipeline_session_context_for_model_step(pipeline_session_mock):
122122
assert not register_step_args.create_model_request
123123
assert register_step_args.create_model_package_request
124124
assert len(register_step_args.need_runtime_repack) == 0
125+
126+
127+
def test_pipeline_session_context_for_model_step_without_instance_types(
128+
pipeline_session_mock,
129+
):
130+
model = Model(
131+
name="MyModel",
132+
image_uri="fakeimage",
133+
model_data=ParameterString(name="ModelData", default_value="s3://my-bucket/file"),
134+
sagemaker_session=pipeline_session_mock,
135+
entry_point=f"{DATA_DIR}/dummy_script.py",
136+
source_dir=f"{DATA_DIR}",
137+
role=_ROLE,
138+
)
139+
140+
register_step_args = model.register(
141+
content_types=["text/csv"],
142+
response_types=["text/csv"],
143+
model_package_group_name="MyModelPackageGroup",
144+
)
145+
146+
expected_output = {
147+
"ModelPackageGroupName": "MyModelPackageGroup",
148+
"InferenceSpecification": {
149+
"Containers": [
150+
{
151+
"Image": "fakeimage",
152+
"Environment": {
153+
"SAGEMAKER_PROGRAM": "dummy_script.py",
154+
"SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code",
155+
"SAGEMAKER_CONTAINER_LOG_LEVEL": "20",
156+
"SAGEMAKER_REGION": "us-west-2",
157+
},
158+
"ModelDataUrl": ParameterString(
159+
name="ModelData",
160+
default_value="s3://my-bucket/file",
161+
),
162+
}
163+
],
164+
"SupportedContentTypes": ["text/csv"],
165+
"SupportedResponseMIMETypes": ["text/csv"],
166+
"SupportedRealtimeInferenceInstanceTypes": None,
167+
"SupportedTransformInstanceTypes": None,
168+
},
169+
"CertifyForMarketplace": False,
170+
"ModelApprovalStatus": "PendingManualApproval",
171+
}
172+
173+
assert register_step_args.create_model_package_request == expected_output

0 commit comments

Comments
 (0)