Skip to content

Commit f799f15

Browse files
authored
Feature: Update model card on model package request (#4739)
* Feature: Update model card on model package request * Feature: Update model card on model package request * fix: update_model_card input types
1 parent 327b5d9 commit f799f15

File tree

4 files changed

+315
-4
lines changed

4 files changed

+315
-4
lines changed

src/sagemaker/model.py

+42
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
ModelCard,
4949
ModelPackageModelCard,
5050
)
51+
from sagemaker.model_card.helpers import _hash_content_str
5152
from sagemaker.model_card.schema_constraints import ModelApprovalStatusEnum
5253
from sagemaker.session import Session
5354
from sagemaker.model_metrics import ModelMetrics
@@ -2426,3 +2427,44 @@ def add_inference_specification(
24262427
)
24272428

24282429
sagemaker_session.sagemaker_client.update_model_package(**model_package_update_args)
2430+
2431+
def update_model_card(self, model_card: Union[ModelCard, ModelPackageModelCard]):
2432+
"""Updates Created model card content which created with model package
2433+
2434+
Args:
2435+
model_card (ModelCard | ModelPackageModelCard): Updated Model Card content
2436+
"""
2437+
2438+
sagemaker_session = self.sagemaker_session or sagemaker.Session()
2439+
desc_model_package = sagemaker_session.sagemaker_client.describe_model_package(
2440+
ModelPackageName=self.model_package_arn
2441+
)
2442+
update_model_card_req = model_card._create_request_args()
2443+
if update_model_card_req["ModelCardStatus"] is not None:
2444+
if (
2445+
desc_model_package["ModelCard"]["ModelCardStatus"]
2446+
== update_model_card_req["ModelCardStatus"]
2447+
):
2448+
del update_model_card_req["ModelCardStatus"]
2449+
2450+
if update_model_card_req.get("ModelCardName") is not None:
2451+
del update_model_card_req["ModelCardName"]
2452+
if update_model_card_req.get("Content") is not None:
2453+
previous_content_hash = _hash_content_str(
2454+
desc_model_package["ModelCard"]["ModelCardContent"]
2455+
)
2456+
current_content_hash = _hash_content_str(update_model_card_req["Content"])
2457+
if (
2458+
previous_content_hash == current_content_hash
2459+
or update_model_card_req.get("Content") == "{}"
2460+
or update_model_card_req.get("Content") == "null"
2461+
):
2462+
del update_model_card_req["Content"]
2463+
else:
2464+
update_model_card_req["ModelCardContent"] = update_model_card_req["Content"]
2465+
del update_model_card_req["Content"]
2466+
update_model_package_args = {
2467+
"ModelPackageArn": self.model_package_arn,
2468+
"ModelCard": update_model_card_req,
2469+
}
2470+
sagemaker_session.sagemaker_client.update_model_package(**update_model_package_args)

src/sagemaker/model_card/model_card.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1890,8 +1890,8 @@ class ModelPackageModelCard(object):
18901890

18911891
def __init__(
18921892
self,
1893-
model_card_content: Dict[str, Any],
1894-
model_card_status: str,
1893+
model_card_content: Optional[Dict[str, Any]] = None,
1894+
model_card_status: Optional[str] = None,
18951895
):
18961896

18971897
self.model_card_content = model_card_content

tests/integ/test_model_package.py

+220-1
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,17 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15+
import json
1516
import os
16-
from sagemaker.model_card.schema_constraints import ModelApprovalStatusEnum
17+
from sagemaker.model_card.model_card import (
18+
AdditionalInformation,
19+
BusinessDetails,
20+
IntendedUses,
21+
ModelCard,
22+
ModelOverview,
23+
ModelPackageModelCard,
24+
)
25+
from sagemaker.model_card.schema_constraints import ModelApprovalStatusEnum, ModelCardStatusEnum
1726
from sagemaker.utils import unique_name_from_base
1827
from tests.integ import DATA_DIR
1928
from sagemaker.xgboost import XGBoostModel
@@ -183,6 +192,216 @@ def test_update_source_uri(sagemaker_session):
183192
assert desc_model_package["SourceUri"] == source_uri
184193

185194

