Skip to content

Commit 2860730

Browse files
Keshav Chandakkeshav-chandak
Keshav Chandak
authored andcommitted
Fix: Returning ModelPackage object on register of PipelineModel
1 parent dc0860d commit 2860730

File tree

3 files changed

+71
-1
lines changed

3 files changed

+71
-1
lines changed

src/sagemaker/pipeline.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
)
2727
from sagemaker.drift_check_baselines import DriftCheckBaselines
2828
from sagemaker.metadata_properties import MetadataProperties
29+
from sagemaker.model import ModelPackage
2930
from sagemaker.model_card import (
3031
ModelCard,
3132
ModelPackageModelCard,
@@ -470,7 +471,18 @@ def register(
470471
model_card=model_card,
471472
)
472473

473-
self.sagemaker_session.create_model_package_from_containers(**model_pkg_args)
474+
model_package = self.sagemaker_session.create_model_package_from_containers(
475+
**model_pkg_args
476+
)
477+
478+
if model_package is not None and "ModelPackageArn" in model_package:
479+
return ModelPackage(
480+
role=self.role,
481+
model_package_arn=model_package.get("ModelPackageArn"),
482+
sagemaker_session=self.sagemaker_session,
483+
predictor_cls=self.predictor_cls,
484+
)
485+
return None
474486

475487
def transformer(
476488
self,

tests/integ/test_inference_pipeline.py

+34
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,40 @@ def test_inference_pipeline_model_deploy(sagemaker_session, cpu_instance_type):
150150
assert "Could not find model" in str(exception.value)
151151

152152

153+
@pytest.mark.release
154+
def test_inference_pipeline_model_register(sagemaker_session):
155+
sparkml_data_path = os.path.join(DATA_DIR, "sparkml_model")
156+
endpoint_name = unique_name_from_base("test-inference-pipeline-deploy")
157+
sparkml_model_data = sagemaker_session.upload_data(
158+
path=os.path.join(sparkml_data_path, "mleap_model.tar.gz"),
159+
key_prefix="integ-test-data/sparkml/model",
160+
)
161+
162+
sparkml_model = SparkMLModel(
163+
model_data=sparkml_model_data,
164+
env={"SAGEMAKER_SPARKML_SCHEMA": SCHEMA},
165+
sagemaker_session=sagemaker_session,
166+
)
167+
168+
model = PipelineModel(
169+
models=[sparkml_model],
170+
role="SageMakerRole",
171+
sagemaker_session=sagemaker_session,
172+
name=endpoint_name,
173+
)
174+
model_package_group_name = unique_name_from_base("pipeline-model-package")
175+
model_package = model.register(model_package_group_name=model_package_group_name)
176+
assert model_package.model_package_arn is not None
177+
178+
sagemaker_session.sagemaker_client.delete_model_package(
179+
ModelPackageName=model_package.model_package_arn
180+
)
181+
182+
sagemaker_session.sagemaker_client.delete_model_package_group(
183+
ModelPackageGroupName=model_package_group_name
184+
)
185+
186+
153187
@pytest.mark.slow_test
154188
@pytest.mark.flaky(reruns=5, reruns_delay=2)
155189
def test_inference_pipeline_model_deploy_and_update_endpoint(

tests/unit/test_pipeline_model.py

+24
Original file line numberDiff line numberDiff line change
@@ -420,3 +420,27 @@ def test_network_isolation(tfo, time, sagemaker_session):
420420
vpc_config=None,
421421
enable_network_isolation=True,
422422
)
423+
424+
425+
def test_pipeline_model_register(sagemaker_session):
426+
sagemaker_session.create_model_package_from_containers = Mock(
427+
name="create_model_package_from_containers",
428+
return_value={
429+
"ModelPackageArn": "arn:aws:sagemaker:us-west-2:123456789123:model-package/unit-test-package-version/1"
430+
},
431+
)
432+
framework_model = DummyFrameworkModel(sagemaker_session)
433+
sparkml_model = SparkMLModel(
434+
model_data=MODEL_DATA_2, role=ROLE, sagemaker_session=sagemaker_session
435+
)
436+
model = PipelineModel(
437+
models=[framework_model, sparkml_model],
438+
role=ROLE,
439+
sagemaker_session=sagemaker_session,
440+
enable_network_isolation=True,
441+
)
442+
model_package = model.register()
443+
assert (
444+
model_package.model_package_arn
445+
== "arn:aws:sagemaker:us-west-2:123456789123:model-package/unit-test-package-version/1"
446+
)

0 commit comments

Comments
 (0)