Skip to content

Commit bcb3425

Browse files
committed
fix: make instance_type optional for model registry model package and mandatory for marketplace model package
1 parent ff7e2bf commit bcb3425

File tree

3 files changed

+76
-19
lines changed

3 files changed

+76
-19
lines changed

src/sagemaker/session.py

+22-12
Original file line numberDiff line numberDiff line change
@@ -4500,18 +4500,28 @@ def get_create_model_package_request(
45004500
"SupportedContentTypes": content_types,
45014501
"SupportedResponseMIMETypes": response_types,
45024502
}
4503-
if inference_instances is not None:
4504-
inference_specification.update(
4505-
{
4506-
"SupportedRealtimeInferenceInstanceTypes": inference_instances,
4507-
}
4508-
)
4509-
if transform_instances is not None:
4510-
inference_specification.update(
4511-
{
4512-
"SupportedTransformInstanceTypes": transform_instances,
4513-
}
4514-
)
4503+
if model_package_group_name is not None:
4504+
if inference_instances is not None:
4505+
inference_specification.update(
4506+
{
4507+
"SupportedRealtimeInferenceInstanceTypes": inference_instances,
4508+
}
4509+
)
4510+
if transform_instances is not None:
4511+
inference_specification.update(
4512+
{
4513+
"SupportedTransformInstanceTypes": transform_instances,
4514+
}
4515+
)
4516+
else:
4517+
if not all([inference_instances, transform_instances]):
4518+
raise ValueError(
4519+
"inference_instances and transform_instances " "must be provided if model_package_group_name is not present."
4520+
)
4521+
inference_specification.update({
4522+
"SupportedRealtimeInferenceInstanceTypes": inference_instances,
4523+
"SupportedTransformInstanceTypes": transform_instances,
4524+
})
45154525
request_dict["InferenceSpecification"] = inference_specification
45164526
request_dict["CertifyForMarketplace"] = marketplace_cert
45174527
request_dict["ModelApprovalStatus"] = approval_status

tests/unit/sagemaker/workflow/test_pipeline_session.py

+30
Original file line numberDiff line numberDiff line change
@@ -298,3 +298,33 @@ def test_pipeline_session_context_for_model_step_with_one_instance_types(
298298
}
299299

300300
assert register_step_args.create_model_package_request == expected_output
301+
302+
def test_pipeline_session_context_for_model_step_without_model_package_group_name(
303+
pipeline_session_mock,
304+
):
305+
model = Model(
306+
name="MyModel",
307+
image_uri="fakeimage",
308+
model_data=ParameterString(name="ModelData", default_value="s3://my-bucket/file"),
309+
sagemaker_session=pipeline_session_mock,
310+
entry_point=f"{DATA_DIR}/dummy_script.py",
311+
source_dir=f"{DATA_DIR}",
312+
role=_ROLE,
313+
)
314+
with pytest.raises(ValueError) as error:
315+
model.register(
316+
content_types=["text/csv"],
317+
response_types=["text/csv"],
318+
inference_instances=["ml.t2.medium", "ml.m5.xlarge"],
319+
model_package_name="MyModelPackageName",
320+
task="IMAGE_CLASSIFICATION",
321+
sample_payload_url="s3://test-bucket/model",
322+
framework="TENSORFLOW",
323+
framework_version="2.9",
324+
nearest_model_name="resnet50",
325+
data_input_configuration='{"input_1":[1,224,224,3]}',
326+
)
327+
assert (
328+
"inference_inferences and transform_instances "
329+
"must be provided if model_package_group_name is not present." == str(error)
330+
)

tests/unit/test_session.py

+24-7
Original file line numberDiff line numberDiff line change
@@ -2355,10 +2355,27 @@ def test_create_model_package_from_containers_incomplete_args(sagemaker_session)
23552355
containers=containers,
23562356
)
23572357
assert (
2358-
"content_types, response_types, inference_inferences and transform_instances "
2358+
"content_types and response_types "
23592359
"must be provided if containers is present." == str(error)
23602360
)
23612361

2362+
def test_create_model_package_from_containers_without_model_package_group_name(sagemaker_session):
2363+
model_package_name = "sagemaker-model-package"
2364+
containers = ["dummy-container"]
2365+
content_types = ["application/json"]
2366+
response_types = ["application/json"]
2367+
with pytest.raises(ValueError) as error:
2368+
sagemaker_session.create_model_package_from_containers(
2369+
model_package_name=model_package_name,
2370+
containers=containers,
2371+
content_types=content_types,
2372+
response_types=response_types,
2373+
)
2374+
assert (
2375+
"inference_inferences and transform_instances "
2376+
"must be provided if model_package_group_name is not present." == str(error)
2377+
)
2378+
23622379

23632380
def test_create_model_package_from_containers_all_args(sagemaker_session):
23642381
model_package_name = "sagemaker-model-package"
@@ -2437,7 +2454,7 @@ def test_create_model_package_from_containers_all_args(sagemaker_session):
24372454

24382455

24392456
def test_create_model_package_from_containers_without_instance_types(sagemaker_session):
2440-
model_package_name = "sagemaker-model-package"
2457+
model_package_group_name = "sagemaker-model-package-group-name-1.0"
24412458
containers = ["dummy-container"]
24422459
content_types = ["application/json"]
24432460
response_types = ["application/json"]
@@ -2470,7 +2487,7 @@ def test_create_model_package_from_containers_without_instance_types(sagemaker_s
24702487
containers=containers,
24712488
content_types=content_types,
24722489
response_types=response_types,
2473-
model_package_name=model_package_name,
2490+
model_package_group_name=model_package_group_name,
24742491
model_metrics=model_metrics,
24752492
metadata_properties=metadata_properties,
24762493
marketplace_cert=marketplace_cert,
@@ -2480,7 +2497,7 @@ def test_create_model_package_from_containers_without_instance_types(sagemaker_s
24802497
customer_metadata_properties=customer_metadata_properties,
24812498
)
24822499
expected_args = {
2483-
"ModelPackageName": model_package_name,
2500+
"ModelPackageGroupName": model_package_group_name,
24842501
"InferenceSpecification": {
24852502
"Containers": containers,
24862503
"SupportedContentTypes": content_types,
@@ -2498,7 +2515,7 @@ def test_create_model_package_from_containers_without_instance_types(sagemaker_s
24982515

24992516

25002517
def test_create_model_package_from_containers_with_one_instance_types(sagemaker_session):
2501-
model_package_name = "sagemaker-model-package"
2518+
model_package_group_name = "sagemaker-model-package-group-name-1.0"
25022519
containers = ["dummy-container"]
25032520
content_types = ["application/json"]
25042521
response_types = ["application/json"]
@@ -2533,7 +2550,7 @@ def test_create_model_package_from_containers_with_one_instance_types(sagemaker_
25332550
content_types=content_types,
25342551
response_types=response_types,
25352552
transform_instances=transform_instances,
2536-
model_package_name=model_package_name,
2553+
model_package_group_name=model_package_group_name,
25372554
model_metrics=model_metrics,
25382555
metadata_properties=metadata_properties,
25392556
marketplace_cert=marketplace_cert,
@@ -2543,7 +2560,7 @@ def test_create_model_package_from_containers_with_one_instance_types(sagemaker_
25432560
customer_metadata_properties=customer_metadata_properties,
25442561
)
25452562
expected_args = {
2546-
"ModelPackageName": model_package_name,
2563+
"ModelPackageGroupName": model_package_group_name,
25472564
"InferenceSpecification": {
25482565
"Containers": containers,
25492566
"SupportedContentTypes": content_types,

0 commit comments

Comments
 (0)