File tree 5 files changed +18
-0
lines changed
tests/unit/sagemaker/workflow
5 files changed +18
-0
lines changed Original file line number Diff line number Diff line change @@ -195,6 +195,7 @@ def _get_model_package_args(
195
195
marketplace_cert = False ,
196
196
approval_status = None ,
197
197
description = None ,
198
+ tags = None ,
198
199
):
199
200
"""Get arguments for session.create_model_package method.
200
201
@@ -250,6 +251,8 @@ def _get_model_package_args(
250
251
model_package_args ["approval_status" ] = approval_status
251
252
if description is not None :
252
253
model_package_args ["description" ] = description
254
+ if tags is not None :
255
+ model_package_args ["tags" ] = tags
253
256
return model_package_args
254
257
255
258
def _init_sagemaker_session_if_does_not_exist (self , instance_type ):
Original file line number Diff line number Diff line change @@ -2724,6 +2724,7 @@ def _get_create_model_package_request(
2724
2724
marketplace_cert = False ,
2725
2725
approval_status = "PendingManualApproval" ,
2726
2726
description = None ,
2727
+ tags = None ,
2727
2728
):
2728
2729
"""Get request dictionary for CreateModelPackage API.
2729
2730
@@ -2761,6 +2762,8 @@ def _get_create_model_package_request(
2761
2762
request_dict ["ModelPackageGroupName" ] = model_package_group_name
2762
2763
if description is not None :
2763
2764
request_dict ["ModelPackageDescription" ] = description
2765
+ if tags is not None :
2766
+ request_dict ["Tags" ] = tags
2764
2767
if model_metrics :
2765
2768
request_dict ["ModelMetrics" ] = model_metrics
2766
2769
if metadata_properties :
Original file line number Diff line number Diff line change @@ -225,6 +225,7 @@ def __init__(
225
225
compile_model_family = None ,
226
226
description = None ,
227
227
depends_on : List [str ] = None ,
228
+ tags = None ,
228
229
** kwargs ,
229
230
):
230
231
"""Constructor of a register model step.
@@ -264,6 +265,7 @@ def __init__(
264
265
self .inference_instances = inference_instances
265
266
self .transform_instances = transform_instances
266
267
self .model_package_group_name = model_package_group_name
268
+ self .tags = tags
267
269
self .model_metrics = model_metrics
268
270
self .metadata_properties = metadata_properties
269
271
self .approval_status = approval_status
@@ -324,10 +326,12 @@ def arguments(self) -> RequestType:
324
326
metadata_properties = self .metadata_properties ,
325
327
approval_status = self .approval_status ,
326
328
description = self .description ,
329
+ tags = self .tags ,
327
330
)
328
331
request_dict = model .sagemaker_session ._get_create_model_package_request (
329
332
** model_package_args
330
333
)
334
+
331
335
# these are not available in the workflow service and will cause rejection
332
336
if "CertifyForMarketplace" in request_dict :
333
337
request_dict .pop ("CertifyForMarketplace" )
Original file line number Diff line number Diff line change @@ -67,6 +67,7 @@ def __init__(
67
67
image_uri = None ,
68
68
compile_model_family = None ,
69
69
description = None ,
70
+ tags = None ,
70
71
** kwargs ,
71
72
):
72
73
"""Construct steps `_RepackModelStep` and `_RegisterModelStep` based on the estimator.
@@ -94,6 +95,10 @@ def __init__(
94
95
compile_model_family (str): The instance family for the compiled model. If
95
96
specified, a compiled model is used (default: None).
96
97
description (str): Model Package description (default: None).
98
+ tags (List[dict[str, str]]): The list of tags to attach to the model package group. Note
99
+ that tags will only be applied to newly created model package groups; if the
100
+ name of an existing group is passed to "model_package_group_name",
101
+ tags will not be applied.
97
102
**kwargs: additional arguments to `create_model`.
98
103
"""
99
104
steps : List [Step ] = []
@@ -134,6 +139,7 @@ def __init__(
134
139
image_uri = image_uri ,
135
140
compile_model_family = compile_model_family ,
136
141
description = description ,
142
+ tags = tags ,
137
143
** kwargs ,
138
144
)
139
145
if not repack_model :
Original file line number Diff line number Diff line change @@ -182,6 +182,7 @@ def test_register_model(estimator, model_metrics):
182
182
approval_status = "Approved" ,
183
183
description = "description" ,
184
184
depends_on = ["TestStep" ],
185
+ tags = [{"Key" : "myKey" , "Value" : "myValue" }],
185
186
)
186
187
assert ordered (register_model .request_dicts ()) == ordered (
187
188
[
@@ -210,6 +211,7 @@ def test_register_model(estimator, model_metrics):
210
211
},
211
212
"ModelPackageDescription" : "description" ,
212
213
"ModelPackageGroupName" : "mpg" ,
214
+ "Tags" : [{"Key" : "myKey" , "Value" : "myValue" }],
213
215
},
214
216
},
215
217
]
You can’t perform that action at this time.
0 commit comments