Skip to content

Commit a87f9bd

Browse files
committed
fix: enable model.register without 'inference_instances' & 'transform_instances'
1 parent a5464a2 commit a87f9bd

File tree

5 files changed

+387
-13
lines changed

5 files changed

+387
-13
lines changed

src/sagemaker/session.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -4499,9 +4499,19 @@ def get_create_model_package_request(
44994499
"Containers": containers,
45004500
"SupportedContentTypes": content_types,
45014501
"SupportedResponseMIMETypes": response_types,
4502-
"SupportedRealtimeInferenceInstanceTypes": inference_instances,
4503-
"SupportedTransformInstanceTypes": transform_instances,
45044502
}
4503+
if inference_instances is not None:
4504+
inference_specification.update(
4505+
{
4506+
"SupportedRealtimeInferenceInstanceTypes": inference_instances,
4507+
}
4508+
)
4509+
if transform_instances is not None:
4510+
inference_specification.update(
4511+
{
4512+
"SupportedTransformInstanceTypes": transform_instances,
4513+
}
4514+
)
45054515
request_dict["InferenceSpecification"] = inference_specification
45064516
request_dict["CertifyForMarketplace"] = marketplace_cert
45074517
request_dict["ModelApprovalStatus"] = approval_status

src/sagemaker/workflow/_utils.py

