Skip to content

Commit b84301d

Browse files
author
Keshav Chandak
committed
feat: Model Package support for updating approval
1 parent d960e49 commit b84301d

File tree

15 files changed

+208
-34
lines changed

15 files changed

+208
-34
lines changed

src/sagemaker/chainer/model.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -148,8 +148,8 @@ def __init__(
148148

149149
def register(
150150
self,
151-
content_types: List[Union[str, PipelineVariable]],
152-
response_types: List[Union[str, PipelineVariable]],
151+
content_types: List[Union[str, PipelineVariable]] = None,
152+
response_types: List[Union[str, PipelineVariable]] = None,
153153
inference_instances: Optional[List[Union[str, PipelineVariable]]] = None,
154154
transform_instances: Optional[List[Union[str, PipelineVariable]]] = None,
155155
model_package_name: Optional[Union[str, PipelineVariable]] = None,

src/sagemaker/estimator.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1665,8 +1665,8 @@ def deploy(
16651665

16661666
def register(
16671667
self,
1668-
content_types,
1669-
response_types,
1668+
content_types=None,
1669+
response_types=None,
16701670
inference_instances=None,
16711671
transform_instances=None,
16721672
image_uri=None,

src/sagemaker/huggingface/model.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -332,8 +332,8 @@ def deploy(
332332

333333
def register(
334334
self,
335-
content_types: List[Union[str, PipelineVariable]],
336-
response_types: List[Union[str, PipelineVariable]],
335+
content_types: List[Union[str, PipelineVariable]] = None,
336+
response_types: List[Union[str, PipelineVariable]] = None,
337337
inference_instances: Optional[List[Union[str, PipelineVariable]]] = None,
338338
transform_instances: Optional[List[Union[str, PipelineVariable]]] = None,
339339
model_package_name: Optional[Union[str, PipelineVariable]] = None,

src/sagemaker/model.py

+64-7
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
ENDPOINT_CONFIG_ASYNC_KMS_KEY_ID_PATH,
4444
load_sagemaker_config,
4545
)
46+
from sagemaker.model_card.schema_constraints import ModelApprovalStatusEnum
4647
from sagemaker.session import Session
4748
from sagemaker.model_metrics import ModelMetrics
4849
from sagemaker.deprecations import removed_kwargs
@@ -374,12 +375,14 @@ def __init__(
374375
self.dependencies = updates["dependencies"]
375376
self.uploaded_code = None
376377
self.repacked_model_data = None
378+
self.content_types = None
379+
self.response_types = None
377380

378381
@runnable_by_pipeline
379382
def register(
380383
self,
381-
content_types: List[Union[str, PipelineVariable]],
382-
response_types: List[Union[str, PipelineVariable]],
384+
content_types: List[Union[str, PipelineVariable]] = None,
385+
response_types: List[Union[str, PipelineVariable]] = None,
383386
inference_instances: Optional[List[Union[str, PipelineVariable]]] = None,
384387
transform_instances: Optional[List[Union[str, PipelineVariable]]] = None,
385388
model_package_name: Optional[Union[str, PipelineVariable]] = None,
@@ -456,16 +459,33 @@ def register(
456459
in case the Model instance is built with
457460
:class:`~sagemaker.workflow.pipeline_context.PipelineSession`
458461
"""
459-
if self.model_data is None:
460-
raise ValueError("SageMaker Model Package cannot be created without model data.")
461462
if isinstance(self.model_data, dict):
462463
raise ValueError(
463464
"SageMaker Model Package currently cannot be created with ModelDataSource."
464465
)
465466

467+
if content_types is not None:
468+
self.content_types = content_types
469+
470+
if response_types is not None:
471+
self.response_types = response_types
472+
473+
if self.content_types is None:
474+
raise ValueError("The supported MIME types for the input data is not set")
475+
476+
if self.response_types is None:
477+
raise ValueError("The supported MIME types for the output data is not set")
478+
466479
if image_uri is not None:
467480
self.image_uri = image_uri
468481

482+
if model_package_group_name is None and model_package_name is None:
483+
# If model package group and model package name is not set
484+
# then register to auto-generated model package group
485+
model_package_group_name = utils.base_name_from_image(
486+
self.image_uri, default_base_name=ModelPackage.__name__
487+
)
488+
469489
if model_package_group_name is not None:
470490
container_def = self.prepare_container_def()
471491
container_def = update_container_with_inference_params(
@@ -478,12 +498,14 @@ def register(
478498
else:
479499
container_def = {
480500
"Image": self.image_uri,
481-
"ModelDataUrl": self.model_data,
482501
}
483502

503+
if self.model_data is not None:
504+
container_def["ModelDataUrl"] = self.model_data
505+
484506
model_pkg_args = sagemaker.get_model_package_args(
485-
content_types,
486-
response_types,
507+
self.content_types,
508+
self.response_types,
487509
inference_instances=inference_instances,
488510
transform_instances=transform_instances,
489511
model_package_name=model_package_name,
@@ -1751,6 +1773,7 @@ def __init__(
17511773

17521774
# works for MODEL_PACKAGE_ARN with or without version info.
17531775
MODEL_PACKAGE_ARN_PATTERN = r"arn:aws:sagemaker:(.*?):(.*?):model-package/(.*?)(?:/(\d+))?$"
1776+
MODEL_PACKAGE_VERSIONED_ARN_PATTERN = r"arn:aws:sagemaker:(.*?):(.*?):model-package/(.*?)/(\d+)$"
17541777

17551778

17561779
class ModelPackage(Model):
@@ -1885,6 +1908,18 @@ def _create_sagemaker_model(self, *args, **kwargs): # pylint: disable=unused-ar
18851908
self._ensure_base_name_if_needed(model_package_name)
18861909
self._set_model_name_if_needed()
18871910

1911+
# Quering the approval status for the model package
1912+
# Approving the versioned model package in case it is not approved
1913+
model_package_desc = self.sagemaker_session.sagemaker_client.describe_model_package(
1914+
ModelPackageName=self.model_package_arn or model_package_name
1915+
)
1916+
if self.model_package_arn is None:
1917+
self.model_package_arn = model_package_desc["ModelPackageArn"]
1918+
if re.match(MODEL_PACKAGE_VERSIONED_ARN_PATTERN, self.model_package_arn):
1919+
approval_status = model_package_desc.get("ModelApprovalStatus", "")
1920+
if approval_status != ModelApprovalStatusEnum.APPROVED:
1921+
self.update_approval_status(approval_status=ModelApprovalStatusEnum.APPROVED)
1922+
18881923
self.sagemaker_session.create_model(
18891924
self.name,
18901925
self.role,
@@ -1898,3 +1933,25 @@ def _ensure_base_name_if_needed(self, base_name):
18981933
"""Set the base name if there is no model name provided."""
18991934
if self.name is None:
19001935
self._base_name = base_name
1936+
1937+
def update_approval_status(self, approval_status, approval_description=None):
1938+
"""Update the approval status for the model package
1939+
1940+
Args:
1941+
approval_status (str or PipelineVariable): Model Approval Status, values can be
1942+
"Approved", "Rejected", or "PendingManualApproval".
1943+
approval_description (str): Optional. Description for the approval status of the model
1944+
(default: None).
1945+
"""
1946+
if self.model_package_arn is None:
1947+
raise ValueError("model_package_arn is required to update the status.")
1948+
1949+
update_approval_args = {
1950+
"ModelPackageArn": self.model_package_arn,
1951+
"ModelApprovalStatus": approval_status,
1952+
}
1953+
1954+
if approval_description is not None:
1955+
update_approval_args["ApprovalDescription"] = approval_description
1956+
1957+
self.sagemaker_session.sagemaker_client.update_model_package(**update_approval_args)

src/sagemaker/mxnet/model.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,8 @@ def __init__(
150150

151151
def register(
152152
self,
153-
content_types: List[Union[str, PipelineVariable]],
154-
response_types: List[Union[str, PipelineVariable]],
153+
content_types: List[Union[str, PipelineVariable]] = None,
154+
response_types: List[Union[str, PipelineVariable]] = None,
155155
inference_instances: Optional[List[Union[str, PipelineVariable]]] = None,
156156
transform_instances: Optional[List[Union[str, PipelineVariable]]] = None,
157157
model_package_name: Optional[Union[str, PipelineVariable]] = None,

src/sagemaker/pipeline.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -335,8 +335,8 @@ def _create_sagemaker_pipeline_model(self, instance_type):
335335
@runnable_by_pipeline
336336
def register(
337337
self,
338-
content_types: List[Union[str, PipelineVariable]],
339-
response_types: List[Union[str, PipelineVariable]],
338+
content_types: List[Union[str, PipelineVariable]] = None,
339+
response_types: List[Union[str, PipelineVariable]] = None,
340340
inference_instances: Optional[List[Union[str, PipelineVariable]]] = None,
341341
transform_instances: Optional[List[Union[str, PipelineVariable]]] = None,
342342
model_package_name: Optional[Union[str, PipelineVariable]] = None,

src/sagemaker/pytorch/model.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -152,8 +152,8 @@ def __init__(
152152

153153
def register(
154154
self,
155-
content_types: List[Union[str, PipelineVariable]],
156-
response_types: List[Union[str, PipelineVariable]],
155+
content_types: List[Union[str, PipelineVariable]] = None,
156+
response_types: List[Union[str, PipelineVariable]] = None,
157157
inference_instances: Optional[List[Union[str, PipelineVariable]]] = None,
158158
transform_instances: Optional[List[Union[str, PipelineVariable]]] = None,
159159
model_package_name: Optional[Union[str, PipelineVariable]] = None,

src/sagemaker/session.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -5830,8 +5830,8 @@ def wait_for_inference_recommendations_job(
58305830

58315831

58325832
def get_model_package_args(
5833-
content_types,
5834-
response_types,
5833+
content_types=None,
5834+
response_types=None,
58355835
inference_instances=None,
58365836
transform_instances=None,
58375837
model_package_name=None,
@@ -5899,19 +5899,23 @@ def get_model_package_args(
58995899
else:
59005900
container = {
59015901
"Image": image_uri,
5902-
"ModelDataUrl": model_data,
59035902
}
5903+
if model_data is not None:
5904+
container["ModelDataUrl"] = model_data
5905+
59045906
containers = [container]
59055907

59065908
model_package_args = {
59075909
"containers": containers,
5908-
"content_types": content_types,
5909-
"response_types": response_types,
59105910
"inference_instances": inference_instances,
59115911
"transform_instances": transform_instances,
59125912
"marketplace_cert": marketplace_cert,
59135913
}
59145914

5915+
if content_types is not None:
5916+
model_package_args["content_types"] = content_types
5917+
if response_types is not None:
5918+
model_package_args["response_types"] = response_types
59155919
if model_package_name is not None:
59165920
model_package_args["model_package_name"] = model_package_name
59175921
if model_package_group_name is not None:

src/sagemaker/sklearn/model.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -145,8 +145,8 @@ def __init__(
145145

146146
def register(
147147
self,
148-
content_types: List[Union[str, PipelineVariable]],
149-
response_types: List[Union[str, PipelineVariable]],
148+
content_types: List[Union[str, PipelineVariable]] = None,
149+
response_types: List[Union[str, PipelineVariable]] = None,
150150
inference_instances: Optional[List[Union[str, PipelineVariable]]] = None,
151151
transform_instances: Optional[List[Union[str, PipelineVariable]]] = None,
152152
model_package_name: Optional[Union[str, PipelineVariable]] = None,

src/sagemaker/tensorflow/model.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -207,8 +207,8 @@ def __init__(
207207

208208
def register(
209209
self,
210-
content_types: List[Union[str, PipelineVariable]],
211-
response_types: List[Union[str, PipelineVariable]],
210+
content_types: List[Union[str, PipelineVariable]] = None,
211+
response_types: List[Union[str, PipelineVariable]] = None,
212212
inference_instances: Optional[List[Union[str, PipelineVariable]]] = None,
213213
transform_instances: Optional[List[Union[str, PipelineVariable]]] = None,
214214
model_package_name: Optional[Union[str, PipelineVariable]] = None,

src/sagemaker/workflow/_utils.py

-3
Original file line numberDiff line numberDiff line change
@@ -443,9 +443,6 @@ def arguments(self) -> RequestType:
443443
model = self.estimator.create_model(**self.kwargs)
444444
self.image_uri = model.image_uri
445445

446-
if self.model_data is None:
447-
self.model_data = model.model_data
448-
449446
# reset placeholder
450447
self.estimator.output_path = output_path
451448

src/sagemaker/xgboost/model.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,8 @@ def __init__(
133133

134134
def register(
135135
self,
136-
content_types: List[Union[str, PipelineVariable]],
137-
response_types: List[Union[str, PipelineVariable]],
136+
content_types: List[Union[str, PipelineVariable]] = None,
137+
response_types: List[Union[str, PipelineVariable]] = None,
138138
inference_instances: Optional[List[Union[str, PipelineVariable]]] = None,
139139
transform_instances: Optional[List[Union[str, PipelineVariable]]] = None,
140140
model_package_name: Optional[Union[str, PipelineVariable]] = None,

tests/integ/test_model_package.py

+87
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import os
16+
from sagemaker.model_card.schema_constraints import ModelApprovalStatusEnum
17+
from sagemaker.utils import unique_name_from_base
18+
from sagemaker.model import ModelPackage
19+
from tests.integ import DATA_DIR
20+
from sagemaker.xgboost import XGBoostModel
21+
22+
_XGBOOST_PATH = os.path.join(DATA_DIR, "xgboost_abalone")
23+
24+
25+
def test_add_remove_model_groups_in_collection_success(sagemaker_session):
26+
27+
model_group_name = unique_name_from_base("test-model-group")
28+
29+
sagemaker_session.sagemaker_client.create_model_package_group(
30+
ModelPackageGroupName=model_group_name
31+
)
32+
33+
xgb_model_data_s3 = sagemaker_session.upload_data(
34+
path=os.path.join(_XGBOOST_PATH, "xgb_model.tar.gz"),
35+
key_prefix="integ-test-data/xgboost/model",
36+
)
37+
model = XGBoostModel(
38+
model_data=xgb_model_data_s3, framework_version="1.3-1", sagemaker_session=sagemaker_session
39+
)
40+
image_uri = model.serving_image_uri(
41+
region_name=sagemaker_session.boto_session.region_name,
42+
instance_type="ml.p2.xlarge",
43+
)
44+
create_model_package_input_dict = {
45+
"ModelPackageGroupName": model_group_name,
46+
"ModelPackageDescription": "Test model package registered for integ test",
47+
"ModelApprovalStatus": ModelApprovalStatusEnum.PENDING_MANUAL_APPROVAL,
48+
"InferenceSpecification": {
49+
"Containers": [{"Image": image_uri}],
50+
"SupportedContentTypes": ["text/csv"],
51+
"SupportedResponseMIMETypes": ["text/csv"],
52+
"SupportedRealtimeInferenceInstanceTypes": ["ml.m5.large"],
53+
"SupportedTransformInstanceTypes": ["ml.m5.large"],
54+
},
55+
}
56+
57+
create_model_package_resp = sagemaker_session.sagemaker_client.create_model_package(
58+
**create_model_package_input_dict
59+
)
60+
model_package_arn = create_model_package_resp["ModelPackageArn"]
61+
desc_model_package = sagemaker_session.sagemaker_client.describe_model_package(
62+
ModelPackageName=model_package_arn
63+
)
64+
model_package_arn = desc_model_package["ModelPackageArn"]
65+
approval_status = desc_model_package["ModelApprovalStatus"]
66+
67+
assert approval_status == ModelApprovalStatusEnum.PENDING_MANUAL_APPROVAL
68+
69+
model_package = ModelPackage(
70+
model_package_arn=model_package_arn,
71+
sagemaker_session=sagemaker_session,
72+
)
73+
74+
model_package.update_approval_status(
75+
approval_status=ModelApprovalStatusEnum.APPROVED, approval_description="dummy"
76+
)
77+
78+
desc_model_package = sagemaker_session.sagemaker_client.describe_model_package(
79+
ModelPackageName=model_package_arn
80+
)
81+
assert desc_model_package["ModelApprovalStatus"] == ModelApprovalStatusEnum.APPROVED
82+
assert desc_model_package["ApprovalDescription"] == "dummy"
83+
84+
sagemaker_session.sagemaker_client.delete_model_package(ModelPackageName=model_package_arn)
85+
sagemaker_session.sagemaker_client.delete_model_package_group(
86+
ModelPackageGroupName=model_group_name
87+
)

tests/unit/sagemaker/model/test_model.py

-1
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@
5454
SCRIPT_URI = "s3://codebucket/someprefix/sourcedir.tar.gz"
5555
IMAGE_URI = "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:1.9.0-gpu-py38"
5656

57-
5857
MODEL_DESCRIPTION = "a description"
5958

6059
SUPPORTED_REALTIME_INFERENCE_INSTANCE_TYPES = ["ml.m4.xlarge"]

0 commit comments

Comments
 (0)