Skip to content

Commit 3dd4373

Browse files
author
Dan
authored
fix: add description parameter for RegisterModelStep (#2190)
1 parent cc19325 commit 3dd4373

File tree

4 files changed

+10
-0
lines changed

4 files changed

+10
-0
lines changed

src/sagemaker/workflow/_utils.py

+4
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,7 @@ def __init__(
220220
approval_status="PendingManualApproval",
221221
image_uri=None,
222222
compile_model_family=None,
223+
description=None,
223224
**kwargs,
224225
):
225226
"""Constructor of a register model step.
@@ -246,6 +247,7 @@ def __init__(
246247
Estimator's training container image will be used (default: None).
247248
compile_model_family (str): Instance family for compiled model, if specified, a compiled
248249
model will be used (default: None).
250+
description (str): Model Package description (default: None).
249251
**kwargs: additional arguments to `create_model`.
250252
"""
251253
super(_RegisterModelStep, self).__init__(name, StepTypeEnum.REGISTER_MODEL)
@@ -261,6 +263,7 @@ def __init__(
261263
self.approval_status = approval_status
262264
self.image_uri = image_uri
263265
self.compile_model_family = compile_model_family
266+
self.description = description
264267
self.kwargs = kwargs
265268

266269
self._properties = Properties(
@@ -314,6 +317,7 @@ def arguments(self) -> RequestType:
314317
model_metrics=self.model_metrics,
315318
metadata_properties=self.metadata_properties,
316319
approval_status=self.approval_status,
320+
description=self.description,
317321
)
318322
request_dict = model.sagemaker_session._get_create_model_package_request(
319323
**model_package_args

src/sagemaker/workflow/step_collections.py

+3
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def __init__(
6565
approval_status=None,
6666
image_uri=None,
6767
compile_model_family=None,
68+
description=None,
6869
**kwargs,
6970
):
7071
"""Construct steps `_RepackModelStep` and `_RegisterModelStep` based on the estimator.
@@ -89,6 +90,7 @@ def __init__(
8990
Estimator's training container image is used (default: None).
9091
compile_model_family (str): The instance family for the compiled model. If
9192
specified, a compiled model is used (default: None).
93+
description (str): Model Package description (default: None).
9294
**kwargs: additional arguments to `create_model`.
9395
"""
9496
steps: List[Step] = []
@@ -120,6 +122,7 @@ def __init__(
120122
approval_status=approval_status,
121123
image_uri=image_uri,
122124
compile_model_family=compile_model_family,
125+
description=description,
123126
**kwargs,
124127
)
125128
steps.append(register_model_step)

tests/integ/test_workflow.py

+1
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,7 @@ def test_conditional_pytorch_training_model_registration(
393393
response_types=["*"],
394394
inference_instances=["*"],
395395
transform_instances=["*"],
396+
description="test-description",
396397
)
397398

398399
model = Model(

tests/unit/sagemaker/workflow/test_step_collections.py

+2
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ def test_register_model(estimator, model_metrics):
143143
model_package_group_name="mpg",
144144
model_metrics=model_metrics,
145145
approval_status="Approved",
146+
description="description",
146147
)
147148
assert ordered(register_model.request_dicts()) == ordered(
148149
[
@@ -168,6 +169,7 @@ def test_register_model(estimator, model_metrics):
168169
},
169170
},
170171
},
172+
"ModelPackageDescription": "description",
171173
"ModelPackageGroupName": "mpg",
172174
},
173175
},

0 commit comments

Comments
 (0)