Skip to content

Commit 91ff059

Browse files
authored
Merge branch 'dev' into feat/override-jumpstart-content-bucket
2 parents c1a23c6 + 21ac6fd commit 91ff059

File tree

10 files changed

+62
-1
lines changed

10 files changed

+62
-1
lines changed

src/sagemaker/estimator.py

+4
Original file line numberDiff line numberDiff line change
@@ -1263,6 +1263,7 @@ def register(
12631263
compile_model_family=None,
12641264
model_name=None,
12651265
drift_check_baselines=None,
1266+
customer_metadata_properties=None,
12661267
**kwargs,
12671268
):
12681269
"""Creates a model package for creating SageMaker models or listing on Marketplace.
@@ -1292,6 +1293,8 @@ def register(
12921293
model will be used (default: None).
12931294
model_name (str): User defined model name (default: None).
12941295
drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
1296+
customer_metadata_properties (dict[str, str]): A dictionary of key-value paired
1297+
metadata properties (default: None).
12951298
**kwargs: Passed to invocation of ``create_model()``. Implementations may customize
12961299
``create_model()`` to accept ``**kwargs`` to customize model creation during
12971300
deploy. For more, see the implementation docs.
@@ -1322,6 +1325,7 @@ def register(
13221325
approval_status,
13231326
description,
13241327
drift_check_baselines=drift_check_baselines,
1328+
customer_metadata_properties=customer_metadata_properties,
13251329
)
13261330

13271331
@property

src/sagemaker/model.py

+4
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,7 @@ def register(
303303
approval_status=None,
304304
description=None,
305305
drift_check_baselines=None,
306+
customer_metadata_properties=None,
306307
):
307308
"""Creates a model package for creating SageMaker models or listing on Marketplace.
308309
@@ -328,6 +329,8 @@ def register(
328329
or "PendingManualApproval" (default: "PendingManualApproval").
329330
description (str): Model Package description (default: None).
330331
drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
332+
customer_metadata_properties (dict[str, str]): A dictionary of key-value paired
333+
metadata properties (default: None).
331334
332335
Returns:
333336
A `sagemaker.model.ModelPackage` instance.
@@ -355,6 +358,7 @@ def register(
355358
description=description,
356359
container_def_list=[container_def],
357360
drift_check_baselines=drift_check_baselines,
361+
customer_metadata_properties=customer_metadata_properties,
358362
)
359363
model_package = self.sagemaker_session.create_model_package_from_containers(
360364
**model_pkg_args

src/sagemaker/mxnet/model.py

+4
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ def register(
158158
approval_status=None,
159159
description=None,
160160
drift_check_baselines=None,
161+
customer_metadata_properties=None,
161162
):
162163
"""Creates a model package for creating SageMaker models or listing on Marketplace.
163164
@@ -183,6 +184,8 @@ def register(
183184
or "PendingManualApproval" (default: "PendingManualApproval").
184185
description (str): Model Package description (default: None).
185186
drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
187+
customer_metadata_properties (dict[str, str]): A dictionary of key-value paired
188+
metadata properties (default: None).
186189
187190
Returns:
188191
A `sagemaker.model.ModelPackage` instance.
@@ -211,6 +214,7 @@ def register(
211214
approval_status,
212215
description,
213216
drift_check_baselines=drift_check_baselines,
217+
customer_metadata_properties=customer_metadata_properties,
214218
)
215219

216220
def prepare_container_def(self, instance_type=None, accelerator_type=None):

src/sagemaker/pytorch/model.py

+4
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ def register(
157157
approval_status=None,
158158
description=None,
159159
drift_check_baselines=None,
160+
customer_metadata_properties=None,
160161
):
161162
"""Creates a model package for creating SageMaker models or listing on Marketplace.
162163
@@ -182,6 +183,8 @@ def register(
182183
or "PendingManualApproval" (default: "PendingManualApproval").
183184
description (str): Model Package description (default: None).
184185
drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
186+
customer_metadata_properties (dict[str, str]): A dictionary of key-value paired
187+
metadata properties (default: None).
185188
186189
Returns:
187190
A `sagemaker.model.ModelPackage` instance.
@@ -210,6 +213,7 @@ def register(
210213
approval_status,
211214
description,
212215
drift_check_baselines=drift_check_baselines,
216+
customer_metadata_properties=customer_metadata_properties,
213217
)
214218

215219
def prepare_container_def(self, instance_type=None, accelerator_type=None):

src/sagemaker/session.py

+24
Original file line numberDiff line numberDiff line change
@@ -2778,6 +2778,7 @@ def create_model_package_from_containers(
27782778
approval_status="PendingManualApproval",
27792779
description=None,
27802780
drift_check_baselines=None,
2781+
customer_metadata_properties=None,
27812782
):
27822783
"""Get request dictionary for CreateModelPackage API.
27832784
@@ -2803,6 +2804,9 @@ def create_model_package_from_containers(
28032804
or "PendingManualApproval" (default: "PendingManualApproval").
28042805
description (str): Model Package description (default: None).
28052806
drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
2807+
customer_metadata_properties (dict[str, str]): A dictionary of key-value paired
2808+
metadata properties (default: None).
2809+
28062810
"""
28072811

28082812
request = get_create_model_package_request(
@@ -2819,7 +2823,17 @@ def create_model_package_from_containers(
28192823
approval_status,
28202824
description,
28212825
drift_check_baselines=drift_check_baselines,
2826+
customer_metadata_properties=customer_metadata_properties,
28222827
)
2828+
if model_package_group_name is not None:
2829+
try:
2830+
self.sagemaker_client.describe_model_package_group(
2831+
ModelPackageGroupName=request["ModelPackageGroupName"]
2832+
)
2833+
except ClientError:
2834+
self.sagemaker_client.create_model_package_group(
2835+
ModelPackageGroupName=request["ModelPackageGroupName"]
2836+
)
28232837
return self.sagemaker_client.create_model_package(**request)
28242838

28252839
def wait_for_model_package(self, model_package_name, poll=5):
@@ -4120,6 +4134,7 @@ def get_model_package_args(
41204134
tags=None,
41214135
container_def_list=None,
41224136
drift_check_baselines=None,
4137+
customer_metadata_properties=None,
41234138
):
41244139
"""Get arguments for create_model_package method.
41254140
@@ -4148,6 +4163,8 @@ def get_model_package_args(
41484163
(default: None).
41494164
container_def_list (list): A list of container defintiions (default: None).
41504165
drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
4166+
customer_metadata_properties (dict[str, str]): A dictionary of key-value paired
4167+
metadata properties (default: None).
41514168
Returns:
41524169
dict: A dictionary of method argument names and values.
41534170
"""
@@ -4185,6 +4202,8 @@ def get_model_package_args(
41854202
model_package_args["description"] = description
41864203
if tags is not None:
41874204
model_package_args["tags"] = tags
4205+
if customer_metadata_properties is not None:
4206+
model_package_args["customer_metadata_properties"] = customer_metadata_properties
41884207
return model_package_args
41894208

41904209

@@ -4203,6 +4222,7 @@ def get_create_model_package_request(
42034222
description=None,
42044223
tags=None,
42054224
drift_check_baselines=None,
4225+
customer_metadata_properties=None,
42064226
):
42074227
"""Get request dictionary for CreateModelPackage API.
42084228
@@ -4229,6 +4249,8 @@ def get_create_model_package_request(
42294249
tags (List[dict[str, str]]): A list of dictionaries containing key-value pairs
42304250
(default: None).
42314251
drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
4252+
customer_metadata_properties (dict[str, str]): A dictionary of key-value paired
4253+
metadata properties (default: None).
42324254
"""
42334255

42344256
if all([model_package_name, model_package_group_name]):
@@ -4250,6 +4272,8 @@ def get_create_model_package_request(
42504272
request_dict["DriftCheckBaselines"] = drift_check_baselines
42514273
if metadata_properties:
42524274
request_dict["MetadataProperties"] = metadata_properties
4275+
if customer_metadata_properties is not None:
4276+
request_dict["CustomerMetadataProperties"] = customer_metadata_properties
42534277
if containers is not None:
42544278
if not all([content_types, response_types, inference_instances, transform_instances]):
42554279
raise ValueError(

src/sagemaker/tensorflow/model.py

+5
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,7 @@ def register(
201201
approval_status=None,
202202
description=None,
203203
drift_check_baselines=None,
204+
customer_metadata_properties=None,
204205
):
205206
"""Creates a model package for creating SageMaker models or listing on Marketplace.
206207
@@ -226,6 +227,9 @@ def register(
226227
or "PendingManualApproval" (default: "PendingManualApproval").
227228
description (str): Model Package description (default: None).
228229
drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
230+
customer_metadata_properties (dict[str, str]): A dictionary of key-value paired
231+
metadata properties (default: None).
232+
229233
230234
Returns:
231235
A `sagemaker.model.ModelPackage` instance.
@@ -254,6 +258,7 @@ def register(
254258
approval_status,
255259
description,
256260
drift_check_baselines=drift_check_baselines,
261+
customer_metadata_properties=customer_metadata_properties,
257262
)
258263

259264
def deploy(

src/sagemaker/workflow/_utils.py

+5
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,7 @@ def __init__(
310310
tags=None,
311311
container_def_list=None,
312312
drift_check_baselines=None,
313+
customer_metadata_properties=None,
313314
**kwargs,
314315
):
315316
"""Constructor of a register model step.
@@ -347,6 +348,8 @@ def __init__(
347348
this step depends on
348349
retry_policies (List[RetryPolicy]): The list of retry policies for the current step
349350
drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
351+
customer_metadata_properties (dict[str, str]): A dictionary of key-value paired
352+
metadata properties (default: None).
350353
**kwargs: additional arguments to `create_model`.
351354
"""
352355
super(_RegisterModelStep, self).__init__(
@@ -362,6 +365,7 @@ def __init__(
362365
self.tags = tags
363366
self.model_metrics = model_metrics
364367
self.drift_check_baselines = drift_check_baselines
368+
self.customer_metadata_properties = customer_metadata_properties
365369
self.metadata_properties = metadata_properties
366370
self.approval_status = approval_status
367371
self.image_uri = image_uri
@@ -435,6 +439,7 @@ def arguments(self) -> RequestType:
435439
description=self.description,
436440
tags=self.tags,
437441
container_def_list=self.container_def_list,
442+
customer_metadata_properties=self.customer_metadata_properties,
438443
)
439444

440445
request_dict = get_create_model_package_request(**model_package_args)

src/sagemaker/workflow/step_collections.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ def __init__(
7575
tags=None,
7676
model: Union[Model, PipelineModel] = None,
7777
drift_check_baselines=None,
78+
customer_metadata_properties=None,
7879
**kwargs,
7980
):
8081
"""Construct steps `_RepackModelStep` and `_RegisterModelStep` based on the estimator.
@@ -95,7 +96,7 @@ def __init__(
9596
for the repack model step
9697
register_model_step_retry_policies (List[RetryPolicy]): The list of retry policies
9798
for register model step
98-
model_package_group_name (str): The Model Package Group name, exclusive to
99+
model_package_group_name (str): The Model Package Group name or Arn, exclusive to
99100
`model_package_name`, using `model_package_group_name` makes the Model Package
100101
versioned (default: None).
101102
model_metrics (ModelMetrics): ModelMetrics object (default: None).
@@ -113,6 +114,9 @@ def __init__(
113114
model (object or Model): A PipelineModel object that comprises a list of models
114115
which gets executed as a serial inference pipeline or a Model object.
115116
drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
117+
customer_metadata_properties (dict[str, str]): A dictionary of key-value paired
118+
metadata properties (default: None).
119+
116120
**kwargs: additional arguments to `create_model`.
117121
"""
118122
steps: List[Step] = []
@@ -229,6 +233,7 @@ def __init__(
229233
tags=tags,
230234
container_def_list=self.container_def_list,
231235
retry_policies=register_model_step_retry_policies,
236+
customer_metadata_properties=customer_metadata_properties,
232237
**kwargs,
233238
)
234239
if not repack_model:

tests/integ/test_workflow.py

+3
Original file line numberDiff line numberDiff line change
@@ -1952,6 +1952,7 @@ def test_model_registration_with_drift_check_baselines(
19521952
content_type="application/json",
19531953
),
19541954
)
1955+
customer_metadata_properties = {"key1": "value1"}
19551956
estimator = XGBoost(
19561957
entry_point="training.py",
19571958
source_dir=os.path.join(DATA_DIR, "sip"),
@@ -1973,6 +1974,7 @@ def test_model_registration_with_drift_check_baselines(
19731974
model_package_group_name="testModelPackageGroup",
19741975
model_metrics=model_metrics,
19751976
drift_check_baselines=drift_check_baselines,
1977+
customer_metadata_properties=customer_metadata_properties,
19761978
)
19771979

19781980
pipeline = Pipeline(
@@ -2043,6 +2045,7 @@ def test_model_registration_with_drift_check_baselines(
20432045
response["DriftCheckBaselines"]["ModelDataQuality"]["Statistics"]["ContentType"]
20442046
== "application/json"
20452047
)
2048+
assert response["CustomerMetadataProperties"] == customer_metadata_properties
20462049
break
20472050
finally:
20482051
try:

tests/unit/test_session.py

+3
Original file line numberDiff line numberDiff line change
@@ -2385,6 +2385,7 @@ def test_create_model_package_from_containers_all_args(sagemaker_session):
23852385
marketplace_cert = (True,)
23862386
approval_status = ("Approved",)
23872387
description = "description"
2388+
customer_metadata_properties = {"key1": "value1"}
23882389
sagemaker_session.create_model_package_from_containers(
23892390
containers=containers,
23902391
content_types=content_types,
@@ -2398,6 +2399,7 @@ def test_create_model_package_from_containers_all_args(sagemaker_session):
23982399
approval_status=approval_status,
23992400
description=description,
24002401
drift_check_baselines=drift_check_baselines,
2402+
customer_metadata_properties=customer_metadata_properties,
24012403
)
24022404
expected_args = {
24032405
"ModelPackageName": model_package_name,
@@ -2414,6 +2416,7 @@ def test_create_model_package_from_containers_all_args(sagemaker_session):
24142416
"CertifyForMarketplace": marketplace_cert,
24152417
"ModelApprovalStatus": approval_status,
24162418
"DriftCheckBaselines": drift_check_baselines,
2419+
"CustomerMetadataProperties": customer_metadata_properties,
24172420
}
24182421
sagemaker_session.sagemaker_client.create_model_package.assert_called_with(**expected_args)
24192422

0 commit comments

Comments
 (0)