Skip to content

Feature: Update model card on model package request #4739

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions src/sagemaker/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
ModelCard,
ModelPackageModelCard,
)
from sagemaker.model_card.helpers import _hash_content_str
from sagemaker.model_card.schema_constraints import ModelApprovalStatusEnum
from sagemaker.session import Session
from sagemaker.model_metrics import ModelMetrics
Expand Down Expand Up @@ -2426,3 +2427,44 @@ def add_inference_specification(
)

sagemaker_session.sagemaker_client.update_model_package(**model_package_update_args)

def update_model_card(self, model_card: Union[ModelCard, ModelPackageModelCard]):
"""Updates Created model card content which created with model package

Args:
model_card (ModelCard | ModelPackageModelCard): Updated Model Card content
"""

sagemaker_session = self.sagemaker_session or sagemaker.Session()
desc_model_package = sagemaker_session.sagemaker_client.describe_model_package(
ModelPackageName=self.model_package_arn
)
update_model_card_req = model_card._create_request_args()
if update_model_card_req["ModelCardStatus"] is not None:
if (
desc_model_package["ModelCard"]["ModelCardStatus"]
== update_model_card_req["ModelCardStatus"]
):
del update_model_card_req["ModelCardStatus"]

if update_model_card_req.get("ModelCardName") is not None:
del update_model_card_req["ModelCardName"]
if update_model_card_req.get("Content") is not None:
previous_content_hash = _hash_content_str(
desc_model_package["ModelCard"]["ModelCardContent"]
)
current_content_hash = _hash_content_str(update_model_card_req["Content"])
if (
previous_content_hash == current_content_hash
or update_model_card_req.get("Content") == "{}"
or update_model_card_req.get("Content") == "null"
):
del update_model_card_req["Content"]
else:
update_model_card_req["ModelCardContent"] = update_model_card_req["Content"]
del update_model_card_req["Content"]
update_model_package_args = {
"ModelPackageArn": self.model_package_arn,
"ModelCard": update_model_card_req,
}
sagemaker_session.sagemaker_client.update_model_package(**update_model_package_args)
4 changes: 2 additions & 2 deletions src/sagemaker/model_card/model_card.py
Original file line number Diff line number Diff line change
Expand Up @@ -1890,8 +1890,8 @@ class ModelPackageModelCard(object):

def __init__(
self,
model_card_content: Dict[str, Any],
model_card_status: str,
model_card_content: Optional[Dict[str, Any]] = None,
model_card_status: Optional[str] = None,
):

self.model_card_content = model_card_content
Expand Down
221 changes: 220 additions & 1 deletion tests/integ/test_model_package.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,17 @@
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import json
import os
from sagemaker.model_card.schema_constraints import ModelApprovalStatusEnum
from sagemaker.model_card.model_card import (
AdditionalInformation,
BusinessDetails,
IntendedUses,
ModelCard,
ModelOverview,
ModelPackageModelCard,
)
from sagemaker.model_card.schema_constraints import ModelApprovalStatusEnum, ModelCardStatusEnum
from sagemaker.utils import unique_name_from_base
from tests.integ import DATA_DIR
from sagemaker.xgboost import XGBoostModel
Expand Down Expand Up @@ -183,6 +192,216 @@ def test_update_source_uri(sagemaker_session):
assert desc_model_package["SourceUri"] == source_uri


def test_update_model_card_with_model_card_object(sagemaker_session):
model_group_name = unique_name_from_base("test-model-group")
intended_uses = IntendedUses(
purpose_of_model="Test model card.",
intended_uses="Not used except this test.",
factors_affecting_model_efficiency="No.",
risk_rating="Low",
explanations_for_risk_rating="Just an example.",
)
business_details = BusinessDetails(
business_problem="The business problem that your model is used to solve.",
business_stakeholders="The stakeholders who have the interest in the business that your model is used for.",
line_of_business="Services that the business is offering.",
)
additional_information = AdditionalInformation(
ethical_considerations="Your model ethical consideration.",
caveats_and_recommendations="Your model's caveats and recommendations.",
custom_details={"custom details1": "details value"},
)

model_overview = ModelOverview(model_creator="TestCreator")

my_card = ModelCard(
name="TestName",
sagemaker_session=sagemaker_session,
status=ModelCardStatusEnum.DRAFT,
model_overview=model_overview,
intended_uses=intended_uses,
business_details=business_details,
additional_information=additional_information,
)

sagemaker_session.sagemaker_client.create_model_package_group(
ModelPackageGroupName=model_group_name
)

