Skip to content

Commit 38b4916

Browse files
authored
fix: enable model.register without 'inference' & 'transform' instances (#3228)
* fix: enable model.register without 'inference_instances' & 'transform_instances' * remove failing integ test * fix: make instance_type optional for model registry model package and mandatory for marketplace model package * fix: black-check and flake8 errors
1 parent 1d2b364 commit 38b4916

File tree

4 files changed

+206
-17
lines changed

4 files changed

+206
-17
lines changed

src/sagemaker/session.py

+25-2
Original file line numberDiff line numberDiff line change
@@ -4499,9 +4499,32 @@ def get_create_model_package_request(
44994499
"Containers": containers,
45004500
"SupportedContentTypes": content_types,
45014501
"SupportedResponseMIMETypes": response_types,
4502-
"SupportedRealtimeInferenceInstanceTypes": inference_instances,
4503-
"SupportedTransformInstanceTypes": transform_instances,
45044502
}
4503+
if model_package_group_name is not None:
4504+
if inference_instances is not None:
4505+
inference_specification.update(
4506+
{
4507+
"SupportedRealtimeInferenceInstanceTypes": inference_instances,
4508+
}
4509+
)
4510+
if transform_instances is not None:
4511+
inference_specification.update(
4512+
{
4513+
"SupportedTransformInstanceTypes": transform_instances,
4514+
}
4515+
)
4516+
else:
4517+
if not all([inference_instances, transform_instances]):
4518+
raise ValueError(
4519+
"inference_instances and transform_instances "
4520+
"must be provided if model_package_group_name is not present."
4521+
)
4522+
inference_specification.update(
4523+
{
4524+
"SupportedRealtimeInferenceInstanceTypes": inference_instances,
4525+
"SupportedTransformInstanceTypes": transform_instances,
4526+
}
4527+
)
45054528
request_dict["InferenceSpecification"] = inference_specification
45064529
request_dict["CertifyForMarketplace"] = marketplace_cert
45074530
request_dict["ModelApprovalStatus"] = approval_status

src/sagemaker/workflow/_utils.py

+2-7
Original file line numberDiff line numberDiff line change
@@ -341,16 +341,11 @@ def __init__(
341341
super(_RegisterModelStep, self).__init__(
342342
name, StepTypeEnum.REGISTER_MODEL, display_name, description, depends_on, retry_policies
343343
)
344-
deprecated_args_missing = (
345-
content_types is None
346-
or response_types is None
347-
or inference_instances is None
348-
or transform_instances is None
349-
)
344+
deprecated_args_missing = content_types is None or response_types is None
350345
if not (step_args is None) ^ deprecated_args_missing:
351346
raise ValueError(
352347
"step_args and the set of (content_types, response_types, "
353-
"inference_instances, transform_instances) are mutually exclusive. "
348+
") are mutually exclusive. "
354349
"Either of them should be provided."
355350
)
356351

tests/unit/sagemaker/workflow/test_pipeline_session.py

+93-2
Original file line numberDiff line numberDiff line change
@@ -224,8 +224,6 @@ def test_pipeline_session_context_for_model_step_without_instance_types(
224224
],
225225
"SupportedContentTypes": ["text/csv"],
226226
"SupportedResponseMIMETypes": ["text/csv"],
227-
"SupportedRealtimeInferenceInstanceTypes": None,
228-
"SupportedTransformInstanceTypes": None,
229227
},
230228
"CertifyForMarketplace": False,
231229
"ModelApprovalStatus": "PendingManualApproval",
@@ -234,3 +232,96 @@ def test_pipeline_session_context_for_model_step_without_instance_types(
234232
}
235233

236234
assert register_step_args.create_model_package_request == expected_output
235+
236+
237+
def test_pipeline_session_context_for_model_step_with_one_instance_types(
238+
pipeline_session_mock,
239+
):
240+
model = Model(
241+
name="MyModel",
242+
image_uri="fakeimage",
243+
model_data=ParameterString(name="ModelData", default_value="s3://my-bucket/file"),
244+
sagemaker_session=pipeline_session_mock,
245+
entry_point=f"{DATA_DIR}/dummy_script.py",
246+
source_dir=f"{DATA_DIR}",
247+
role=_ROLE,
248+
)
249+
register_step_args = model.register(
250+
content_types=["text/csv"],
251+
response_types=["text/csv"],
252+
inference_instances=["ml.t2.medium", "ml.m5.xlarge"],
253+
model_package_group_name="MyModelPackageGroup",
254+
task="IMAGE_CLASSIFICATION",
255+
sample_payload_url="s3://test-bucket/model",
256+
framework="TENSORFLOW",
257+
framework_version="2.9",
258+
nearest_model_name="resnet50",
259+
data_input_configuration='{"input_1":[1,224,224,3]}',
260+
)
261+
262+
expected_output = {
263+
"ModelPackageGroupName": "MyModelPackageGroup",
264+
"InferenceSpecification": {
265+
"Containers": [
266+
{
267+
"Image": "fakeimage",
268+
"Environment": {
269+
"SAGEMAKER_PROGRAM": "dummy_script.py",
270+
"SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code",
271+
"SAGEMAKER_CONTAINER_LOG_LEVEL": "20",
272+
"SAGEMAKER_REGION": "us-west-2",
273+
},
274+
"ModelDataUrl": ParameterString(
275+
name="ModelData",
276+
default_value="s3://my-bucket/file",
277+
),
278+
"Framework": "TENSORFLOW",
279+
"FrameworkVersion": "2.9",
280+
"NearestModelName": "resnet50",
281+
"ModelInput": {
282+
"DataInputConfig": '{"input_1":[1,224,224,3]}',
283+
},
284+
}
285+
],
286+
"SupportedContentTypes": ["text/csv"],
287+
"SupportedResponseMIMETypes": ["text/csv"],
288+
"SupportedRealtimeInferenceInstanceTypes": ["ml.t2.medium", "ml.m5.xlarge"],
289+
},
290+
"CertifyForMarketplace": False,
291+
"ModelApprovalStatus": "PendingManualApproval",
292+
"SamplePayloadUrl": "s3://test-bucket/model",
293+
"Task": "IMAGE_CLASSIFICATION",
294+
}
295+
296+
assert register_step_args.create_model_package_request == expected_output
297+
298+
299+
def test_pipeline_session_context_for_model_step_without_model_package_group_name(
300+
pipeline_session_mock,
301+
):
302+
model = Model(
303+
name="MyModel",
304+
image_uri="fakeimage",
305+
model_data=ParameterString(name="ModelData", default_value="s3://my-bucket/file"),
306+
sagemaker_session=pipeline_session_mock,
307+
entry_point=f"{DATA_DIR}/dummy_script.py",
308+
source_dir=f"{DATA_DIR}",
309+
role=_ROLE,
310+
)
311+
with pytest.raises(ValueError) as error:
312+
model.register(
313+
content_types=["text/csv"],
314+
response_types=["text/csv"],
315+
inference_instances=["ml.t2.medium", "ml.m5.xlarge"],
316+
model_package_name="MyModelPackageName",
317+
task="IMAGE_CLASSIFICATION",
318+
sample_payload_url="s3://test-bucket/model",
319+
framework="TENSORFLOW",
320+
framework_version="2.9",
321+
nearest_model_name="resnet50",
322+
data_input_configuration='{"input_1":[1,224,224,3]}',
323+
)
324+
assert (
325+
"inference_inferences and transform_instances "
326+
"must be provided if model_package_group_name is not present." == str(error)
327+
)

tests/unit/test_session.py

+86-6
Original file line numberDiff line numberDiff line change
@@ -2355,11 +2355,29 @@ def test_create_model_package_from_containers_incomplete_args(sagemaker_session)
23552355
containers=containers,
23562356
)
23572357
assert (
2358-
"content_types, response_types, inference_inferences and transform_instances "
2358+
"content_types and response_types "
23592359
"must be provided if containers is present." == str(error)
23602360
)
23612361

