Skip to content

Commit 747c239

Browse files
committed
feature: adding customer metadata support to registermodel step
1 parent bb7563f commit 747c239

File tree

11 files changed

+100
-1
lines changed

11 files changed

+100
-1
lines changed

src/sagemaker/estimator.py

+4
Original file line numberDiff line numberDiff line change
@@ -1263,6 +1263,7 @@ def register(
12631263
compile_model_family=None,
12641264
model_name=None,
12651265
drift_check_baselines=None,
1266+
customer_metadata_properties=None,
12661267
**kwargs,
12671268
):
12681269
"""Creates a model package for creating SageMaker models or listing on Marketplace.
@@ -1292,6 +1293,8 @@ def register(
12921293
model will be used (default: None).
12931294
model_name (str): User defined model name (default: None).
12941295
drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
1296+
customer_metadata_properties (dict[str, str]): A dictionary of key-value paired
1297+
metadata properties (default: None).
12951298
**kwargs: Passed to invocation of ``create_model()``. Implementations may customize
12961299
``create_model()`` to accept ``**kwargs`` to customize model creation during
12971300
deploy. For more, see the implementation docs.
@@ -1322,6 +1325,7 @@ def register(
13221325
approval_status,
13231326
description,
13241327
drift_check_baselines=drift_check_baselines,
1328+
customer_metadata_properties=customer_metadata_properties,
13251329
)
13261330

13271331
@property

src/sagemaker/model.py

+4
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,7 @@ def register(
304304
approval_status=None,
305305
description=None,
306306
drift_check_baselines=None,
307+
customer_metadata_properties=None,
307308
):
308309
"""Creates a model package for creating SageMaker models or listing on Marketplace.
309310
@@ -329,6 +330,8 @@ def register(
329330
or "PendingManualApproval" (default: "PendingManualApproval").
330331
description (str): Model Package description (default: None).
331332
drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
333+
customer_metadata_properties (dict[str, str]): A dictionary of key-value paired
334+
metadata properties (default: None).
332335
333336
Returns:
334337
A `sagemaker.model.ModelPackage` instance.
@@ -356,6 +359,7 @@ def register(
356359
description=description,
357360
container_def_list=[container_def],
358361
drift_check_baselines=drift_check_baselines,
362+
customer_metadata_properties = customer_metadata_properties,
359363
)
360364
model_package = self.sagemaker_session.create_model_package_from_containers(
361365
**model_pkg_args

src/sagemaker/mxnet/model.py

+4
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ def register(
158158
approval_status=None,
159159
description=None,
160160
drift_check_baselines=None,
161+
customer_metadata_properties=None,
161162
):
162163
"""Creates a model package for creating SageMaker models or listing on Marketplace.
163164
@@ -183,6 +184,8 @@ def register(
183184
or "PendingManualApproval" (default: "PendingManualApproval").
184185
description (str): Model Package description (default: None).
185186
drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
187+
customer_metadata_properties (dict[str, str]): A dictionary of key-value paired
188+
metadata properties (default: None).
186189
187190
Returns:
188191
A `sagemaker.model.ModelPackage` instance.
@@ -211,6 +214,7 @@ def register(
211214
approval_status,
212215
description,
213216
drift_check_baselines=drift_check_baselines,
217+
customer_metadata_properties=customer_metadata_properties,
214218
)
215219

216220
def prepare_container_def(self, instance_type=None, accelerator_type=None):

src/sagemaker/pytorch/model.py

+4
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ def register(
157157
approval_status=None,
158158
description=None,
159159
drift_check_baselines=None,
160+
customer_metadata_properties=None,
160161
):
161162
"""Creates a model package for creating SageMaker models or listing on Marketplace.
162163
@@ -182,6 +183,8 @@ def register(
182183
or "PendingManualApproval" (default: "PendingManualApproval").
183184
description (str): Model Package description (default: None).
184185
drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
186+
customer_metadata_properties (dict[str, str]): A dictionary of key-value paired
187+
metadata properties (default: None).
185188
186189
Returns:
187190
A `sagemaker.model.ModelPackage` instance.
@@ -210,6 +213,7 @@ def register(
210213
approval_status,
211214
description,
212215
drift_check_baselines=drift_check_baselines,
216+
customer_metadata_properties=customer_metadata_properties,
213217
)
214218

215219
def prepare_container_def(self, instance_type=None, accelerator_type=None):

src/sagemaker/session.py

+21
Original file line numberDiff line numberDiff line change
@@ -2778,6 +2778,7 @@ def create_model_package_from_containers(
27782778
approval_status="PendingManualApproval",
27792779
description=None,
27802780
drift_check_baselines=None,
2781+
customer_metadata_properties=customer_metadata_properties,
27812782
):
27822783
"""Get request dictionary for CreateModelPackage API.
27832784
@@ -2803,6 +2804,9 @@ def create_model_package_from_containers(
28032804
or "PendingManualApproval" (default: "PendingManualApproval").
28042805
description (str): Model Package description (default: None).
28052806
drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
2807+
customer_metadata_properties (dict[str, str]): A dictionary of key-value paired
2808+
metadata properties (default: None).
2809+
28062810
"""
28072811

28082812
request = get_create_model_package_request(
@@ -2819,7 +2823,14 @@ def create_model_package_from_containers(
28192823
approval_status,
28202824
description,
28212825
drift_check_baselines=drift_check_baselines,
2826+
customer_metadata_properties=customer_metadata_properties,
28222827
)
2828+
try:
2829+
self.sagemaker_client.describe_model_package_group(
2830+
ModelPackageGroupName=request["ModelPackageGroupName"])
2831+
except ClientError as e:
2832+
self.sagemaker_client.create_model_package_group(
2833+
ModelPackageGroupName=request["ModelPackageGroupName"])
28232834
return self.sagemaker_client.create_model_package(**request)
28242835

28252836
def wait_for_model_package(self, model_package_name, poll=5):
@@ -4120,6 +4131,7 @@ def get_model_package_args(
41204131
tags=None,
41214132
container_def_list=None,
41224133
drift_check_baselines=None,
4134+
customer_metadata_properties=None,
41234135
):
41244136
"""Get arguments for create_model_package method.
41254137
@@ -4148,6 +4160,8 @@ def get_model_package_args(
41484160
(default: None).
41494161
container_def_list (list): A list of container defintiions (default: None).
41504162
drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
4163+
customer_metadata_properties (dict[str, str]): A dictionary of key-value paired
4164+
metadata properties (default: None).
41514165
Returns:
41524166
dict: A dictionary of method argument names and values.
41534167
"""
@@ -4185,6 +4199,8 @@ def get_model_package_args(
41854199
model_package_args["description"] = description
41864200
if tags is not None:
41874201
model_package_args["tags"] = tags
4202+
if customer_metadata_properties is not None:
4203+
model_package_args["customer_metadata_properties"] = customer_metadata_properties
41884204
return model_package_args
41894205

41904206

@@ -4203,6 +4219,7 @@ def get_create_model_package_request(
42034219
description=None,
42044220
tags=None,
42054221
drift_check_baselines=None,
4222+
customer_metadata_properties=None,
42064223
):
42074224
"""Get request dictionary for CreateModelPackage API.
42084225
@@ -4229,6 +4246,8 @@ def get_create_model_package_request(
42294246
tags (List[dict[str, str]]): A list of dictionaries containing key-value pairs
42304247
(default: None).
42314248
drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
4249+
customer_metadata_properties (dict[str, str]): A dictionary of key-value paired
4250+
metadata properties (default: None).
42324251
"""
42334252

42344253
if all([model_package_name, model_package_group_name]):
@@ -4250,6 +4269,8 @@ def get_create_model_package_request(
42504269
request_dict["DriftCheckBaselines"] = drift_check_baselines
42514270
if metadata_properties:
42524271
request_dict["MetadataProperties"] = metadata_properties
4272+
if customer_metadata_properties is not None:
4273+
request_dict["CustomerMetadataProperties"] = customer_metadata_properties
42534274
if containers is not None:
42544275
if not all([content_types, response_types, inference_instances, transform_instances]):
42554276
raise ValueError(

src/sagemaker/tensorflow/model.py

+5
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,7 @@ def register(
201201
approval_status=None,
202202
description=None,
203203
drift_check_baselines=None,
204+
customer_metadata_properties=None,
204205
):
205206
"""Creates a model package for creating SageMaker models or listing on Marketplace.
206207
@@ -226,6 +227,9 @@ def register(
226227
or "PendingManualApproval" (default: "PendingManualApproval").
227228
description (str): Model Package description (default: None).
228229
drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
230+
customer_metadata_properties (dict[str, str]): A dictionary of key-value paired
231+
metadata properties (default: None).
232+
229233
230234
Returns:
231235
A `sagemaker.model.ModelPackage` instance.
@@ -254,6 +258,7 @@ def register(
254258
approval_status,
255259
description,
256260
drift_check_baselines=drift_check_baselines,
261+
customer_metadata_properties=customer_metadata_properties,
257262
)
258263

259264
def deploy(

src/sagemaker/workflow/_utils.py

+5
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,7 @@ def __init__(
310310
tags=None,
311311
container_def_list=None,
312312
drift_check_baselines=None,
313+
customer_metadata_properties=None,
313314
**kwargs,
314315
):
315316
"""Constructor of a register model step.
@@ -347,6 +348,8 @@ def __init__(
347348
this step depends on
348349
retry_policies (List[RetryPolicy]): The list of retry policies for the current step
349350
drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
351+
customer_metadata_properties (dict[str, str]): A dictionary of key-value paired
352+
metadata properties (default: None).
350353
**kwargs: additional arguments to `create_model`.
351354
"""
352355
super(_RegisterModelStep, self).__init__(
@@ -362,6 +365,7 @@ def __init__(
362365
self.tags = tags
363366
self.model_metrics = model_metrics
364367
self.drift_check_baselines = drift_check_baselines
368+
self.customer_metadata_properties = customer_metadata_properties,
365369
self.metadata_properties = metadata_properties
366370
self.approval_status = approval_status
367371
self.image_uri = image_uri
@@ -435,6 +439,7 @@ def arguments(self) -> RequestType:
435439
description=self.description,
436440
tags=self.tags,
437441
container_def_list=self.container_def_list,
442+
customer_metadata_properties=self.customer_metadata_properties
438443
)
439444

440445
request_dict = get_create_model_package_request(**model_package_args)

src/sagemaker/workflow/step_collections.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ def __init__(
7575
tags=None,
7676
model: Union[Model, PipelineModel] = None,
7777
drift_check_baselines=None,
78+
customer_metadata_properties=None,
7879
**kwargs,
7980
):
8081
"""Construct steps `_RepackModelStep` and `_RegisterModelStep` based on the estimator.
@@ -95,7 +96,7 @@ def __init__(
9596
for the repack model step
9697
register_model_step_retry_policies (List[RetryPolicy]): The list of retry policies
9798
for register model step
98-
model_package_group_name (str): The Model Package Group name, exclusive to
99+
model_package_group_name (str): The Model Package Group name or Arn, exclusive to
99100
`model_package_name`, using `model_package_group_name` makes the Model Package
100101
versioned (default: None).
101102
model_metrics (ModelMetrics): ModelMetrics object (default: None).
@@ -113,6 +114,9 @@ def __init__(
113114
model (object or Model): A PipelineModel object that comprises a list of models
114115
which gets executed as a serial inference pipeline or a Model object.
115116
drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
117+
customer_metadata_properties (dict[str, str]): A dictionary of key-value paired
118+
metadata properties (default: None).
119+
116120
**kwargs: additional arguments to `create_model`.
117121
"""
118122
steps: List[Step] = []
@@ -229,6 +233,7 @@ def __init__(
229233
tags=tags,
230234
container_def_list=self.container_def_list,
231235
retry_policies=register_model_step_retry_policies,
236+
customer_metadata_properties=customer_metadata_properties,
232237
**kwargs,
233238
)
234239
if not repack_model:

tests/integ/test_mxnet.py

+41
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,47 @@ def test_register_model_package(
230230
assert result is not None
231231
sagemaker_session.sagemaker_client.delete_model_package(ModelPackageName=model_package_name)
232232

233+
def test_register_model_package_via_group(
234+
mxnet_training_job,
235+
sagemaker_session,
236+
mxnet_inference_latest_version,
237+
mxnet_inference_latest_py_version,
238+
cpu_instance_type,
239+
):
240+
endpoint_name = "test-mxnet-deploy-model-{}".format(sagemaker_timestamp())
241+
with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
242+
desc = sagemaker_session.sagemaker_client.describe_training_job(
243+
TrainingJobName=mxnet_training_job
244+
)
245+
model_data = desc["ModelArtifacts"]["S3ModelArtifacts"]
246+
script_path = os.path.join(DATA_DIR, "mxnet_mnist", "mnist.py")
247+
model = MXNetModel(
248+
model_data,
249+
"SageMakerRole",
250+
entry_point=script_path,
251+
py_version=mxnet_inference_latest_py_version,
252+
sagemaker_session=sagemaker_session,
253+
framework_version=mxnet_inference_latest_version,
254+
)
255+
model_package_group_name = "register-model-package-{}".format(sagemaker_timestamp())
256+
model_pkg = model.register(
257+
content_types=["application/json"],
258+
response_types=["application/json"],
259+
inference_instances=["ml.m5.large"],
260+
transform_instances=["ml.m5.large"],
261+
model_package_group_name=model_package_group_name,
262+
)
263+
assert isinstance(model_pkg, ModelPackage)
264+
predictor = model.deploy(1, cpu_instance_type, endpoint_name=endpoint_name)
265+
data = numpy.zeros(shape=(1, 1, 28, 28))
266+
result = predictor.predict(data)
267+
assert result is not None
268+
model_packages = \
269+
sagemaker_session.sagemaker_client.list_model_packages(ModelPackageGroupName=model_package_group_name)[
270+
'ModelPackageSummaryList']
271+
for model_package in model_packages:
272+
sagemaker_session.sagemaker_client.delete_model_package(ModelPackageName=model_package['ModelPackageArn'])
273+
sagemaker_session.sagemaker_client.delete_model_package_group(ModelPackageGroupName=model_package_group_name)
233274

234275
def test_register_model_package_versioned(
235276
mxnet_training_job,

tests/integ/test_workflow.py

+3
Original file line numberDiff line numberDiff line change
@@ -1951,6 +1951,7 @@ def test_model_registration_with_drift_check_baselines(
19511951
content_type="application/json",
19521952
),
19531953
)
1954+
customer_metadata_properties = {"key1": "value1"}
19541955
estimator = XGBoost(
19551956
entry_point="training.py",
19561957
source_dir=os.path.join(DATA_DIR, "sip"),
@@ -1972,6 +1973,7 @@ def test_model_registration_with_drift_check_baselines(
19721973
model_package_group_name="testModelPackageGroup",
19731974
model_metrics=model_metrics,
19741975
drift_check_baselines=drift_check_baselines,
1976+
customer_metadata_properties=customer_metadata_properties,
19751977
)
19761978

19771979
pipeline = Pipeline(
@@ -2042,6 +2044,7 @@ def test_model_registration_with_drift_check_baselines(
20422044
response["DriftCheckBaselines"]["ModelDataQuality"]["Statistics"]["ContentType"]
20432045
== "application/json"
20442046
)
2047+
assert response["CustomerMetadataProperties"] == customer_metadata_properties
20452048
break
20462049
finally:
20472050
try:

tests/unit/test_session.py

+3
Original file line numberDiff line numberDiff line change
@@ -2385,6 +2385,7 @@ def test_create_model_package_from_containers_all_args(sagemaker_session):
23852385
marketplace_cert = (True,)
23862386
approval_status = ("Approved",)
23872387
description = "description"
2388+
customer_metadata_properties = {"key1": "value1"}
23882389
sagemaker_session.create_model_package_from_containers(
23892390
containers=containers,
23902391
content_types=content_types,
@@ -2398,6 +2399,7 @@ def test_create_model_package_from_containers_all_args(sagemaker_session):
23982399
approval_status=approval_status,
23992400
description=description,
24002401
drift_check_baselines=drift_check_baselines,
2402+
customer_metadata_properties=customer_metadata_properties,
24012403
)
24022404
expected_args = {
24032405
"ModelPackageName": model_package_name,
@@ -2414,6 +2416,7 @@ def test_create_model_package_from_containers_all_args(sagemaker_session):
24142416
"CertifyForMarketplace": marketplace_cert,
24152417
"ModelApprovalStatus": approval_status,
24162418
"DriftCheckBaselines": drift_check_baselines,
2419+
"CustomerMetadataProperties": customer_metadata_properties
24172420
}
24182421
sagemaker_session.sagemaker_client.create_model_package.assert_called_with(**expected_args)
24192422

0 commit comments

Comments
 (0)