Skip to content

Commit 0dbf11c

Browse files
metrizableDan Choi
authored and
Dan Choi
committed
fix: ensure model metrics is available to RegisterModel (aws#493)
1 parent b0d7030 commit 0dbf11c

File tree

3 files changed

+41
-8
lines changed

3 files changed

+41
-8
lines changed

src/sagemaker/workflow/_utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -215,9 +215,9 @@ def __init__(
215215
inference_instances,
216216
transform_instances,
217217
model_package_group_name=None,
218-
image_uri=None,
219218
model_metrics=None,
220219
approval_status="PendingManualApproval",
220+
image_uri=None,
221221
compile_model_family=None,
222222
**kwargs,
223223
):
@@ -237,11 +237,11 @@ def __init__(
237237
model_package_group_name (str): Model Package Group name, exclusive to
238238
`model_package_name`, using `model_package_group_name` makes the Model Package
239239
versioned (default: None).
240-
image_uri (str): The container image uri for Model Package, if not specified,
241-
Estimator's training container image will be used (default: None).
242240
model_metrics (ModelMetrics): ModelMetrics object (default: None).
243241
approval_status (str): Model Approval Status, values can be "Approved", "Rejected",
244242
or "PendingManualApproval" (default: "PendingManualApproval").
243+
image_uri (str): The container image uri for Model Package, if not specified,
244+
Estimator's training container image will be used (default: None).
245245
compile_model_family (str): Instance family for compiled model, if specified, a compiled
246246
model will be used (default: None).
247247
**kwargs: additional arguments to `create_model`.
@@ -254,9 +254,9 @@ def __init__(
254254
self.inference_instances = inference_instances
255255
self.transform_instances = transform_instances
256256
self.model_package_group_name = model_package_group_name
257-
self.image_uri = image_uri
258257
self.model_metrics = model_metrics
259258
self.approval_status = approval_status
259+
self.image_uri = image_uri
260260
self.compile_model_family = compile_model_family
261261
self.kwargs = kwargs
262262

@@ -314,7 +314,7 @@ def arguments(self) -> RequestType:
314314
request_dict = model.sagemaker_session._get_create_model_package_request(
315315
**model_package_args
316316
)
317-
# these are not available in the workflow service
317+
# these are not available in the workflow service and will cause rejection
318318
if "CertifyForMarketplace" in request_dict:
319319
request_dict.pop("CertifyForMarketplace")
320320
if "Description" in request_dict:

src/sagemaker/workflow/step_collections.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ def __init__(
6565
inference_instances,
6666
transform_instances,
6767
model_package_group_name=None,
68+
model_metrics=None,
69+
approval_status=None,
6870
image_uri=None,
6971
compile_model_family=None,
7072
**kwargs,
@@ -84,6 +86,9 @@ def __init__(
8486
model_package_group_name (str): The Model Package Group name, exclusive to
8587
`model_package_name`, using `model_package_group_name` makes the Model Package
8688
versioned (default: None).
89+
model_metrics (ModelMetrics): ModelMetrics object (default: None).
90+
approval_status (str): Model Approval Status, values can be "Approved", "Rejected",
91+
or "PendingManualApproval" (default: "PendingManualApproval").
8792
image_uri (str): The container image uri for Model Package, if not specified,
8893
Estimator's training container image is used (default: None).
8994
compile_model_family (str): The instance family for the compiled model. If
@@ -115,6 +120,8 @@ def __init__(
115120
inference_instances=inference_instances,
116121
transform_instances=transform_instances,
117122
model_package_group_name=model_package_group_name,
123+
model_metrics=model_metrics,
124+
approval_status=approval_status,
118125
image_uri=image_uri,
119126
compile_model_family=compile_model_family,
120127
**kwargs,

tests/unit/sagemaker/workflow/test_step_collections.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@
2424

2525
from sagemaker.estimator import Estimator
2626
from sagemaker.inputs import CreateModelInput, TransformInput
27+
from sagemaker.model_metrics import (
28+
MetricsSource,
29+
ModelMetrics,
30+
)
2731
from sagemaker.workflow.properties import Properties
2832
from sagemaker.workflow.steps import (
2933
Step,
@@ -108,6 +112,16 @@ def estimator(sagemaker_session):
108112
)
109113

110114

115+
@pytest.fixture
116+
def model_metrics():
117+
return ModelMetrics(
118+
model_statistics=MetricsSource(
119+
s3_uri=f"s3://{BUCKET}/metrics.csv",
120+
content_type="text/csv",
121+
)
122+
)
123+
124+
111125
def test_step_collection():
112126
step_collection = StepCollection(steps=[CustomStep("MyStep1"), CustomStep("MyStep2")])
113127
assert step_collection.request_dicts() == [
@@ -116,7 +130,7 @@ def test_step_collection():
116130
]
117131

118132

119-
def test_register_model(estimator):
133+
def test_register_model(estimator, model_metrics):
120134
model_data = f"s3://{BUCKET}/model.tar.gz"
121135
register_model = RegisterModel(
122136
name="RegisterModelStep",
@@ -126,6 +140,9 @@ def test_register_model(estimator):
126140
response_types=["response_type"],
127141
inference_instances=["inference_instance"],
128142
transform_instances=["transform_instance"],
143+
model_package_group_name="mpg",
144+
model_metrics=model_metrics,
145+
approval_status="Approved",
129146
)
130147
assert ordered(register_model.request_dicts()) == ordered(
131148
[
@@ -135,14 +152,23 @@ def test_register_model(estimator):
135152
"Arguments": {
136153
"InferenceSpecification": {
137154
"Containers": [
138-
{"Image": "fakeimage", "ModelDataUrl": "s3://my-bucket/model.tar.gz"}
155+
{"Image": "fakeimage", "ModelDataUrl": f"s3://{BUCKET}/model.tar.gz"}
139156
],
140157
"SupportedContentTypes": ["content_type"],
141158
"SupportedRealtimeInferenceInstanceTypes": ["inference_instance"],
142159
"SupportedResponseMIMETypes": ["response_type"],
143160
"SupportedTransformInstanceTypes": ["transform_instance"],
144161
},
145-
"ModelApprovalStatus": "PendingManualApproval",
162+
"ModelApprovalStatus": "Approved",
163+
"ModelMetrics": {
164+
"ModelQuality": {
165+
"Statistics": {
166+
"ContentType": "text/csv",
167+
"S3Uri": f"s3://{BUCKET}/metrics.csv",
168+
},
169+
},
170+
},
171+
"ModelPackageGroupName": "mpg",
146172
},
147173
},
148174
]

0 commit comments

Comments
 (0)