xgb_model_data_s3 = sagemaker_session.upload_data(
path=os.path.join(_XGBOOST_PATH, "xgb_model.tar.gz"),
key_prefix="integ-test-data/xgboost/model",
)
model = XGBoostModel(
model_data=xgb_model_data_s3, framework_version="1.3-1", sagemaker_session=sagemaker_session
)

model_package = model.register(
content_types=["text/csv"],
response_types=["text/csv"],
inference_instances=["ml.m5.large"],
transform_instances=["ml.m5.large"],
model_package_group_name=model_group_name,
model_card=my_card,
)

desc_model_package = sagemaker_session.sagemaker_client.describe_model_package(
ModelPackageName=model_package.model_package_arn
)

updated_model_overview = ModelOverview(model_creator="updatedCreator")
updated_intended_uses = IntendedUses(
purpose_of_model="Updated Test model card.",
)
updated_my_card = ModelCard(
name="TestName",
sagemaker_session=sagemaker_session,
model_overview=updated_model_overview,
intended_uses=updated_intended_uses,
)
model_package.update_model_card(updated_my_card)
desc_model_package = sagemaker_session.sagemaker_client.describe_model_package(
ModelPackageName=model_package.model_package_arn
)

model_card_content = json.loads(desc_model_package["ModelCard"]["ModelCardContent"])
assert model_card_content["intended_uses"]["purpose_of_model"] == "Updated Test model card."
assert model_card_content["model_overview"]["model_creator"] == "updatedCreator"
updated_my_card_status = ModelCard(
name="TestName",
sagemaker_session=sagemaker_session,
status=ModelCardStatusEnum.PENDING_REVIEW,
)
model_package.update_model_card(updated_my_card_status)
desc_model_package = sagemaker_session.sagemaker_client.describe_model_package(
ModelPackageName=model_package.model_package_arn
)

model_card_content = json.loads(desc_model_package["ModelCard"]["ModelCardContent"])
assert desc_model_package["ModelCard"]["ModelCardStatus"] == ModelCardStatusEnum.PENDING_REVIEW


def test_update_model_card_with_model_card_json(sagemaker_session):
model_group_name = unique_name_from_base("test-model-group")
model_card_content = {
"model_overview": {
"model_creator": "TestCreator",
},
"intended_uses": {
"purpose_of_model": "Test model card.",
"intended_uses": "Not used except this test.",
"factors_affecting_model_efficiency": "No.",
"risk_rating": "Low",
"explanations_for_risk_rating": "Just an example.",
},
"business_details": {
"business_problem": "The business problem that your model is used to solve.",
"business_stakeholders": "The stakeholders who have the interest in the business.",
"line_of_business": "Services that the business is offering.",
},
"evaluation_details": [
{
"name": "Example evaluation job",
"evaluation_observation": "Evaluation observations.",
"metric_groups": [
{
"name": "binary classification metrics",
"metric_data": [{"name": "accuracy", "type": "number", "value": 0.5}],
}
],
}
],
"additional_information": {
"ethical_considerations": "Your model ethical consideration.",
"caveats_and_recommendations": 'Your model"s caveats and recommendations.',
"custom_details": {"custom details1": "details value"},
},
}
my_card = ModelPackageModelCard(
model_card_status=ModelCardStatusEnum.DRAFT, model_card_content=model_card_content
)

sagemaker_session.sagemaker_client.create_model_package_group(
ModelPackageGroupName=model_group_name
)

xgb_model_data_s3 = sagemaker_session.upload_data(
path=os.path.join(_XGBOOST_PATH, "xgb_model.tar.gz"),
key_prefix="integ-test-data/xgboost/model",
)
model = XGBoostModel(
model_data=xgb_model_data_s3, framework_version="1.3-1", sagemaker_session=sagemaker_session
)

model_package = model.register(
content_types=["text/csv"],
response_types=["text/csv"],
inference_instances=["ml.m5.large"],
transform_instances=["ml.m5.large"],
model_package_group_name=model_group_name,
model_card=my_card,
)

desc_model_package = sagemaker_session.sagemaker_client.describe_model_package(
ModelPackageName=model_package.model_package_arn
)

