diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index f426724b6c..145bf41cbe 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -4499,9 +4499,32 @@ def get_create_model_package_request( "Containers": containers, "SupportedContentTypes": content_types, "SupportedResponseMIMETypes": response_types, - "SupportedRealtimeInferenceInstanceTypes": inference_instances, - "SupportedTransformInstanceTypes": transform_instances, } + if model_package_group_name is not None: + if inference_instances is not None: + inference_specification.update( + { + "SupportedRealtimeInferenceInstanceTypes": inference_instances, + } + ) + if transform_instances is not None: + inference_specification.update( + { + "SupportedTransformInstanceTypes": transform_instances, + } + ) + else: + if not all([inference_instances, transform_instances]): + raise ValueError( + "inference_instances and transform_instances " + "must be provided if model_package_group_name is not present." + ) + inference_specification.update( + { + "SupportedRealtimeInferenceInstanceTypes": inference_instances, + "SupportedTransformInstanceTypes": transform_instances, + } + ) request_dict["InferenceSpecification"] = inference_specification request_dict["CertifyForMarketplace"] = marketplace_cert request_dict["ModelApprovalStatus"] = approval_status diff --git a/src/sagemaker/workflow/_utils.py b/src/sagemaker/workflow/_utils.py index 7a0a399299..f8a99996a5 100644 --- a/src/sagemaker/workflow/_utils.py +++ b/src/sagemaker/workflow/_utils.py @@ -341,16 +341,11 @@ def __init__( super(_RegisterModelStep, self).__init__( name, StepTypeEnum.REGISTER_MODEL, display_name, description, depends_on, retry_policies ) - deprecated_args_missing = ( - content_types is None - or response_types is None - or inference_instances is None - or transform_instances is None - ) + deprecated_args_missing = content_types is None or response_types is None if not (step_args is None) ^ deprecated_args_missing: raise ValueError( "step_args and the set of (content_types, response_types, " - "inference_instances, transform_instances) are mutually exclusive. " + ") are mutually exclusive. " "Either of them should be provided." ) diff --git a/tests/unit/sagemaker/workflow/test_pipeline_session.py b/tests/unit/sagemaker/workflow/test_pipeline_session.py index 13af00cf6a..eca3892390 100644 --- a/tests/unit/sagemaker/workflow/test_pipeline_session.py +++ b/tests/unit/sagemaker/workflow/test_pipeline_session.py @@ -224,8 +224,6 @@ def test_pipeline_session_context_for_model_step_without_instance_types( ], "SupportedContentTypes": ["text/csv"], "SupportedResponseMIMETypes": ["text/csv"], - "SupportedRealtimeInferenceInstanceTypes": None, - "SupportedTransformInstanceTypes": None, }, "CertifyForMarketplace": False, "ModelApprovalStatus": "PendingManualApproval", @@ -234,3 +232,96 @@ def test_pipeline_session_context_for_model_step_without_instance_types( } assert register_step_args.create_model_package_request == expected_output + + +def test_pipeline_session_context_for_model_step_with_one_instance_types( + pipeline_session_mock, +): + model = Model( + name="MyModel", + image_uri="fakeimage", + model_data=ParameterString(name="ModelData", default_value="s3://my-bucket/file"), + sagemaker_session=pipeline_session_mock, + entry_point=f"{DATA_DIR}/dummy_script.py", + source_dir=f"{DATA_DIR}", + role=_ROLE, + ) + register_step_args = model.register( + content_types=["text/csv"], + response_types=["text/csv"], + inference_instances=["ml.t2.medium", "ml.m5.xlarge"], + model_package_group_name="MyModelPackageGroup", + task="IMAGE_CLASSIFICATION", + sample_payload_url="s3://test-bucket/model", + framework="TENSORFLOW", + framework_version="2.9", + nearest_model_name="resnet50", + data_input_configuration='{"input_1":[1,224,224,3]}', + ) + + expected_output = { + "ModelPackageGroupName": "MyModelPackageGroup", + "InferenceSpecification": { + "Containers": [ + { + "Image": "fakeimage", + "Environment": { + "SAGEMAKER_PROGRAM": "dummy_script.py", + "SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code", + "SAGEMAKER_CONTAINER_LOG_LEVEL": "20", + "SAGEMAKER_REGION": "us-west-2", + }, + "ModelDataUrl": ParameterString( + name="ModelData", + default_value="s3://my-bucket/file", + ), + "Framework": "TENSORFLOW", + "FrameworkVersion": "2.9", + "NearestModelName": "resnet50", + "ModelInput": { + "DataInputConfig": '{"input_1":[1,224,224,3]}', + }, + } + ], + "SupportedContentTypes": ["text/csv"], + "SupportedResponseMIMETypes": ["text/csv"], + "SupportedRealtimeInferenceInstanceTypes": ["ml.t2.medium", "ml.m5.xlarge"], + }, + "CertifyForMarketplace": False, + "ModelApprovalStatus": "PendingManualApproval", + "SamplePayloadUrl": "s3://test-bucket/model", + "Task": "IMAGE_CLASSIFICATION", + } + + assert register_step_args.create_model_package_request == expected_output + + +def test_pipeline_session_context_for_model_step_without_model_package_group_name( + pipeline_session_mock, +): + model = Model( + name="MyModel", + image_uri="fakeimage", + model_data=ParameterString(name="ModelData", default_value="s3://my-bucket/file"), + sagemaker_session=pipeline_session_mock, + entry_point=f"{DATA_DIR}/dummy_script.py", + source_dir=f"{DATA_DIR}", + role=_ROLE, + ) + with pytest.raises(ValueError) as error: + model.register( + content_types=["text/csv"], + response_types=["text/csv"], + inference_instances=["ml.t2.medium", "ml.m5.xlarge"], + model_package_name="MyModelPackageName", + task="IMAGE_CLASSIFICATION", + sample_payload_url="s3://test-bucket/model", + framework="TENSORFLOW", + framework_version="2.9", + nearest_model_name="resnet50", + data_input_configuration='{"input_1":[1,224,224,3]}', + ) + assert ( + "inference_inferences and transform_instances " + "must be provided if model_package_group_name is not present." == str(error) + ) diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 1fd58ea531..78df274b71 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -2355,11 +2355,29 @@ def test_create_model_package_from_containers_incomplete_args(sagemaker_session) containers=containers, ) assert ( - "content_types, response_types, inference_inferences and transform_instances " + "content_types and response_types " "must be provided if containers is present." == str(error) ) +def test_create_model_package_from_containers_without_model_package_group_name(sagemaker_session): + model_package_name = "sagemaker-model-package" + containers = ["dummy-container"] + content_types = ["application/json"] + response_types = ["application/json"] + with pytest.raises(ValueError) as error: + sagemaker_session.create_model_package_from_containers( + model_package_name=model_package_name, + containers=containers, + content_types=content_types, + response_types=response_types, + ) + assert ( + "inference_inferences and transform_instances " + "must be provided if model_package_group_name is not present." == str(error) + ) + + def test_create_model_package_from_containers_all_args(sagemaker_session): model_package_name = "sagemaker-model-package" containers = ["dummy-container"] @@ -2437,7 +2455,7 @@ def test_create_model_package_from_containers_all_args(sagemaker_session): def test_create_model_package_from_containers_without_instance_types(sagemaker_session): - model_package_name = "sagemaker-model-package" + model_package_group_name = "sagemaker-model-package-group-name-1.0" containers = ["dummy-container"] content_types = ["application/json"] response_types = ["application/json"] @@ -2470,7 +2488,7 @@ def test_create_model_package_from_containers_without_instance_types(sagemaker_s containers=containers, content_types=content_types, response_types=response_types, - model_package_name=model_package_name, + model_package_group_name=model_package_group_name, model_metrics=model_metrics, metadata_properties=metadata_properties, marketplace_cert=marketplace_cert, @@ -2480,13 +2498,75 @@ def test_create_model_package_from_containers_without_instance_types(sagemaker_s customer_metadata_properties=customer_metadata_properties, ) expected_args = { - "ModelPackageName": model_package_name, + "ModelPackageGroupName": model_package_group_name, "InferenceSpecification": { "Containers": containers, "SupportedContentTypes": content_types, "SupportedResponseMIMETypes": response_types, - "SupportedRealtimeInferenceInstanceTypes": None, - "SupportedTransformInstanceTypes": None, + }, + "ModelPackageDescription": description, + "ModelMetrics": model_metrics, + "MetadataProperties": metadata_properties, + "CertifyForMarketplace": marketplace_cert, + "ModelApprovalStatus": approval_status, + "DriftCheckBaselines": drift_check_baselines, + "CustomerMetadataProperties": customer_metadata_properties, + } + sagemaker_session.sagemaker_client.create_model_package.assert_called_with(**expected_args) + + +def test_create_model_package_from_containers_with_one_instance_types(sagemaker_session): + model_package_group_name = "sagemaker-model-package-group-name-1.0" + containers = ["dummy-container"] + content_types = ["application/json"] + response_types = ["application/json"] + transform_instances = ["ml.m5.xlarge"] + model_metrics = { + "Bias": { + "ContentType": "content-type", + "S3Uri": "s3://...", + } + } + drift_check_baselines = { + "Bias": { + "ConfigFile": { + "ContentType": "content-type", + "S3Uri": "s3://...", + } + } + } + + metadata_properties = { + "CommitId": "test-commit-id", + "Repository": "test-repository", + "GeneratedBy": "sagemaker-python-sdk", + "ProjectId": "unit-test", + } + marketplace_cert = (True,) + approval_status = ("Approved",) + description = "description" + customer_metadata_properties = {"key1": "value1"} + sagemaker_session.create_model_package_from_containers( + containers=containers, + content_types=content_types, + response_types=response_types, + transform_instances=transform_instances, + model_package_group_name=model_package_group_name, + model_metrics=model_metrics, + metadata_properties=metadata_properties, + marketplace_cert=marketplace_cert, + approval_status=approval_status, + description=description, + drift_check_baselines=drift_check_baselines, + customer_metadata_properties=customer_metadata_properties, + ) + expected_args = { + "ModelPackageGroupName": model_package_group_name, + "InferenceSpecification": { + "Containers": containers, + "SupportedContentTypes": content_types, + "SupportedResponseMIMETypes": response_types, + "SupportedTransformInstanceTypes": transform_instances, }, "ModelPackageDescription": description, "ModelMetrics": model_metrics,