+2-7
Original file line numberDiff line numberDiff line change
@@ -341,16 +341,11 @@ def __init__(
341341
super(_RegisterModelStep, self).__init__(
342342
name, StepTypeEnum.REGISTER_MODEL, display_name, description, depends_on, retry_policies
343343
)
344-
deprecated_args_missing = (
345-
content_types is None
346-
or response_types is None
347-
or inference_instances is None
348-
or transform_instances is None
349-
)
344+
deprecated_args_missing = content_types is None or response_types is None
350345
if not (step_args is None) ^ deprecated_args_missing:
351346
raise ValueError(
352347
"step_args and the set of (content_types, response_types, "
353-
"inference_instances, transform_instances) are mutually exclusive. "
348+
") are mutually exclusive. "
354349
"Either of them should be provided."
355350
)
356351

tests/integ/sagemaker/workflow/test_model_create_and_registration.py

+247
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,253 @@ def test_conditional_pytorch_training_model_registration(
199199
pass
200200

201201

202+
def test_conditional_pytorch_training_model_registration_without_instance_types(
203+
sagemaker_session,
204+
role,
205+
cpu_instance_type,
206+
pipeline_name,
207+
region_name,
208+
):
209+
base_dir = os.path.join(DATA_DIR, "pytorch_mnist")
210+
entry_point = os.path.join(base_dir, "mnist.py")
211+
input_path = sagemaker_session.upload_data(
212+
path=os.path.join(base_dir, "training"),
213+
key_prefix="integ-test-data/pytorch_mnist/training",
214+
)
215+
inputs = TrainingInput(s3_data=input_path)
216+
217+
instance_count = ParameterInteger(name="InstanceCount", default_value=1)
218+
instance_type = "ml.m5.xlarge"
219+
good_enough_input = ParameterInteger(name="GoodEnoughInput", default_value=1)
220+
in_condition_input = ParameterString(name="Foo", default_value="Foo")
221+
222+
task = "IMAGE_CLASSIFICATION"
223+
sample_payload_url = "s3://test-bucket/model"
224+
framework = "TENSORFLOW"
225+
framework_version = "2.9"
226+
nearest_model_name = "resnet50"
227+
data_input_configuration = '{"input_1":[1,224,224,3]}'
228+
229+
# If image_uri is not provided, the instance_type should not be a pipeline variable
230+
# since instance_type is used to retrieve image_uri in compile time (PySDK)
231+
pytorch_estimator = PyTorch(
232+
entry_point=entry_point,
233+
role=role,
234+
framework_version="1.5.0",
235+
py_version="py3",
236+
instance_count=instance_count,
237+
instance_type=instance_type,
238+
sagemaker_session=sagemaker_session,
239+
)
240+
step_train = TrainingStep(
241+
name="pytorch-train",
242+
estimator=pytorch_estimator,
243+
inputs=inputs,
244+
)
245+
246+
step_register = RegisterModel(
247+
name="pytorch-register-model",
248+
estimator=pytorch_estimator,
249+
model_data=step_train.properties.ModelArtifacts.S3ModelArtifacts,
250+
content_types=["*"],
251+
response_types=["*"],
252+
description="test-description",
253+
sample_payload_url=sample_payload_url,
254+
task=task,
255+
framework=framework,
256+
framework_version=framework_version,
257+
nearest_model_name=nearest_model_name,
258+
data_input_configuration=data_input_configuration,
259+
)
260+
261+
model = Model(
262+
image_uri=pytorch_estimator.training_image_uri(),
263+
model_data=step_train.properties.ModelArtifacts.S3ModelArtifacts,
264+
sagemaker_session=sagemaker_session,
265+
role=role,
266+
)
267+
model_inputs = CreateModelInput(
268+
instance_type="ml.m5.large",
269+
accelerator_type="ml.eia1.medium",
270+
)
271+
step_model = CreateModelStep(
272+
name="pytorch-model",
273+
model=model,
274+
inputs=model_inputs,
275+
)
276+
277+
step_cond = ConditionStep(
278+
name="cond-good-enough",
279+
conditions=[
280+
ConditionGreaterThanOrEqualTo(left=good_enough_input, right=1),
281+
ConditionIn(value=in_condition_input, in_values=["foo", "bar"]),
282+
],
283+
if_steps=[step_register],
284+
else_steps=[step_model],
285+
depends_on=[step_train],
286+
)
287+
288+
pipeline = Pipeline(
289+
name=pipeline_name,
290+
parameters=[
291+
in_condition_input,
292+
good_enough_input,
293+
instance_count,
294+
],
295+
steps=[step_train, step_cond],
296+
sagemaker_session=sagemaker_session,
297+
)
298+
299+
try:
300+
response = pipeline.create(role)
301+
create_arn = response["PipelineArn"]
302+
assert re.match(
303+
rf"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}",
304+
create_arn,
305+
)
306+
307+
execution = pipeline.start(parameters={})
308+
assert re.match(
309+
rf"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}/execution/",
310+
execution.arn,
311+
)
312+
313+
execution = pipeline.start(parameters={"GoodEnoughInput": 0})
314+
assert re.match(
315+
rf"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}/execution/",
316+
execution.arn,
317+
)
318+
finally:
319+
try:
320+
pipeline.delete()
321+
except Exception:
322+
pass
323+
324+
325+
def test_conditional_pytorch_training_model_registration_with_one_instance_types(
326+
sagemaker_session,
327+
role,
328+
cpu_instance_type,
329+
pipeline_name,
330+
region_name,
331+
):
332+
base_dir = os.path.join(DATA_DIR, "pytorch_mnist")
333+
entry_point = os.path.join(base_dir, "mnist.py")
334+
input_path = sagemaker_session.upload_data(
335+
path=os.path.join(base_dir, "training"),
336+
key_prefix="integ-test-data/pytorch_mnist/training",
337+
)
338+
inputs = TrainingInput(s3_data=input_path)
339+
340+
instance_count = ParameterInteger(name="InstanceCount", default_value=1)
341+
instance_type = "ml.m5.xlarge"
342+
good_enough_input = ParameterInteger(name="GoodEnoughInput", default_value=1)
343+
in_condition_input = ParameterString(name="Foo", default_value="Foo")
344+
345+
task = "IMAGE_CLASSIFICATION"
346+
sample_payload_url = "s3://test-bucket/model"
347+
framework = "TENSORFLOW"
348+
framework_version = "2.9"
349+
nearest_model_name = "resnet50"
350+
data_input_configuration = '{"input_1":[1,224,224,3]}'
351+
352+
# If image_uri is not provided, the instance_type should not be a pipeline variable
353+
# since instance_type is used to retrieve image_uri in compile time (PySDK)
354+
pytorch_estimator = PyTorch(
355+
entry_point=entry_point,
356+
role=role,
357+
framework_version="1.5.0",
358+
py_version="py3",
359+
instance_count=instance_count,
360+
instance_type=instance_type,
361+
sagemaker_session=sagemaker_session,
362+
)
363+
step_train = TrainingStep(
364+
name="pytorch-train",
365+
estimator=pytorch_estimator,
366+
inputs=inputs,
367+
)
368+
369+
step_register = RegisterModel(
370+
name="pytorch-register-model",
371+
estimator=pytorch_estimator,
372+
model_data=step_train.properties.ModelArtifacts.S3ModelArtifacts,
373+
content_types=["*"],
374+
response_types=["*"],
375+
inference_instances=["*"],
376+
description="test-description",
377+
sample_payload_url=sample_payload_url,
378+
task=task,
379+
framework=framework,
380+
framework_version=framework_version,
381+
nearest_model_name=nearest_model_name,
382+
data_input_configuration=data_input_configuration,
383+
)
384+
385+
model = Model(
386+
image_uri=pytorch_estimator.training_image_uri(),
387+
model_data=step_train.properties.ModelArtifacts.S3ModelArtifacts,
388+
sagemaker_session=sagemaker_session,
389+
role=role,
390+
)
391+
model_inputs = CreateModelInput(
392+
instance_type="ml.m5.large",
393+
accelerator_type="ml.eia1.medium",
394+
)
395+
step_model = CreateModelStep(
396+
name="pytorch-model",
397+
model=model,
398+
inputs=model_inputs,
399+
)
400+
401+
step_cond = ConditionStep(
402+
name="cond-good-enough",
403+
conditions=[
404+
ConditionGreaterThanOrEqualTo(left=good_enough_input, right=1),
405+
ConditionIn(value=in_condition_input, in_values=["foo", "bar"]),
406+
],
407+
if_steps=[step_register],
408+
else_steps=[step_model],
409+
depends_on=[step_train],
410+
)
411+
412+
pipeline = Pipeline(
413+
name=pipeline_name,
414+
parameters=[
415+
in_condition_input,
416+
good_enough_input,
417+
instance_count,
418+
],
419+
steps=[step_train, step_cond],
420+
sagemaker_session=sagemaker_session,
421+
)
422+
423+
try:
424+
response = pipeline.create(role)
425+
create_arn = response["PipelineArn"]
426+
assert re.match(
427+
rf"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}",
428+
create_arn,
429+
)
430+
431+
execution = pipeline.start(parameters={})
432+
assert re.match(
433+
rf"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}/execution/",
434+
execution.arn,
435+
)
436+
437+
execution = pipeline.start(parameters={"GoodEnoughInput": 0})
438+
assert re.match(
439+
rf"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}/execution/",
440+
execution.arn,
441+
)
442+
finally:
443+
try:
444+
pipeline.delete()
445+
except Exception:
446+
pass
447+
448+
202449
def test_mxnet_model_registration(
203450
sagemaker_session,
204451
role,

tests/unit/sagemaker/workflow/test_pipeline_session.py

+62-2
Original file line numberDiff line numberDiff line change
@@ -228,8 +228,68 @@ def test_pipeline_session_context_for_model_step_without_instance_types(
228228
],
229229
"SupportedContentTypes": ["text/csv"],
230230
"SupportedResponseMIMETypes": ["text/csv"],
231-
"SupportedRealtimeInferenceInstanceTypes": None,
232-
"SupportedTransformInstanceTypes": None,
231+
},
232+
"CertifyForMarketplace": False,
233+
"ModelApprovalStatus": "PendingManualApproval",
234+
"SamplePayloadUrl": "s3://test-bucket/model",
235+
"Task": "IMAGE_CLASSIFICATION",
236+
}
237+
238+
assert register_step_args.create_model_package_request == expected_output
239+
240+
241+
def test_pipeline_session_context_for_model_step_with_one_instance_types(
242+
pipeline_session_mock,
243+
):
244+
model = Model(
245+
name="MyModel",
246+
image_uri="fakeimage",
247+
model_data=ParameterString(name="ModelData", default_value="s3://my-bucket/file"),
248+
sagemaker_session=pipeline_session_mock,
249+
entry_point=f"{DATA_DIR}/dummy_script.py",
250+
source_dir=f"{DATA_DIR}",
251+
role=_ROLE,
252+
)
253+
register_step_args = model.register(
254+
content_types=["text/csv"],
255+
response_types=["text/csv"],
256+
inference_instances=["ml.t2.medium", "ml.m5.xlarge"],
257+
model_package_group_name="MyModelPackageGroup",
258+
task="IMAGE_CLASSIFICATION",
259+
sample_payload_url="s3://test-bucket/model",
260+
framework="TENSORFLOW",
261+
framework_version="2.9",
262+
nearest_model_name="resnet50",
263+
data_input_configuration='{"input_1":[1,224,224,3]}',
264+
)
265+
266+
expected_output = {
267+
"ModelPackageGroupName": "MyModelPackageGroup",
268+
"InferenceSpecification": {
269+
"Containers": [
270+
{
271+
"Image": "fakeimage",
272+
"Environment": {
273+
"SAGEMAKER_PROGRAM": "dummy_script.py",
274+
"SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code",
275+
"SAGEMAKER_CONTAINER_LOG_LEVEL": "20",
276+
"SAGEMAKER_REGION": "us-west-2",
277+
},
278+
"ModelDataUrl": ParameterString(
279+
name="ModelData",
280+
default_value="s3://my-bucket/file",
281+
),
282+
"Framework": "TENSORFLOW",
283+
"FrameworkVersion": "2.9",
284+
"NearestModelName": "resnet50",
285+
"ModelInput": {
286+
"DataInputConfig": '{"input_1":[1,224,224,3]}',
287+
},
288+
}
289+
],
290+
"SupportedContentTypes": ["text/csv"],
291+
"SupportedResponseMIMETypes": ["text/csv"],
292+
"SupportedRealtimeInferenceInstanceTypes": ["ml.t2.medium", "ml.m5.xlarge"],
233293
},
234294
"CertifyForMarketplace": False,
235295
"ModelApprovalStatus": "PendingManualApproval",

0 commit comments

Comments
 (0)