diff --git a/src/sagemaker/pipeline.py b/src/sagemaker/pipeline.py index b5a3cd4357..04fbc1cc93 100644 --- a/src/sagemaker/pipeline.py +++ b/src/sagemaker/pipeline.py @@ -26,6 +26,7 @@ ) from sagemaker.drift_check_baselines import DriftCheckBaselines from sagemaker.metadata_properties import MetadataProperties +from sagemaker.model import ModelPackage from sagemaker.model_card import ( ModelCard, ModelPackageModelCard, @@ -470,7 +471,18 @@ def register( model_card=model_card, ) - self.sagemaker_session.create_model_package_from_containers(**model_pkg_args) + model_package = self.sagemaker_session.create_model_package_from_containers( + **model_pkg_args + ) + + if model_package is not None and "ModelPackageArn" in model_package: + return ModelPackage( + role=self.role, + model_package_arn=model_package.get("ModelPackageArn"), + sagemaker_session=self.sagemaker_session, + predictor_cls=self.predictor_cls, + ) + return None def transformer( self, diff --git a/tests/integ/test_inference_pipeline.py b/tests/integ/test_inference_pipeline.py index 9e6b41d753..6504932a7e 100644 --- a/tests/integ/test_inference_pipeline.py +++ b/tests/integ/test_inference_pipeline.py @@ -150,6 +150,40 @@ def test_inference_pipeline_model_deploy(sagemaker_session, cpu_instance_type): assert "Could not find model" in str(exception.value) +@pytest.mark.release +def test_inference_pipeline_model_register(sagemaker_session): + sparkml_data_path = os.path.join(DATA_DIR, "sparkml_model") + endpoint_name = unique_name_from_base("test-inference-pipeline-deploy") + sparkml_model_data = sagemaker_session.upload_data( + path=os.path.join(sparkml_data_path, "mleap_model.tar.gz"), + key_prefix="integ-test-data/sparkml/model", + ) + + sparkml_model = SparkMLModel( + model_data=sparkml_model_data, + env={"SAGEMAKER_SPARKML_SCHEMA": SCHEMA}, + sagemaker_session=sagemaker_session, + ) + + model = PipelineModel( + models=[sparkml_model], + role="SageMakerRole", + sagemaker_session=sagemaker_session, + name=endpoint_name, + ) + model_package_group_name = unique_name_from_base("pipeline-model-package") + model_package = model.register(model_package_group_name=model_package_group_name) + assert model_package.model_package_arn is not None + + sagemaker_session.sagemaker_client.delete_model_package( + ModelPackageName=model_package.model_package_arn + ) + + sagemaker_session.sagemaker_client.delete_model_package_group( + ModelPackageGroupName=model_package_group_name + ) + + @pytest.mark.slow_test @pytest.mark.flaky(reruns=5, reruns_delay=2) def test_inference_pipeline_model_deploy_and_update_endpoint( diff --git a/tests/unit/test_pipeline_model.py b/tests/unit/test_pipeline_model.py index b546d4e9e8..07d419779f 100644 --- a/tests/unit/test_pipeline_model.py +++ b/tests/unit/test_pipeline_model.py @@ -420,3 +420,27 @@ def test_network_isolation(tfo, time, sagemaker_session): vpc_config=None, enable_network_isolation=True, ) + + +def test_pipeline_model_register(sagemaker_session): + sagemaker_session.create_model_package_from_containers = Mock( + name="create_model_package_from_containers", + return_value={ + "ModelPackageArn": "arn:aws:sagemaker:us-west-2:123456789123:model-package/unit-test-package-version/1" + }, + ) + framework_model = DummyFrameworkModel(sagemaker_session) + sparkml_model = SparkMLModel( + model_data=MODEL_DATA_2, role=ROLE, sagemaker_session=sagemaker_session + ) + model = PipelineModel( + models=[framework_model, sparkml_model], + role=ROLE, + sagemaker_session=sagemaker_session, + enable_network_isolation=True, + ) + model_package = model.register() + assert ( + model_package.model_package_arn + == "arn:aws:sagemaker:us-west-2:123456789123:model-package/unit-test-package-version/1" + )