23622362

2363+
def test_create_model_package_from_containers_without_model_package_group_name(sagemaker_session):
2364+
model_package_name = "sagemaker-model-package"
2365+
containers = ["dummy-container"]
2366+
content_types = ["application/json"]
2367+
response_types = ["application/json"]
2368+
with pytest.raises(ValueError) as error:
2369+
sagemaker_session.create_model_package_from_containers(
2370+
model_package_name=model_package_name,
2371+
containers=containers,
2372+
content_types=content_types,
2373+
response_types=response_types,
2374+
)
2375+
assert (
2376+
"inference_inferences and transform_instances "
2377+
"must be provided if model_package_group_name is not present." == str(error)
2378+
)
2379+
2380+
23632381
def test_create_model_package_from_containers_all_args(sagemaker_session):
23642382
model_package_name = "sagemaker-model-package"
23652383
containers = ["dummy-container"]
@@ -2437,7 +2455,7 @@ def test_create_model_package_from_containers_all_args(sagemaker_session):
24372455

24382456

24392457
def test_create_model_package_from_containers_without_instance_types(sagemaker_session):
2440-
model_package_name = "sagemaker-model-package"
2458+
model_package_group_name = "sagemaker-model-package-group-name-1.0"
24412459
containers = ["dummy-container"]
24422460
content_types = ["application/json"]
24432461
response_types = ["application/json"]
@@ -2470,7 +2488,7 @@ def test_create_model_package_from_containers_without_instance_types(sagemaker_s
24702488
containers=containers,
24712489
content_types=content_types,
24722490
response_types=response_types,
2473-
model_package_name=model_package_name,
2491+
model_package_group_name=model_package_group_name,
24742492
model_metrics=model_metrics,
24752493
metadata_properties=metadata_properties,
24762494
marketplace_cert=marketplace_cert,
@@ -2480,13 +2498,75 @@ def test_create_model_package_from_containers_without_instance_types(sagemaker_s
24802498
customer_metadata_properties=customer_metadata_properties,
24812499
)
24822500
expected_args = {
2483-
"ModelPackageName": model_package_name,
2501+
"ModelPackageGroupName": model_package_group_name,
24842502
"InferenceSpecification": {
24852503
"Containers": containers,
24862504
"SupportedContentTypes": content_types,
24872505
"SupportedResponseMIMETypes": response_types,
2488-
"SupportedRealtimeInferenceInstanceTypes": None,
2489-
"SupportedTransformInstanceTypes": None,
2506+
},
2507+
"ModelPackageDescription": description,
2508+
"ModelMetrics": model_metrics,
2509+
"MetadataProperties": metadata_properties,
2510+
"CertifyForMarketplace": marketplace_cert,
2511+
"ModelApprovalStatus": approval_status,
2512+
"DriftCheckBaselines": drift_check_baselines,
2513+
"CustomerMetadataProperties": customer_metadata_properties,
2514+
}
2515+
sagemaker_session.sagemaker_client.create_model_package.assert_called_with(**expected_args)
2516+
2517+
2518+
def test_create_model_package_from_containers_with_one_instance_types(sagemaker_session):
2519+
model_package_group_name = "sagemaker-model-package-group-name-1.0"
2520+
containers = ["dummy-container"]
2521+
content_types = ["application/json"]
2522+
response_types = ["application/json"]
2523+
transform_instances = ["ml.m5.xlarge"]
2524+
model_metrics = {
2525+
"Bias": {
2526+
"ContentType": "content-type",
2527+
"S3Uri": "s3://...",
2528+
}
2529+
}
2530+
drift_check_baselines = {
2531+
"Bias": {
2532+
"ConfigFile": {
2533+
"ContentType": "content-type",
2534+
"S3Uri": "s3://...",
2535+
}
2536+
}
2537+
}
2538+
2539+
metadata_properties = {
2540+
"CommitId": "test-commit-id",
2541+
"Repository": "test-repository",
2542+
"GeneratedBy": "sagemaker-python-sdk",
2543+
"ProjectId": "unit-test",
2544+
}
2545+
marketplace_cert = (True,)
2546+
approval_status = ("Approved",)
2547+
description = "description"
2548+
customer_metadata_properties = {"key1": "value1"}
2549+
sagemaker_session.create_model_package_from_containers(
2550+
containers=containers,
2551+
content_types=content_types,
2552+
response_types=response_types,
2553+
transform_instances=transform_instances,
2554+
model_package_group_name=model_package_group_name,
2555+
model_metrics=model_metrics,
2556+
metadata_properties=metadata_properties,
2557+
marketplace_cert=marketplace_cert,
2558+
approval_status=approval_status,
2559+
description=description,
2560+
drift_check_baselines=drift_check_baselines,
2561+
customer_metadata_properties=customer_metadata_properties,
2562+
)
2563+
expected_args = {
2564+
"ModelPackageGroupName": model_package_group_name,
2565+
"InferenceSpecification": {
2566+
"Containers": containers,
2567+
"SupportedContentTypes": content_types,
2568+
"SupportedResponseMIMETypes": response_types,
2569+
"SupportedTransformInstanceTypes": transform_instances,
24902570
},
24912571
"ModelPackageDescription": description,
24922572
"ModelMetrics": model_metrics,

0 commit comments

Comments
 (0)