Skip to content

Commit 0872dae

Browse files
authored
Merge branch 'master' into feat/s3-prefix-model-data-for-jumpstart-model
2 parents 770d6a7 + e8c42ae commit 0872dae

File tree

15 files changed

+189
-34
lines changed

15 files changed

+189
-34
lines changed

src/sagemaker/chainer/model.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -148,8 +148,8 @@ def __init__(
148148

149149
def register(
150150
self,
151-
content_types: List[Union[str, PipelineVariable]],
152-
response_types: List[Union[str, PipelineVariable]],
151+
content_types: List[Union[str, PipelineVariable]] = None,
152+
response_types: List[Union[str, PipelineVariable]] = None,
153153
inference_instances: Optional[List[Union[str, PipelineVariable]]] = None,
154154
transform_instances: Optional[List[Union[str, PipelineVariable]]] = None,
155155
model_package_name: Optional[Union[str, PipelineVariable]] = None,

src/sagemaker/estimator.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1665,8 +1665,8 @@ def deploy(
16651665

16661666
def register(
16671667
self,
1668-
content_types,
1669-
response_types,
1668+
content_types=None,
1669+
response_types=None,
16701670
inference_instances=None,
16711671
transform_instances=None,
16721672
image_uri=None,

src/sagemaker/huggingface/model.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -332,8 +332,8 @@ def deploy(
332332

333333
def register(
334334
self,
335-
content_types: List[Union[str, PipelineVariable]],
336-
response_types: List[Union[str, PipelineVariable]],
335+
content_types: List[Union[str, PipelineVariable]] = None,
336+
response_types: List[Union[str, PipelineVariable]] = None,
337337
inference_instances: Optional[List[Union[str, PipelineVariable]]] = None,
338338
transform_instances: Optional[List[Union[str, PipelineVariable]]] = None,
339339
model_package_name: Optional[Union[str, PipelineVariable]] = None,

src/sagemaker/model.py

+69-7
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
ENDPOINT_CONFIG_ASYNC_KMS_KEY_ID_PATH,
4444
load_sagemaker_config,
4545
)
46+
from sagemaker.model_card.schema_constraints import ModelApprovalStatusEnum
4647
from sagemaker.session import Session
4748
from sagemaker.model_metrics import ModelMetrics
4849
from sagemaker.deprecations import removed_kwargs
@@ -374,12 +375,14 @@ def __init__(
374375
self.dependencies = updates["dependencies"]
375376
self.uploaded_code = None
376377
self.repacked_model_data = None
378+
self.content_types = None
379+
self.response_types = None
377380

378381
@runnable_by_pipeline
379382
def register(
380383
self,
381-
content_types: List[Union[str, PipelineVariable]],
382-
response_types: List[Union[str, PipelineVariable]],
384+
content_types: List[Union[str, PipelineVariable]] = None,
385+
response_types: List[Union[str, PipelineVariable]] = None,
383386
inference_instances: Optional[List[Union[str, PipelineVariable]]] = None,
384387
transform_instances: Optional[List[Union[str, PipelineVariable]]] = None,
385388
model_package_name: Optional[Union[str, PipelineVariable]] = None,
@@ -456,16 +459,33 @@ def register(
456459
in case the Model instance is built with
457460
:class:`~sagemaker.workflow.pipeline_context.PipelineSession`
458461
"""
459-
if self.model_data is None:
460-
raise ValueError("SageMaker Model Package cannot be created without model data.")
461462
if isinstance(self.model_data, dict):
462463
raise ValueError(
463464
"SageMaker Model Package currently cannot be created with ModelDataSource."
464465
)
465466

467+
if content_types is not None:
468+
self.content_types = content_types
469+
470+
if response_types is not None:
471+
self.response_types = response_types
472+
473+
if self.content_types is None:
474+
raise ValueError("The supported MIME types for the input data is not set")
475+
476+
if self.response_types is None:
477+
raise ValueError("The supported MIME types for the output data is not set")
478+
466479
if image_uri is not None:
467480
self.image_uri = image_uri
468481

482+
if model_package_group_name is None and model_package_name is None:
483+
# If model package group and model package name is not set
484+
# then register to auto-generated model package group
485+
model_package_group_name = utils.base_name_from_image(
486+
self.image_uri, default_base_name=ModelPackage.__name__
487+
)
488+
469489
if model_package_group_name is not None:
470490
container_def = self.prepare_container_def()
471491
container_def = update_container_with_inference_params(
@@ -478,12 +498,14 @@ def register(
478498
else:
479499
container_def = {
480500
"Image": self.image_uri,
481-
"ModelDataUrl": self.model_data,
482501
}
483502

503+
if self.model_data is not None:
504+
container_def["ModelDataUrl"] = self.model_data
505+
484506
model_pkg_args = sagemaker.get_model_package_args(
485-
content_types,
486-
response_types,
507+
self.content_types,
508+
self.response_types,
487509
inference_instances=inference_instances,
488510
transform_instances=transform_instances,
489511
model_package_name=model_package_name,
@@ -511,6 +533,7 @@ def register(
511533
role=self.role,
512534
model_data=self.model_data,
513535
model_package_arn=model_package.get("ModelPackageArn"),
536+
sagemaker_session=self.sagemaker_session,
514537
)
515538

516539
@runnable_by_pipeline
@@ -1751,6 +1774,7 @@ def __init__(
17511774

17521775
# works for MODEL_PACKAGE_ARN with or without version info.
17531776
MODEL_PACKAGE_ARN_PATTERN = r"arn:aws:sagemaker:(.*?):(.*?):model-package/(.*?)(?:/(\d+))?$"
1777+
MODEL_PACKAGE_VERSIONED_ARN_PATTERN = r"arn:aws:sagemaker:(.*?):(.*?):model-package/(.*?)/(\d+)$"
17541778

17551779

17561780
class ModelPackage(Model):
@@ -1885,6 +1909,18 @@ def _create_sagemaker_model(self, *args, **kwargs): # pylint: disable=unused-ar
18851909
self._ensure_base_name_if_needed(model_package_name)
18861910
self._set_model_name_if_needed()
18871911

1912+
# Quering the approval status for the model package
1913+
# Approving the versioned model package in case it is not approved
1914+
model_package_desc = self.sagemaker_session.sagemaker_client.describe_model_package(
1915+
ModelPackageName=self.model_package_arn or model_package_name
1916+
)
1917+
if self.model_package_arn is None:
1918+
self.model_package_arn = model_package_desc["ModelPackageArn"]
1919+
if re.match(MODEL_PACKAGE_VERSIONED_ARN_PATTERN, self.model_package_arn):
1920+
approval_status = model_package_desc.get("ModelApprovalStatus", "")
1921+
if approval_status != ModelApprovalStatusEnum.APPROVED:
1922+
self.update_approval_status(approval_status=ModelApprovalStatusEnum.APPROVED)
1923+
18881924
self.sagemaker_session.create_model(
18891925
self.name,
18901926
self.role,
@@ -1898,3 +1934,29 @@ def _ensure_base_name_if_needed(self, base_name):
18981934
"""Set the base name if there is no model name provided."""
18991935
if self.name is None:
19001936
self._base_name = base_name
1937+
1938+
def update_approval_status(self, approval_status, approval_description=None):
1939+
"""Update the approval status for the model package
1940+
1941+
Args:
1942+
approval_status (str or PipelineVariable): Model Approval Status, values can be
1943+
"Approved", "Rejected", or "PendingManualApproval".
1944+
approval_description (str): Optional. Description for the approval status of the model
1945+
(default: None).
1946+
"""
1947+
1948+
# Models can lazy-init sagemaker_session until deploy() is called to support
1949+
# LocalMode so we must make sure we have an actual session
1950+
sagemaker_session = self.sagemaker_session or sagemaker.Session()
1951+
if self.model_package_arn is None:
1952+
raise ValueError("model_package_arn is required to update the status.")
1953+
1954+
update_approval_args = {
1955+
"ModelPackageArn": self.model_package_arn,
1956+
"ModelApprovalStatus": approval_status,
1957+
}
1958+
1959+
if approval_description is not None:
1960+
update_approval_args["ApprovalDescription"] = approval_description
1961+
1962+
sagemaker_session.sagemaker_client.update_model_package(**update_approval_args)

src/sagemaker/mxnet/model.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,8 @@ def __init__(
150150

151151
def register(
152152
self,
153-
content_types: List[Union[str, PipelineVariable]],
154-
response_types: List[Union[str, PipelineVariable]],
153+
content_types: List[Union[str, PipelineVariable]] = None,
154+
response_types: List[Union[str, PipelineVariable]] = None,
155155
inference_instances: Optional[List[Union[str, PipelineVariable]]] = None,
156156
transform_instances: Optional[List[Union[str, PipelineVariable]]] = None,
157157
model_package_name: Optional[Union[str, PipelineVariable]] = None,

src/sagemaker/pipeline.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -335,8 +335,8 @@ def _create_sagemaker_pipeline_model(self, instance_type):
335335
@runnable_by_pipeline
336336
def register(
337337
self,
338-
content_types: List[Union[str, PipelineVariable]],
339-
response_types: List[Union[str, PipelineVariable]],
338+
content_types: List[Union[str, PipelineVariable]] = None,
339+
response_types: List[Union[str, PipelineVariable]] = None,
340340
inference_instances: Optional[List[Union[str, PipelineVariable]]] = None,
341341
transform_instances: Optional[List[Union[str, PipelineVariable]]] = None,
342342
model_package_name: Optional[Union[str, PipelineVariable]] = None,

src/sagemaker/pytorch/model.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -152,8 +152,8 @@ def __init__(
152152

153153
def register(
154154
self,
155-
content_types: List[Union[str, PipelineVariable]],
156-
response_types: List[Union[str, PipelineVariable]],
155+
content_types: List[Union[str, PipelineVariable]] = None,
156+
response_types: List[Union[str, PipelineVariable]] = None,
157157
inference_instances: Optional[List[Union[str, PipelineVariable]]] = None,
158158
transform_instances: Optional[List[Union[str, PipelineVariable]]] = None,
159159
model_package_name: Optional[Union[str, PipelineVariable]] = None,

src/sagemaker/session.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -5830,8 +5830,8 @@ def wait_for_inference_recommendations_job(
58305830

58315831

58325832
def get_model_package_args(
5833-
content_types,
5834-
response_types,
5833+
content_types=None,
5834+
response_types=None,
58355835
inference_instances=None,
58365836
transform_instances=None,
58375837
model_package_name=None,
@@ -5899,19 +5899,23 @@ def get_model_package_args(
58995899
else:
59005900
container = {
59015901
"Image": image_uri,
5902-
"ModelDataUrl": model_data,
59035902
}
5903+
if model_data is not None:
5904+
container["ModelDataUrl"] = model_data
5905+
59045906
containers = [container]
59055907

59065908
model_package_args = {
59075909
"containers": containers,
5908-
"content_types": content_types,
5909-
"response_types": response_types,
59105910
"inference_instances": inference_instances,
59115911
"transform_instances": transform_instances,
59125912
"marketplace_cert": marketplace_cert,
59135913
}
59145914

5915+
if content_types is not None:
5916+
model_package_args["content_types"] = content_types
5917+
if response_types is not None:
5918+
model_package_args["response_types"] = response_types
59155919
if model_package_name is not None:
59165920
model_package_args["model_package_name"] = model_package_name
59175921
if model_package_group_name is not None:

src/sagemaker/sklearn/model.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -145,8 +145,8 @@ def __init__(
145145

146146
def register(
147147
self,
148-
content_types: List[Union[str, PipelineVariable]],
149-
response_types: List[Union[str, PipelineVariable]],
148+
content_types: List[Union[str, PipelineVariable]] = None,
149+
response_types: List[Union[str, PipelineVariable]] = None,
150150
inference_instances: Optional[List[Union[str, PipelineVariable]]] = None,
151151
transform_instances: Optional[List[Union[str, PipelineVariable]]] = None,
152152
model_package_name: Optional[Union[str, PipelineVariable]] = None,

src/sagemaker/tensorflow/model.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -207,8 +207,8 @@ def __init__(
207207

208208
def register(
209209
self,
210-
content_types: List[Union[str, PipelineVariable]],
211-
response_types: List[Union[str, PipelineVariable]],
210+
content_types: List[Union[str, PipelineVariable]] = None,
211+
response_types: List[Union[str, PipelineVariable]] = None,
212212
inference_instances: Optional[List[Union[str, PipelineVariable]]] = None,
213213
transform_instances: Optional[List[Union[str, PipelineVariable]]] = None,
214214
model_package_name: Optional[Union[str, PipelineVariable]] = None,

src/sagemaker/workflow/_utils.py

-3
Original file line numberDiff line numberDiff line change
@@ -443,9 +443,6 @@ def arguments(self) -> RequestType:
443443
model = self.estimator.create_model(**self.kwargs)
444444
self.image_uri = model.image_uri
445445

446-
if self.model_data is None:
447-
self.model_data = model.model_data
448-
449446
# reset placeholder
450447
self.estimator.output_path = output_path
451448

src/sagemaker/xgboost/model.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,8 @@ def __init__(
133133

134134
def register(
135135
self,
136-
content_types: List[Union[str, PipelineVariable]],
137-
response_types: List[Union[str, PipelineVariable]],
136+
content_types: List[Union[str, PipelineVariable]] = None,
137+
response_types: List[Union[str, PipelineVariable]] = None,
138138
inference_instances: Optional[List[Union[str, PipelineVariable]]] = None,
139139
transform_instances: Optional[List[Union[str, PipelineVariable]]] = None,
140140
model_package_name: Optional[Union[str, PipelineVariable]] = None,

tests/integ/test_model_package.py

+63
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import os
16+
from sagemaker.model_card.schema_constraints import ModelApprovalStatusEnum
17+
from sagemaker.utils import unique_name_from_base
18+
from tests.integ import DATA_DIR
19+
from sagemaker.xgboost import XGBoostModel
20+
21+
_XGBOOST_PATH = os.path.join(DATA_DIR, "xgboost_abalone")
22+
23+
24+
def test_update_approval_model_package(sagemaker_session):
25+
26+
model_group_name = unique_name_from_base("test-model-group")
27+
28+
sagemaker_session.sagemaker_client.create_model_package_group(
29+
ModelPackageGroupName=model_group_name
30+
)
31+
32+
xgb_model_data_s3 = sagemaker_session.upload_data(
33+
path=os.path.join(_XGBOOST_PATH, "xgb_model.tar.gz"),
34+
key_prefix="integ-test-data/xgboost/model",
35+
)
36+
model = XGBoostModel(
37+
model_data=xgb_model_data_s3, framework_version="1.3-1", sagemaker_session=sagemaker_session
38+
)
39+
40+
model_package = model.register(
41+
content_types=["text/csv"],
42+
response_types=["text/csv"],
43+
inference_instances=["ml.m5.large"],
44+
transform_instances=["ml.m5.large"],
45+
model_package_group_name=model_group_name,
46+
)
47+
48+
model_package.update_approval_status(
49+
approval_status=ModelApprovalStatusEnum.APPROVED, approval_description="dummy"
50+
)
51+
52+
desc_model_package = sagemaker_session.sagemaker_client.describe_model_package(
53+
ModelPackageName=model_package.model_package_arn
54+
)
55+
assert desc_model_package["ModelApprovalStatus"] == ModelApprovalStatusEnum.APPROVED
56+
assert desc_model_package["ApprovalDescription"] == "dummy"
57+
58+
sagemaker_session.sagemaker_client.delete_model_package(
59+
ModelPackageName=model_package.model_package_arn
60+
)
61+
sagemaker_session.sagemaker_client.delete_model_package_group(
62+
ModelPackageGroupName=model_group_name
63+
)

tests/unit/sagemaker/model/test_model.py

-1
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@
5454
SCRIPT_URI = "s3://codebucket/someprefix/sourcedir.tar.gz"
5555
IMAGE_URI = "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:1.9.0-gpu-py38"
5656

57-
5857
MODEL_DESCRIPTION = "a description"
5958

6059
SUPPORTED_REALTIME_INFERENCE_INSTANCE_TYPES = ["ml.m4.xlarge"]

0 commit comments

Comments
 (0)