195+
def test_update_model_card_with_model_card_object(sagemaker_session):
196+
model_group_name = unique_name_from_base("test-model-group")
197+
intended_uses = IntendedUses(
198+
purpose_of_model="Test model card.",
199+
intended_uses="Not used except this test.",
200+
factors_affecting_model_efficiency="No.",
201+
risk_rating="Low",
202+
explanations_for_risk_rating="Just an example.",
203+
)
204+
business_details = BusinessDetails(
205+
business_problem="The business problem that your model is used to solve.",
206+
business_stakeholders="The stakeholders who have the interest in the business that your model is used for.",
207+
line_of_business="Services that the business is offering.",
208+
)
209+
additional_information = AdditionalInformation(
210+
ethical_considerations="Your model ethical consideration.",
211+
caveats_and_recommendations="Your model's caveats and recommendations.",
212+
custom_details={"custom details1": "details value"},
213+
)
214+
215+
model_overview = ModelOverview(model_creator="TestCreator")
216+
217+
my_card = ModelCard(
218+
name="TestName",
219+
sagemaker_session=sagemaker_session,
220+
status=ModelCardStatusEnum.DRAFT,
221+
model_overview=model_overview,
222+
intended_uses=intended_uses,
223+
business_details=business_details,
224+
additional_information=additional_information,
225+
)
226+
227+
sagemaker_session.sagemaker_client.create_model_package_group(
228+
ModelPackageGroupName=model_group_name
229+
)
230+
231+
xgb_model_data_s3 = sagemaker_session.upload_data(
232+
path=os.path.join(_XGBOOST_PATH, "xgb_model.tar.gz"),
233+
key_prefix="integ-test-data/xgboost/model",
234+
)
235+
model = XGBoostModel(
236+
model_data=xgb_model_data_s3, framework_version="1.3-1", sagemaker_session=sagemaker_session
237+
)
238+
239+
model_package = model.register(
240+
content_types=["text/csv"],
241+
response_types=["text/csv"],
242+
inference_instances=["ml.m5.large"],
243+
transform_instances=["ml.m5.large"],
244+
model_package_group_name=model_group_name,
245+
model_card=my_card,
246+
)
247+
248+
desc_model_package = sagemaker_session.sagemaker_client.describe_model_package(
249+
ModelPackageName=model_package.model_package_arn
250+
)
251+
252+
updated_model_overview = ModelOverview(model_creator="updatedCreator")
253+
updated_intended_uses = IntendedUses(
254+
purpose_of_model="Updated Test model card.",
255+
)
256+
updated_my_card = ModelCard(
257+
name="TestName",
258+
sagemaker_session=sagemaker_session,
259+
model_overview=updated_model_overview,
260+
intended_uses=updated_intended_uses,
261+
)
262+
model_package.update_model_card(updated_my_card)
263+
desc_model_package = sagemaker_session.sagemaker_client.describe_model_package(
264+
ModelPackageName=model_package.model_package_arn
265+
)
266+
267+
model_card_content = json.loads(desc_model_package["ModelCard"]["ModelCardContent"])
268+
assert model_card_content["intended_uses"]["purpose_of_model"] == "Updated Test model card."
269+
assert model_card_content["model_overview"]["model_creator"] == "updatedCreator"
270+
updated_my_card_status = ModelCard(
271+
name="TestName",
272+
sagemaker_session=sagemaker_session,
273+
status=ModelCardStatusEnum.PENDING_REVIEW,
274+
)
275+
model_package.update_model_card(updated_my_card_status)
276+
desc_model_package = sagemaker_session.sagemaker_client.describe_model_package(
277+
ModelPackageName=model_package.model_package_arn
278+
)
279+
280+
model_card_content = json.loads(desc_model_package["ModelCard"]["ModelCardContent"])
281+
assert desc_model_package["ModelCard"]["ModelCardStatus"] == ModelCardStatusEnum.PENDING_REVIEW
282+
283+
284+
def test_update_model_card_with_model_card_json(sagemaker_session):
285+
model_group_name = unique_name_from_base("test-model-group")
286+
model_card_content = {
287+
"model_overview": {
288+
"model_creator": "TestCreator",
289+
},
290+
"intended_uses": {
291+
"purpose_of_model": "Test model card.",
292+
"intended_uses": "Not used except this test.",
293+
"factors_affecting_model_efficiency": "No.",
294+
"risk_rating": "Low",
295+
"explanations_for_risk_rating": "Just an example.",
296+
},
297+
"business_details": {
298+
"business_problem": "The business problem that your model is used to solve.",
299+
"business_stakeholders": "The stakeholders who have the interest in the business.",
300+
"line_of_business": "Services that the business is offering.",
301+
},
302+
"evaluation_details": [
303+
{
304+
"name": "Example evaluation job",
305+
"evaluation_observation": "Evaluation observations.",
306+
"metric_groups": [
307+
{
308+
"name": "binary classification metrics",
309+
"metric_data": [{"name": "accuracy", "type": "number", "value": 0.5}],
310+
}
311+
],
312+
}
313+
],
314+
"additional_information": {
315+
"ethical_considerations": "Your model ethical consideration.",
316+
"caveats_and_recommendations": 'Your model"s caveats and recommendations.',
317+
"custom_details": {"custom details1": "details value"},
318+
},
319+
}
320+
my_card = ModelPackageModelCard(
321+
model_card_status=ModelCardStatusEnum.DRAFT, model_card_content=model_card_content
322+
)
323+
324+
sagemaker_session.sagemaker_client.create_model_package_group(
325+
ModelPackageGroupName=model_group_name
326+
)
327+
328+
xgb_model_data_s3 = sagemaker_session.upload_data(
329+
path=os.path.join(_XGBOOST_PATH, "xgb_model.tar.gz"),
330+
key_prefix="integ-test-data/xgboost/model",
331+
)
332+
model = XGBoostModel(
333+
model_data=xgb_model_data_s3, framework_version="1.3-1", sagemaker_session=sagemaker_session
334+
)
335+
336+
model_package = model.register(
337+
content_types=["text/csv"],
338+
response_types=["text/csv"],
339+
inference_instances=["ml.m5.large"],
340+
transform_instances=["ml.m5.large"],
341+
model_package_group_name=model_group_name,
342+
model_card=my_card,
343+
)
344+
345+
desc_model_package = sagemaker_session.sagemaker_client.describe_model_package(
346+
ModelPackageName=model_package.model_package_arn
347+
)
348+
349+
updated_model_card_content = {
350+
"model_overview": {
351+
"model_creator": "updatedCreator",
352+
},
353+
"intended_uses": {
354+
"purpose_of_model": "Updated Test model card.",
355+
"intended_uses": "Not used except this test.",
356+
"factors_affecting_model_efficiency": "No.",
357+
"risk_rating": "Low",
358+
"explanations_for_risk_rating": "Just an example.",
359+
},
360+
"business_details": {
361+
"business_problem": "The business problem that your model is used to solve.",
362+
"business_stakeholders": "The stakeholders who have the interest in the business.",
363+
"line_of_business": "Services that the business is offering.",
364+
},
365+
"evaluation_details": [
366+
{
367+
"name": "Example evaluation job",
368+
"evaluation_observation": "Evaluation observations.",
369+
"metric_groups": [
370+
{
371+
"name": "binary classification metrics",
372+
"metric_data": [{"name": "accuracy", "type": "number", "value": 0.5}],
373+
}
374+
],
375+
}
376+
],
377+
"additional_information": {
378+
"ethical_considerations": "Your model ethical consideration.",
379+
"caveats_and_recommendations": 'Your model"s caveats and recommendations.',
380+
"custom_details": {"custom details1": "details value"},
381+
},
382+
}
383+
updated_my_card = ModelPackageModelCard(
384+
model_card_status=ModelCardStatusEnum.DRAFT, model_card_content=updated_model_card_content
385+
)
386+
model_package.update_model_card(updated_my_card)
387+
desc_model_package = sagemaker_session.sagemaker_client.describe_model_package(
388+
ModelPackageName=model_package.model_package_arn
389+
)
390+
391+
model_card_content = json.loads(desc_model_package["ModelCard"]["ModelCardContent"])
392+
assert model_card_content["intended_uses"]["purpose_of_model"] == "Updated Test model card."
393+
assert model_card_content["model_overview"]["model_creator"] == "updatedCreator"
394+
updated_my_card_status = ModelPackageModelCard(
395+
model_card_status=ModelCardStatusEnum.PENDING_REVIEW,
396+
)
397+
model_package.update_model_card(updated_my_card_status)
398+
desc_model_package = sagemaker_session.sagemaker_client.describe_model_package(
399+
ModelPackageName=model_package.model_package_arn
400+
)
401+
402+
assert desc_model_package["ModelCard"]["ModelCardStatus"] == ModelCardStatusEnum.PENDING_REVIEW
403+
404+
186405
def test_clone_model_package_using_source_uri(sagemaker_session):
187406
model_group_name = unique_name_from_base("test-model-group")
188407