updated_model_card_content = {
"model_overview": {
"model_creator": "updatedCreator",
},
"intended_uses": {
"purpose_of_model": "Updated Test model card.",
"intended_uses": "Not used except this test.",
"factors_affecting_model_efficiency": "No.",
"risk_rating": "Low",
"explanations_for_risk_rating": "Just an example.",
},
"business_details": {
"business_problem": "The business problem that your model is used to solve.",
"business_stakeholders": "The stakeholders who have the interest in the business.",
"line_of_business": "Services that the business is offering.",
},
"evaluation_details": [
{
"name": "Example evaluation job",
"evaluation_observation": "Evaluation observations.",
"metric_groups": [
{
"name": "binary classification metrics",
"metric_data": [{"name": "accuracy", "type": "number", "value": 0.5}],
}
],
}
],
"additional_information": {
"ethical_considerations": "Your model ethical consideration.",
"caveats_and_recommendations": 'Your model"s caveats and recommendations.',
"custom_details": {"custom details1": "details value"},
},
}
updated_my_card = ModelPackageModelCard(
model_card_status=ModelCardStatusEnum.DRAFT, model_card_content=updated_model_card_content
)
model_package.update_model_card(updated_my_card)
desc_model_package = sagemaker_session.sagemaker_client.describe_model_package(
ModelPackageName=model_package.model_package_arn
)

model_card_content = json.loads(desc_model_package["ModelCard"]["ModelCardContent"])
assert model_card_content["intended_uses"]["purpose_of_model"] == "Updated Test model card."
assert model_card_content["model_overview"]["model_creator"] == "updatedCreator"
updated_my_card_status = ModelPackageModelCard(
model_card_status=ModelCardStatusEnum.PENDING_REVIEW,
)
model_package.update_model_card(updated_my_card_status)
desc_model_package = sagemaker_session.sagemaker_client.describe_model_package(
ModelPackageName=model_package.model_package_arn
)

assert desc_model_package["ModelCard"]["ModelCardStatus"] == ModelCardStatusEnum.PENDING_REVIEW


def test_clone_model_package_using_source_uri(sagemaker_session):
model_group_name = unique_name_from_base("test-model-group")

Expand Down
52 changes: 51 additions & 1 deletion tests/unit/sagemaker/model/test_model_package.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@

import sagemaker
from sagemaker.model import ModelPackage
from sagemaker.model_card.schema_constraints import ModelApprovalStatusEnum
from sagemaker.model_card.model_card import ModelCard, ModelOverview
from sagemaker.model_card.schema_constraints import ModelApprovalStatusEnum, ModelCardStatusEnum

MODEL_PACKAGE_VERSIONED_ARN = (
"arn:aws:sagemaker:us-west-2:001234567890:model-package/testmodelgroup/1"
Expand Down Expand Up @@ -56,6 +57,10 @@
"ModelPackageStatus": "Completed",
"ModelPackageName": "mp-scikit-decision-trees-1542410022-2018-11-20-22-13-56-502",
"CertifyForMarketplace": False,
"ModelCard": {
"ModelCardStatus": "Draft",
"ModelCardContent": '{"model_overview": {"model_creator": "updatedCreator", "model_artifact": []}}',
},
}

MODEL_DATA = {
Expand Down Expand Up @@ -442,3 +447,48 @@ def test_update_source_uri(sagemaker_session):
sagemaker_session.sagemaker_client.update_model_package.assert_called_with(
ModelPackageArn=MODEL_PACKAGE_VERSIONED_ARN, SourceUri=source_uri
)


def test_update_model_card(sagemaker_session):
model_package_response = copy.deepcopy(DESCRIBE_MODEL_PACKAGE_RESPONSE)

sagemaker_session.sagemaker_client.describe_model_package = Mock(
return_value=model_package_response
)
model_package = ModelPackage(
role="role",
model_package_arn=MODEL_PACKAGE_VERSIONED_ARN,
sagemaker_session=sagemaker_session,
)

update_my_card = ModelCard(
name="UpdateTestName",
sagemaker_session=sagemaker_session,
status=ModelCardStatusEnum.PENDING_REVIEW,
)
model_package.update_model_card(update_my_card)
update_my_card_req = update_my_card._create_request_args()
del update_my_card_req["ModelCardName"]
del update_my_card_req["Content"]
sagemaker_session.sagemaker_client.update_model_package.assert_called_with(
ModelPackageArn=MODEL_PACKAGE_VERSIONED_ARN, ModelCard=update_my_card_req
)

model_overview = ModelOverview(
model_creator="UpdatedNewCreator",
)
update_my_card_1 = ModelCard(
name="UpdateTestName",
sagemaker_session=sagemaker_session,
status=ModelCardStatusEnum.DRAFT,
model_overview=model_overview,
)
model_package.update_model_card(update_my_card_1)
update_my_card_req_1 = update_my_card_1._create_request_args()
del update_my_card_req_1["ModelCardName"]
del update_my_card_req_1["ModelCardStatus"]
update_my_card_req_1["ModelCardContent"] = update_my_card_req_1["Content"]
del update_my_card_req_1["Content"]
sagemaker_session.sagemaker_client.update_model_package.assert_called_with(
ModelPackageArn=MODEL_PACKAGE_VERSIONED_ARN, ModelCard=update_my_card_req_1
)