Skip to content

Commit 5b54ffe

Browse files
author
Keshav Chandak
committed
Fix: Returning ModelPackage object on register of PipelineModel
1 parent 66d5fdf commit 5b54ffe

File tree

2 files changed

+45
-1
lines changed

2 files changed

+45
-1
lines changed

src/sagemaker/pipeline.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
)
2727
from sagemaker.drift_check_baselines import DriftCheckBaselines
2828
from sagemaker.metadata_properties import MetadataProperties
29+
from sagemaker.model import ModelPackage
2930
from sagemaker.model_card import (
3031
ModelCard,
3132
ModelPackageModelCard,
@@ -470,7 +471,16 @@ def register(
470471
model_card=model_card,
471472
)
472473

473-
self.sagemaker_session.create_model_package_from_containers(**model_pkg_args)
474+
model_package = self.sagemaker_session.create_model_package_from_containers(
475+
**model_pkg_args
476+
)
477+
478+
return ModelPackage(
479+
role=self.role,
480+
model_package_arn=model_package.get("ModelPackageArn"),
481+
sagemaker_session=self.sagemaker_session,
482+
predictor_cls=self.predictor_cls,
483+
)
474484

475485
def transformer(
476486
self,

tests/integ/test_inference_pipeline.py

+34
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,40 @@ def test_inference_pipeline_model_deploy(sagemaker_session, cpu_instance_type):
150150
assert "Could not find model" in str(exception.value)
151151

152152

153+
@pytest.mark.release
154+
def test_inference_pipeline_model_register(sagemaker_session):
155+
sparkml_data_path = os.path.join(DATA_DIR, "sparkml_model")
156+
endpoint_name = unique_name_from_base("test-inference-pipeline-deploy")
157+
sparkml_model_data = sagemaker_session.upload_data(
158+
path=os.path.join(sparkml_data_path, "mleap_model.tar.gz"),
159+
key_prefix="integ-test-data/sparkml/model",
160+
)
161+
162+
sparkml_model = SparkMLModel(
163+
model_data=sparkml_model_data,
164+
env={"SAGEMAKER_SPARKML_SCHEMA": SCHEMA},
165+
sagemaker_session=sagemaker_session,
166+
)
167+
168+
model = PipelineModel(
169+
models=[sparkml_model],
170+
role="SageMakerRole",
171+
sagemaker_session=sagemaker_session,
172+
name=endpoint_name,
173+
)
174+
model_package_group_name = unique_name_from_base("pipeline-model-package")
175+
model_package = model.register(model_package_group_name=model_package_group_name)
176+
assert model_package.model_package_arn is not None
177+
178+
sagemaker_session.sagemaker_client.delete_model_package(
179+
ModelPackageName=model_package.model_package_arn
180+
)
181+
182+
sagemaker_session.sagemaker_client.delete_model_package_group(
183+
ModelPackageGroupName=model_package_group_name
184+
)
185+
186+
153187
@pytest.mark.slow_test
154188
@pytest.mark.flaky(reruns=5, reruns_delay=2)
155189
def test_inference_pipeline_model_deploy_and_update_endpoint(

0 commit comments

Comments
 (0)