tests/unit/sagemaker/model/test_model_package.py

+51-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919

2020
import sagemaker
2121
from sagemaker.model import ModelPackage
22-
from sagemaker.model_card.schema_constraints import ModelApprovalStatusEnum
22+
from sagemaker.model_card.model_card import ModelCard, ModelOverview
23+
from sagemaker.model_card.schema_constraints import ModelApprovalStatusEnum, ModelCardStatusEnum
2324

2425
MODEL_PACKAGE_VERSIONED_ARN = (
2526
"arn:aws:sagemaker:us-west-2:001234567890:model-package/testmodelgroup/1"
@@ -56,6 +57,10 @@
5657
"ModelPackageStatus": "Completed",
5758
"ModelPackageName": "mp-scikit-decision-trees-1542410022-2018-11-20-22-13-56-502",
5859
"CertifyForMarketplace": False,
60+
"ModelCard": {
61+
"ModelCardStatus": "Draft",
62+
"ModelCardContent": '{"model_overview": {"model_creator": "updatedCreator", "model_artifact": []}}',
63+
},
5964
}
6065

6166
MODEL_DATA = {
@@ -442,3 +447,48 @@ def test_update_source_uri(sagemaker_session):
442447
sagemaker_session.sagemaker_client.update_model_package.assert_called_with(
443448
ModelPackageArn=MODEL_PACKAGE_VERSIONED_ARN, SourceUri=source_uri
444449
)
450+
451+
452+
def test_update_model_card(sagemaker_session):
453+
model_package_response = copy.deepcopy(DESCRIBE_MODEL_PACKAGE_RESPONSE)
454+
455+
sagemaker_session.sagemaker_client.describe_model_package = Mock(
456+
return_value=model_package_response
457+
)
458+
model_package = ModelPackage(
459+
role="role",
460+
model_package_arn=MODEL_PACKAGE_VERSIONED_ARN,
461+
sagemaker_session=sagemaker_session,
462+
)
463+
464+
update_my_card = ModelCard(
465+
name="UpdateTestName",
466+
sagemaker_session=sagemaker_session,
467+
status=ModelCardStatusEnum.PENDING_REVIEW,
468+
)
469+
model_package.update_model_card(update_my_card)
470+
update_my_card_req = update_my_card._create_request_args()
471+
del update_my_card_req["ModelCardName"]
472+
del update_my_card_req["Content"]
473+
sagemaker_session.sagemaker_client.update_model_package.assert_called_with(
474+
ModelPackageArn=MODEL_PACKAGE_VERSIONED_ARN, ModelCard=update_my_card_req
475+
)
476+
477+
model_overview = ModelOverview(
478+
model_creator="UpdatedNewCreator",
479+
)
480+
update_my_card_1 = ModelCard(
481+
name="UpdateTestName",
482+
sagemaker_session=sagemaker_session,
483+
status=ModelCardStatusEnum.DRAFT,
484+
model_overview=model_overview,
485+
)
486+
model_package.update_model_card(update_my_card_1)
487+
update_my_card_req_1 = update_my_card_1._create_request_args()
488+
del update_my_card_req_1["ModelCardName"]
489+
del update_my_card_req_1["ModelCardStatus"]
490+
update_my_card_req_1["ModelCardContent"] = update_my_card_req_1["Content"]
491+
del update_my_card_req_1["Content"]
492+
sagemaker_session.sagemaker_client.update_model_package.assert_called_with(
493+
ModelPackageArn=MODEL_PACKAGE_VERSIONED_ARN, ModelCard=update_my_card_req_1
494+
)

0 commit comments

Comments
 (0)