Skip to content

Commit dd76ad7

Browse files
staubhpPayton Staub
and
Payton Staub
authored
change: Register model step tags (#2475)
Co-authored-by: Payton Staub <[email protected]>
1 parent 9e7b4b5 commit dd76ad7

File tree

5 files changed

+18
-0
lines changed

5 files changed

+18
-0
lines changed

src/sagemaker/model.py

+3
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,7 @@ def _get_model_package_args(
195195
marketplace_cert=False,
196196
approval_status=None,
197197
description=None,
198+
tags=None,
198199
):
199200
"""Get arguments for session.create_model_package method.
200201
@@ -250,6 +251,8 @@ def _get_model_package_args(
250251
model_package_args["approval_status"] = approval_status
251252
if description is not None:
252253
model_package_args["description"] = description
254+
if tags is not None:
255+
model_package_args["tags"] = tags
253256
return model_package_args
254257

255258
def _init_sagemaker_session_if_does_not_exist(self, instance_type):

src/sagemaker/session.py

+3
Original file line numberDiff line numberDiff line change
@@ -2724,6 +2724,7 @@ def _get_create_model_package_request(
27242724
marketplace_cert=False,
27252725
approval_status="PendingManualApproval",
27262726
description=None,
2727+
tags=None,
27272728
):
27282729
"""Get request dictionary for CreateModelPackage API.
27292730
@@ -2761,6 +2762,8 @@ def _get_create_model_package_request(
27612762
request_dict["ModelPackageGroupName"] = model_package_group_name
27622763
if description is not None:
27632764
request_dict["ModelPackageDescription"] = description
2765+
if tags is not None:
2766+
request_dict["Tags"] = tags
27642767
if model_metrics:
27652768
request_dict["ModelMetrics"] = model_metrics
27662769
if metadata_properties:

src/sagemaker/workflow/_utils.py

+4
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,7 @@ def __init__(
225225
compile_model_family=None,
226226
description=None,
227227
depends_on: List[str] = None,
228+
tags=None,
228229
**kwargs,
229230
):
230231
"""Constructor of a register model step.
@@ -264,6 +265,7 @@ def __init__(
264265
self.inference_instances = inference_instances
265266
self.transform_instances = transform_instances
266267
self.model_package_group_name = model_package_group_name
268+
self.tags = tags
267269
self.model_metrics = model_metrics
268270
self.metadata_properties = metadata_properties
269271
self.approval_status = approval_status
@@ -324,10 +326,12 @@ def arguments(self) -> RequestType:
324326
metadata_properties=self.metadata_properties,
325327
approval_status=self.approval_status,
326328
description=self.description,
329+
tags=self.tags,
327330
)
328331
request_dict = model.sagemaker_session._get_create_model_package_request(
329332
**model_package_args
330333
)
334+
331335
# these are not available in the workflow service and will cause rejection
332336
if "CertifyForMarketplace" in request_dict:
333337
request_dict.pop("CertifyForMarketplace")

src/sagemaker/workflow/step_collections.py

+6
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ def __init__(
6767
image_uri=None,
6868
compile_model_family=None,
6969
description=None,
70+
tags=None,
7071
**kwargs,
7172
):
7273
"""Construct steps `_RepackModelStep` and `_RegisterModelStep` based on the estimator.
@@ -94,6 +95,10 @@ def __init__(
9495
compile_model_family (str): The instance family for the compiled model. If
9596
specified, a compiled model is used (default: None).
9697
description (str): Model Package description (default: None).
98+
tags (List[dict[str, str]]): The list of tags to attach to the model package group. Note
99+
that tags will only be applied to newly created model package groups; if the
100+
name of an existing group is passed to "model_package_group_name",
101+
tags will not be applied.
97102
**kwargs: additional arguments to `create_model`.
98103
"""
99104
steps: List[Step] = []
@@ -134,6 +139,7 @@ def __init__(
134139
image_uri=image_uri,
135140
compile_model_family=compile_model_family,
136141
description=description,
142+
tags=tags,
137143
**kwargs,
138144
)
139145
if not repack_model:

tests/unit/sagemaker/workflow/test_step_collections.py

+2
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@ def test_register_model(estimator, model_metrics):
182182
approval_status="Approved",
183183
description="description",
184184
depends_on=["TestStep"],
185+
tags=[{"Key": "myKey", "Value": "myValue"}],
185186
)
186187
assert ordered(register_model.request_dicts()) == ordered(
187188
[
@@ -210,6 +211,7 @@ def test_register_model(estimator, model_metrics):
210211
},
211212
"ModelPackageDescription": "description",
212213
"ModelPackageGroupName": "mpg",
214+
"Tags": [{"Key": "myKey", "Value": "myValue"}],
213215
},
214216
},
215217
]

0 commit comments

Comments
 (0)