From 118f4d9d3b657795e1f6eff41668c5e238d11500 Mon Sep 17 00:00:00 2001 From: selvask Date: Tue, 11 Jun 2024 15:27:34 +0530 Subject: [PATCH 1/3] feature: add 'ModelCard' property to RegisterModel step --- src/sagemaker/chainer/model.py | 8 + src/sagemaker/estimator.py | 4 + src/sagemaker/huggingface/model.py | 8 + src/sagemaker/jumpstart/factory/model.py | 3 + src/sagemaker/jumpstart/model.py | 8 + src/sagemaker/jumpstart/types.py | 4 + src/sagemaker/model.py | 8 + src/sagemaker/model_card/__init__.py | 1 + src/sagemaker/model_card/model_card.py | 26 ++ src/sagemaker/mxnet/model.py | 8 + src/sagemaker/pipeline.py | 8 + src/sagemaker/pytorch/model.py | 8 + src/sagemaker/session.py | 21 + src/sagemaker/sklearn/model.py | 8 + src/sagemaker/tensorflow/model.py | 8 + src/sagemaker/workflow/_utils.py | 5 + src/sagemaker/workflow/step_collections.py | 5 +- src/sagemaker/xgboost/model.py | 8 + .../test_model_create_and_registration.py | 430 ++++++++++++++++++ tests/unit/test_estimator.py | 15 +- tests/unit/test_session.py | 19 + 21 files changed, 611 insertions(+), 2 deletions(-) diff --git a/src/sagemaker/chainer/model.py b/src/sagemaker/chainer/model.py index 9fce051454..59c8310587 100644 --- a/src/sagemaker/chainer/model.py +++ b/src/sagemaker/chainer/model.py @@ -28,6 +28,10 @@ from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME from sagemaker.chainer import defaults from sagemaker.deserializers import NumpyDeserializer +from sagemaker.model_card import ( + ModelCard, + ModelPackageModelCard, +) from sagemaker.predictor import Predictor from sagemaker.serializers import NumpySerializer from sagemaker.utils import to_string @@ -175,6 +179,7 @@ def register( data_input_configuration: Optional[Union[str, PipelineVariable]] = None, skip_model_validation: Optional[Union[str, PipelineVariable]] = None, source_uri: Optional[Union[str, PipelineVariable]] = None, + model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -226,6 +231,8 @@ def register( validation. Values can be "All" or "None" (default: None). source_uri (str or PipelineVariable): The URI of the source for the model package (default: None). + model_card (ModeCard or ModelPackageModelCard): document contains qualitative and + quantitative information about a model (default: None). Returns: str: A string of SageMaker Model Package ARN. @@ -266,6 +273,7 @@ def register( data_input_configuration=data_input_configuration, skip_model_validation=skip_model_validation, source_uri=source_uri, + model_card=model_card, ) def prepare_container_def( diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index 58a5fabc2f..b6af6cf5de 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -1724,6 +1724,7 @@ def register( data_input_configuration=None, skip_model_validation=None, source_uri=None, + model_card=None, **kwargs, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -1772,6 +1773,8 @@ def register( skip_model_validation (str): Indicates if you want to skip model validation. Values can be "All" or "None" (default: None). source_uri (str): The URI of the source for the model package (default: None). + model_card (ModeCard or ModelPackageModelCard): document contains qualitative and + quantitative information about a model (default: None). **kwargs: Passed to invocation of ``create_model()``. Implementations may customize ``create_model()`` to accept ``**kwargs`` to customize model creation during deploy. For more, see the implementation docs. @@ -1817,6 +1820,7 @@ def register( data_input_configuration=data_input_configuration, skip_model_validation=skip_model_validation, source_uri=source_uri, + model_card=model_card, ) @property diff --git a/src/sagemaker/huggingface/model.py b/src/sagemaker/huggingface/model.py index 662baecae6..8c1978c156 100644 --- a/src/sagemaker/huggingface/model.py +++ b/src/sagemaker/huggingface/model.py @@ -26,6 +26,10 @@ ) from sagemaker.metadata_properties import MetadataProperties from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME +from sagemaker.model_card import ( + ModelCard, + ModelPackageModelCard, +) from sagemaker.predictor import Predictor from sagemaker.serializers import JSONSerializer from sagemaker.session import Session @@ -362,6 +366,7 @@ def register( data_input_configuration: Optional[Union[str, PipelineVariable]] = None, skip_model_validation: Optional[Union[str, PipelineVariable]] = None, source_uri: Optional[Union[str, PipelineVariable]] = None, + model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -414,6 +419,8 @@ def register( validation. Values can be "All" or "None" (default: None). source_uri (str or PipelineVariable): The URI of the source for the model package (default: None). + model_card (ModeCard or ModelPackageModelCard): document contains qualitative and + quantitative information about a model (default: None). Returns: A `sagemaker.model.ModelPackage` instance. @@ -462,6 +469,7 @@ def register( data_input_configuration=data_input_configuration, skip_model_validation=skip_model_validation, source_uri=source_uri, + model_card=model_card, ) def prepare_container_def( diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index 89b0578342..b4bfd8a348 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -33,6 +33,7 @@ JUMPSTART_DEFAULT_REGION_NAME, JUMPSTART_LOGGER, ) +from sagemaker.model_card.model_card import ModelCard, ModelPackageModelCard from sagemaker.model_metrics import ModelMetrics from sagemaker.metadata_properties import MetadataProperties from sagemaker.drift_check_baselines import DriftCheckBaselines @@ -646,6 +647,7 @@ def get_register_kwargs( data_input_configuration: Optional[str] = None, skip_model_validation: Optional[str] = None, source_uri: Optional[str] = None, + model_card: Optional[Dict[ModelCard, ModelPackageModelCard]] = None, ) -> JumpStartModelRegisterKwargs: """Returns kwargs required to call `register` on `sagemaker.estimator.Model` object.""" @@ -678,6 +680,7 @@ def get_register_kwargs( data_input_configuration=data_input_configuration, skip_model_validation=skip_model_validation, source_uri=source_uri, + model_card=model_card, ) model_specs = verify_model_region_and_return_specs( diff --git a/src/sagemaker/jumpstart/model.py b/src/sagemaker/jumpstart/model.py index 994193de3e..205b3bb08d 100644 --- a/src/sagemaker/jumpstart/model.py +++ b/src/sagemaker/jumpstart/model.py @@ -43,6 +43,10 @@ ) from sagemaker.jumpstart.constants import JUMPSTART_LOGGER from sagemaker.jumpstart.enums import JumpStartModelType +from sagemaker.model_card import ( + ModelCard, + ModelPackageModelCard, +) from sagemaker.utils import stringify_object, format_tags, Tags from sagemaker.model import ( Model, @@ -692,6 +696,7 @@ def register( data_input_configuration: Optional[Union[str, PipelineVariable]] = None, skip_model_validation: Optional[Union[str, PipelineVariable]] = None, source_uri: Optional[Union[str, PipelineVariable]] = None, + model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -739,6 +744,8 @@ def register( validation. Values can be "All" or "None" (default: None). source_uri (str or PipelineVariable): The URI of the source for the model package (default: None). + model_card (ModeCard or ModelPackageModelCard): document contains qualitative and + quantitative information about a model (default: None). Returns: A `sagemaker.model.ModelPackage` instance. @@ -773,6 +780,7 @@ def register( data_input_configuration=data_input_configuration, skip_model_validation=skip_model_validation, source_uri=source_uri, + model_card=model_card, ) model_package = super(JumpStartModel, self).register(**register_kwargs.to_kwargs_dict()) diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index bea125d423..5754704632 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -15,6 +15,7 @@ from copy import deepcopy from enum import Enum from typing import Any, Dict, List, Optional, Set, Union +from sagemaker.model_card.model_card import ModelCard, ModelPackageModelCard from sagemaker.utils import get_instance_type_family, format_tags, Tags, deep_override_dict from sagemaker.model_metrics import ModelMetrics from sagemaker.metadata_properties import MetadataProperties @@ -2114,6 +2115,7 @@ class JumpStartModelRegisterKwargs(JumpStartKwargs): "data_input_configuration", "skip_model_validation", "source_uri", + "model_card", ] SERIALIZATION_EXCLUSION_SET = { @@ -2155,6 +2157,7 @@ def __init__( data_input_configuration: Optional[str] = None, skip_model_validation: Optional[str] = None, source_uri: Optional[str] = None, + model_card: Optional[Dict[ModelCard, ModelPackageModelCard]] = None, ) -> None: """Instantiates JumpStartModelRegisterKwargs object.""" @@ -2187,3 +2190,4 @@ def __init__( self.data_input_configuration = data_input_configuration self.skip_model_validation = skip_model_validation self.source_uri = source_uri + self.model_card = model_card diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index 1bb6cb2e5c..5c5156c84a 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -44,6 +44,10 @@ ENDPOINT_CONFIG_ASYNC_KMS_KEY_ID_PATH, load_sagemaker_config, ) +from sagemaker.model_card import ( + ModelCard, + ModelPackageModelCard, +) from sagemaker.model_card.schema_constraints import ModelApprovalStatusEnum from sagemaker.session import Session from sagemaker.model_metrics import ModelMetrics @@ -428,6 +432,7 @@ def register( data_input_configuration: Optional[Union[str, PipelineVariable]] = None, skip_model_validation: Optional[Union[str, PipelineVariable]] = None, source_uri: Optional[Union[str, PipelineVariable]] = None, + model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -479,6 +484,8 @@ def register( validation. Values can be "All" or "None" (default: None). source_uri (str or PipelineVariable): The URI of the source for the model package (default: None). + model_card (ModeCard or ModelPackageModelCard): document contains qualitative and + quantitative information about a model (default: None). Returns: A `sagemaker.model.ModelPackage` instance or pipeline step arguments @@ -545,6 +552,7 @@ def register( task=task, skip_model_validation=skip_model_validation, source_uri=source_uri, + model_card=model_card, ) model_package = self.sagemaker_session.create_model_package_from_containers( **model_pkg_args diff --git a/src/sagemaker/model_card/__init__.py b/src/sagemaker/model_card/__init__.py index 679da42a3f..b7a7d24dc7 100644 --- a/src/sagemaker/model_card/__init__.py +++ b/src/sagemaker/model_card/__init__.py @@ -29,6 +29,7 @@ AdditionalInformation, ModelCard, ModelPackage, + ModelPackageModelCard, ) from sagemaker.model_card.schema_constraints import ( # noqa: F401 # pylint: disable=unused-import diff --git a/src/sagemaker/model_card/model_card.py b/src/sagemaker/model_card/model_card.py index 33af98723f..e13905aabf 100644 --- a/src/sagemaker/model_card/model_card.py +++ b/src/sagemaker/model_card/model_card.py @@ -1883,3 +1883,29 @@ def list_export_jobs( return sagemaker_session.sagemaker_client.list_model_card_export_jobs( ModelCardName=model_card_name, **kwargs ) + + +class ModelPackageModelCard(object): + """Use an Amazon SageMaker Model Card to document qualitative and quantitative information about a model.""" # noqa E501 # pylint: disable=c0301 + + def __init__( + self, + model_card_content: dict[str], + model_card_status: str, + ): + + self.model_card_content = model_card_content + self.model_card_status = model_card_status + + def _create_request_args(self): + """Generate the request body for create model card call. + + Args: + model_card_content dict[str]: Content of the model card. + model_card_status (str): Status of the model card you want to export. + + """ # noqa E501 # pylint: disable=line-too-long + request_args = {} + request_args["ModelCardStatus"] = self.model_card_status + request_args["Content"] = json.dumps(self.model_card_content, cls=_JSONEncoder) + return request_args diff --git a/src/sagemaker/mxnet/model.py b/src/sagemaker/mxnet/model.py index 714b0db945..8d389e9f59 100644 --- a/src/sagemaker/mxnet/model.py +++ b/src/sagemaker/mxnet/model.py @@ -29,6 +29,10 @@ ) from sagemaker.metadata_properties import MetadataProperties from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME +from sagemaker.model_card import ( + ModelCard, + ModelPackageModelCard, +) from sagemaker.mxnet import defaults from sagemaker.predictor import Predictor from sagemaker.serializers import JSONSerializer @@ -177,6 +181,7 @@ def register( data_input_configuration: Optional[Union[str, PipelineVariable]] = None, skip_model_validation: Optional[Union[str, PipelineVariable]] = None, source_uri: Optional[Union[str, PipelineVariable]] = None, + model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -228,6 +233,8 @@ def register( validation. Values can be "All" or "None" (default: None). source_uri (str or PipelineVariable): The URI of the source for the model package (default: None). + model_card (ModeCard or ModelPackageModelCard): document contains qualitative and + quantitative information about a model (default: None). Returns: A `sagemaker.model.ModelPackage` instance. @@ -268,6 +275,7 @@ def register( data_input_configuration=data_input_configuration, skip_model_validation=skip_model_validation, source_uri=source_uri, + model_card=model_card, ) def prepare_container_def( diff --git a/src/sagemaker/pipeline.py b/src/sagemaker/pipeline.py index 3bfdb1a594..b5a3cd4357 100644 --- a/src/sagemaker/pipeline.py +++ b/src/sagemaker/pipeline.py @@ -26,6 +26,10 @@ ) from sagemaker.drift_check_baselines import DriftCheckBaselines from sagemaker.metadata_properties import MetadataProperties +from sagemaker.model_card import ( + ModelCard, + ModelPackageModelCard, +) from sagemaker.session import Session from sagemaker.utils import ( name_from_image, @@ -361,6 +365,7 @@ def register( data_input_configuration: Optional[Union[str, PipelineVariable]] = None, skip_model_validation: Optional[Union[str, PipelineVariable]] = None, source_uri: Optional[Union[str, PipelineVariable]] = None, + model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -412,6 +417,8 @@ def register( validation. Values can be "All" or "None" (default: None). source_uri (str or PipelineVariable): The URI of the source for the model package (default: None). + model_card (ModeCard or ModelPackageModelCard): document contains qualitative and + quantitative information about a model (default: None). Returns: If ``sagemaker_session`` is a ``PipelineSession`` instance, returns pipeline step @@ -460,6 +467,7 @@ def register( task=task, skip_model_validation=skip_model_validation, source_uri=source_uri, + model_card=model_card, ) self.sagemaker_session.create_model_package_from_containers(**model_pkg_args) diff --git a/src/sagemaker/pytorch/model.py b/src/sagemaker/pytorch/model.py index f490e49375..6d915772cd 100644 --- a/src/sagemaker/pytorch/model.py +++ b/src/sagemaker/pytorch/model.py @@ -29,6 +29,10 @@ ) from sagemaker.metadata_properties import MetadataProperties from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME +from sagemaker.model_card import ( + ModelCard, + ModelPackageModelCard, +) from sagemaker.pytorch import defaults from sagemaker.predictor import Predictor from sagemaker.serializers import NumpySerializer @@ -179,6 +183,7 @@ def register( data_input_configuration: Optional[Union[str, PipelineVariable]] = None, skip_model_validation: Optional[Union[str, PipelineVariable]] = None, source_uri: Optional[Union[str, PipelineVariable]] = None, + model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -230,6 +235,8 @@ def register( validation. Values can be "All" or "None" (default: None). source_uri (str or PipelineVariable): The URI of the source for the model package (default: None). + model_card (ModeCard or ModelPackageModelCard): document contains qualitative and + quantitative information about a model (default: None). Returns: A `sagemaker.model.ModelPackage` instance. @@ -270,6 +277,7 @@ def register( data_input_configuration=data_input_configuration, skip_model_validation=skip_model_validation, source_uri=source_uri, + model_card=model_card, ) def prepare_container_def( diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index bf2a736871..5b1894f49e 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -4055,6 +4055,7 @@ def create_model_package_from_containers( task=None, skip_model_validation="None", source_uri=None, + model_card=None, ): """Get request dictionary for CreateModelPackage API. @@ -4092,6 +4093,8 @@ def create_model_package_from_containers( skip_model_validation (str): Indicates if you want to skip model validation. Values can be "All" or "None" (default: None). source_uri (str): The URI of the source for the model package (default: None). + model_card (ModeCard or ModelPackageModelCard): document contains qualitative and + quantitative information about a model (default: None). """ if containers: # Containers are provided. Now we can merge missing entries from config. @@ -4149,6 +4152,7 @@ def create_model_package_from_containers( task=task, skip_model_validation=skip_model_validation, source_uri=source_uri, + model_card=model_card, ) def submit(request): @@ -6730,6 +6734,7 @@ def get_model_package_args( task=None, skip_model_validation=None, source_uri=None, + model_card=None, ): """Get arguments for create_model_package method. @@ -6769,6 +6774,8 @@ def get_model_package_args( skip_model_validation (str): Indicates if you want to skip model validation. Values can be "All" or "None" (default: None). source_uri (str): The URI of the source for the model package (default: None). + model_card (ModeCard or ModelPackageModelCard): document contains qualitative and + quantitative information about a model (default: None). Returns: dict: A dictionary of method argument names and values. @@ -6825,6 +6832,14 @@ def get_model_package_args( model_package_args["skip_model_validation"] = skip_model_validation if source_uri is not None: model_package_args["source_uri"] = source_uri + if model_card is not None: + original_req = model_card._create_request_args() + if original_req.get("ModelCardName") is not None: + del original_req["ModelCardName"] + if original_req.get("Content") is not None: + original_req["ModelCardContent"] = original_req["Content"] + del original_req["Content"] + model_package_args["model_card"] = original_req return model_package_args @@ -6850,6 +6865,7 @@ def get_create_model_package_request( task=None, skip_model_validation="None", source_uri=None, + model_card=None, ): """Get request dictionary for CreateModelPackage API. @@ -6887,6 +6903,8 @@ def get_create_model_package_request( skip_model_validation (str): Indicates if you want to skip model validation. Values can be "All" or "None" (default: None). source_uri (str): The URI of the source for the model package (default: None). + model_card (ModeCard or ModelPackageModelCard): document contains qualitative and + quantitative information about a model (default: None). """ if all([model_package_name, model_package_group_name]): @@ -6984,6 +7002,9 @@ def get_create_model_package_request( request_dict["CertifyForMarketplace"] = marketplace_cert request_dict["ModelApprovalStatus"] = approval_status request_dict["SkipModelValidation"] = skip_model_validation + if model_card is not None: + request_dict["ModelCard"] = model_card + return request_dict diff --git a/src/sagemaker/sklearn/model.py b/src/sagemaker/sklearn/model.py index 27833c1d9c..82d9510e53 100644 --- a/src/sagemaker/sklearn/model.py +++ b/src/sagemaker/sklearn/model.py @@ -23,6 +23,10 @@ from sagemaker.fw_utils import model_code_key_prefix, validate_version_or_image_args from sagemaker.metadata_properties import MetadataProperties from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME +from sagemaker.model_card import ( + ModelCard, + ModelPackageModelCard, +) from sagemaker.predictor import Predictor from sagemaker.serializers import NumpySerializer from sagemaker.sklearn import defaults @@ -172,6 +176,7 @@ def register( data_input_configuration: Optional[Union[str, PipelineVariable]] = None, skip_model_validation: Optional[Union[str, PipelineVariable]] = None, source_uri: Optional[Union[str, PipelineVariable]] = None, + model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -223,6 +228,8 @@ def register( validation. Values can be "All" or "None" (default: None). source_uri (str or PipelineVariable): The URI of the source for the model package (default: None). + model_card (ModeCard or ModelPackageModelCard): document contains qualitative and + quantitative information about a model (default: None). Returns: A `sagemaker.model.ModelPackage` instance. @@ -263,6 +270,7 @@ def register( data_input_configuration=data_input_configuration, skip_model_validation=skip_model_validation, source_uri=source_uri, + model_card=model_card, ) def prepare_container_def( diff --git a/src/sagemaker/tensorflow/model.py b/src/sagemaker/tensorflow/model.py index 77f162207c..4a22f1abcb 100644 --- a/src/sagemaker/tensorflow/model.py +++ b/src/sagemaker/tensorflow/model.py @@ -22,6 +22,10 @@ from sagemaker.deprecations import removed_kwargs from sagemaker.drift_check_baselines import DriftCheckBaselines from sagemaker.metadata_properties import MetadataProperties +from sagemaker.model_card import ( + ModelCard, + ModelPackageModelCard, +) from sagemaker.predictor import Predictor from sagemaker.serializers import JSONSerializer from sagemaker.workflow import is_pipeline_variable @@ -234,6 +238,7 @@ def register( data_input_configuration: Optional[Union[str, PipelineVariable]] = None, skip_model_validation: Optional[Union[str, PipelineVariable]] = None, source_uri: Optional[Union[str, PipelineVariable]] = None, + model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -285,6 +290,8 @@ def register( validation. Values can be "All" or "None" (default: None). source_uri (str or PipelineVariable): The URI of the source for the model package (default: None). + model_card (ModeCard or ModelPackageModelCard): document contains qualitative and + quantitative information about a model (default: None). Returns: A `sagemaker.model.ModelPackage` instance. @@ -325,6 +332,7 @@ def register( data_input_configuration=data_input_configuration, skip_model_validation=skip_model_validation, source_uri=source_uri, + model_card=model_card, ) def deploy( diff --git a/src/sagemaker/workflow/_utils.py b/src/sagemaker/workflow/_utils.py index 841cd68083..e405d1034a 100644 --- a/src/sagemaker/workflow/_utils.py +++ b/src/sagemaker/workflow/_utils.py @@ -329,6 +329,7 @@ def __init__( task=None, skip_model_validation=None, source_uri=None, + model_card=None, **kwargs, ): """Constructor of a register model step. @@ -381,6 +382,8 @@ def __init__( skip_model_validation (str): Indicates if you want to skip model validation. Values can be "All" or "None" (default: None). source_uri (str): The URI of the source for the model package (default: None). + model_card (ModeCard or ModelPackageModelCard): document contains qualitative and + quantitative information about a model (default: None). **kwargs: additional arguments to `create_model`. """ super(_RegisterModelStep, self).__init__( @@ -418,6 +421,7 @@ def __init__( self.container_def_list = container_def_list self.skip_model_validation = skip_model_validation self.source_uri = source_uri + self.model_card = model_card self._properties = Properties( step_name=name, step=self, shape_name="DescribeModelPackageOutput" @@ -493,6 +497,7 @@ def arguments(self) -> RequestType: task=self.task, skip_model_validation=self.skip_model_validation, source_uri=self.source_uri, + model_card=self.model_card, ) request_dict = get_create_model_package_request(**model_package_args) diff --git a/src/sagemaker/workflow/step_collections.py b/src/sagemaker/workflow/step_collections.py index 0eedf4aa96..c88c82efa9 100644 --- a/src/sagemaker/workflow/step_collections.py +++ b/src/sagemaker/workflow/step_collections.py @@ -97,6 +97,7 @@ def __init__( data_input_configuration=None, skip_model_validation=None, source_uri=None, + model_card=None, **kwargs, ): """Construct steps `_RepackModelStep` and `_RegisterModelStep` based on the estimator. @@ -155,7 +156,8 @@ def __init__( skip_model_validation (str): Indicates if you want to skip model validation. Values can be "All" or "None" (default: None). source_uri (str): The URI of the source for the model package (default: None). - + model_card (ModeCard or ModelPackageModelCard): document contains qualitative and + quantitative information about a model (default: None). **kwargs: additional arguments to `create_model`. """ super().__init__(name=name, depends_on=depends_on) @@ -294,6 +296,7 @@ def __init__( task=task, skip_model_validation=skip_model_validation, source_uri=source_uri, + model_card=model_card, **kwargs, ) if not repack_model: diff --git a/src/sagemaker/xgboost/model.py b/src/sagemaker/xgboost/model.py index 8101f32721..157f3cb8fd 100644 --- a/src/sagemaker/xgboost/model.py +++ b/src/sagemaker/xgboost/model.py @@ -23,6 +23,10 @@ from sagemaker.fw_utils import model_code_key_prefix from sagemaker.metadata_properties import MetadataProperties from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME +from sagemaker.model_card import ( + ModelCard, + ModelPackageModelCard, +) from sagemaker.predictor import Predictor from sagemaker.serializers import LibSVMSerializer from sagemaker.utils import to_string @@ -160,6 +164,7 @@ def register( data_input_configuration: Optional[Union[str, PipelineVariable]] = None, skip_model_validation: Optional[Union[str, PipelineVariable]] = None, source_uri: Optional[Union[str, PipelineVariable]] = None, + model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -211,6 +216,8 @@ def register( validation. Values can be "All" or "None" (default: None). source_uri (str or PipelineVariable): The URI of the source for the model package (default: None). + model_card (ModeCard or ModelPackageModelCard): document contains qualitative and + quantitative information about a model (default: None). Returns: str: A string of SageMaker Model Package ARN. @@ -251,6 +258,7 @@ def register( data_input_configuration=data_input_configuration, skip_model_validation=skip_model_validation, source_uri=source_uri, + model_card=model_card, ) def prepare_container_def( diff --git a/tests/integ/sagemaker/workflow/test_model_create_and_registration.py b/tests/integ/sagemaker/workflow/test_model_create_and_registration.py index 0733649cb2..1b1cadc092 100644 --- a/tests/integ/sagemaker/workflow/test_model_create_and_registration.py +++ b/tests/integ/sagemaker/workflow/test_model_create_and_registration.py @@ -18,12 +18,15 @@ # and the RegisterModel and CreateModelStep have been replaced with the new interface - ModelStep from __future__ import absolute_import +import json import logging import os import re import pytest +from sagemaker.model_card.model_card import ModelCard, ModelOverview, ModelPackageModelCard +from sagemaker.model_card.schema_constraints import ModelCardStatusEnum import tests from tests.integ.sagemaker.workflow.helpers import wait_pipeline_execution from sagemaker.tensorflow import TensorFlow, TensorFlowModel @@ -56,6 +59,15 @@ ) from tests.integ.kms_utils import get_or_create_kms_key from tests.integ import DATA_DIR +from sagemaker.model_card import ( + IntendedUses, + BusinessDetails, + EvaluationJob, + AdditionalInformation, + Metric, + MetricGroup, + MetricTypeEnum, +) @pytest.fixture @@ -703,6 +715,424 @@ def test_model_registration_with_drift_check_baselines( pass +def test_model_registration_with_model_card_object( + sagemaker_session_for_pipeline, + role, + pipeline_name, +): + instance_count = ParameterInteger(name="InstanceCount", default_value=1) + instance_type = "ml.m5.xlarge" + + # upload model data to s3 + model_local_path = os.path.join(DATA_DIR, "mxnet_mnist/model.tar.gz") + model_base_uri = "s3://{}/{}/input/model/{}".format( + sagemaker_session_for_pipeline.default_bucket(), + "register_model_test_with_drift_baseline", + utils.unique_name_from_base("model"), + ) + model_uri = S3Uploader.upload( + model_local_path, model_base_uri, sagemaker_session=sagemaker_session_for_pipeline + ) + model_uri_param = ParameterString(name="model_uri", default_value=model_uri) + + # upload metrics to s3 + metrics_data = ( + '{"regression_metrics": {"mse": {"value": 4.925353410353891, ' + '"standard_deviation": 2.219186917819692}}}' + ) + metrics_base_uri = "s3://{}/{}/input/metrics/{}".format( + sagemaker_session_for_pipeline.default_bucket(), + "register_model_test_with_drift_baseline", + utils.unique_name_from_base("metrics"), + ) + metrics_uri = S3Uploader.upload_string_as_file_body( + body=metrics_data, + desired_s3_uri=metrics_base_uri, + sagemaker_session=sagemaker_session_for_pipeline, + ) + metrics_uri_param = ParameterString(name="metrics_uri", default_value=metrics_uri) + + model_metrics = ModelMetrics( + bias=MetricsSource( + s3_uri=metrics_uri_param, + content_type="application/json", + ), + explainability=MetricsSource( + s3_uri=metrics_uri_param, + content_type="application/json", + ), + bias_pre_training=MetricsSource( + s3_uri=metrics_uri_param, + content_type="application/json", + ), + bias_post_training=MetricsSource( + s3_uri=metrics_uri_param, + content_type="application/json", + ), + ) + customer_metadata_properties = {"key1": "value1"} + domain = "COMPUTER_VISION" + task = "IMAGE_CLASSIFICATION" + sample_payload_url = "s3://test-bucket/model" + framework = "TENSORFLOW" + framework_version = "2.9" + nearest_model_name = "resnet50" + data_input_configuration = '{"input_1":[1,224,224,3]}' + skip_model_validation = "All" + + # If image_uri is not provided, the instance_type should not be a pipeline variable + # since instance_type is used to retrieve image_uri in compile time (PySDK) + estimator = XGBoost( + entry_point="training.py", + source_dir=os.path.join(DATA_DIR, "sip"), + instance_type=instance_type, + instance_count=instance_count, + framework_version="0.90-2", + sagemaker_session=sagemaker_session_for_pipeline, + py_version="py3", + role=role, + ) + intended_uses = IntendedUses( + purpose_of_model="Test model card.", + intended_uses="Not used except this test.", + factors_affecting_model_efficiency="No.", + risk_rating="Low", + explanations_for_risk_rating="Just an example.", + ) + business_details = BusinessDetails( + business_problem="The business problem that your model is used to solve.", + business_stakeholders="The stakeholders who have the interest in the business that your model is used for.", + line_of_business="Services that the business is offering.", + ) + additional_information = AdditionalInformation( + ethical_considerations="Your model ethical consideration.", + caveats_and_recommendations="Your model's caveats and recommendations.", + custom_details={"custom details1": "details value"}, + ) + manual_metric_group = MetricGroup( + name="binary classification metrics", + metric_data=[Metric(name="accuracy", type=MetricTypeEnum.NUMBER, value=0.5)], + ) + example_evaluation_job = EvaluationJob( + name="Example evaluation job", + evaluation_observation="Evaluation observations.", + datasets=["s3://path/to/evaluation/data"], + metric_groups=[manual_metric_group], + ) + evaluation_details = [example_evaluation_job] + + model_overview = ModelOverview(model_creator="TestCreator") + + my_card = ModelCard( + name="TestName", + status=ModelCardStatusEnum.DRAFT, + model_overview=model_overview, + intended_uses=intended_uses, + business_details=business_details, + evaluation_details=evaluation_details, + additional_information=additional_information, + ) + + step_register = RegisterModel( + name="MyRegisterModelStep", + estimator=estimator, + model_data=model_uri_param, + content_types=["application/json"], + response_types=["application/json"], + inference_instances=["ml.t2.medium", "ml.m5.xlarge"], + transform_instances=["ml.m5.xlarge"], + model_package_group_name="testModelPackageGroup", + model_metrics=model_metrics, + customer_metadata_properties=customer_metadata_properties, + domain=domain, + sample_payload_url=sample_payload_url, + task=task, + framework=framework, + framework_version=framework_version, + nearest_model_name=nearest_model_name, + data_input_configuration=data_input_configuration, + skip_model_validation=skip_model_validation, + model_card=my_card, + ) + + pipeline = Pipeline( + name=pipeline_name, + parameters=[ + model_uri_param, + metrics_uri_param, + instance_count, + ], + steps=[step_register], + sagemaker_session=sagemaker_session_for_pipeline, + ) + + try: + response = pipeline.create(role) + create_arn = response["PipelineArn"] + + for _ in retries( + max_retry_count=5, + exception_message_prefix="Waiting for a successful execution of pipeline", + seconds_to_sleep=10, + ): + execution = pipeline.start( + parameters={"model_uri": model_uri, "metrics_uri": metrics_uri} + ) + response = execution.describe() + + assert response["PipelineArn"] == create_arn + + wait_pipeline_execution(execution=execution) + execution_steps = execution.list_steps() + + assert len(execution_steps) == 1 + failure_reason = execution_steps[0].get("FailureReason", "") + if failure_reason != "": + logging.error( + f"Pipeline execution failed with error: {failure_reason}." " Retrying.." + ) + continue + assert execution_steps[0]["StepStatus"] == "Succeeded" + assert execution_steps[0]["StepName"] == "MyRegisterModelStep-RegisterModel" + + response = sagemaker_session_for_pipeline.sagemaker_client.describe_model_package( + ModelPackageName=execution_steps[0]["Metadata"]["RegisterModel"]["Arn"] + ) + + assert ( + response["ModelMetrics"]["Explainability"]["Report"]["ContentType"] + == "application/json" + ) + assert response["CustomerMetadataProperties"] == customer_metadata_properties + assert response["Domain"] == domain + assert response["Task"] == task + assert response["SamplePayloadUrl"] == sample_payload_url + assert response["SkipModelValidation"] == skip_model_validation + assert (response["ModelCard"]["ModelCardStatus"]) == ModelCardStatusEnum.DRAFT + model_card_content = json.loads(response["ModelCard"]["ModelCardContent"]) + assert (model_card_content["model_overview"]["model_creator"]) == "TestCreator" + assert (model_card_content["intended_uses"]["purpose_of_model"]) == "Test model card." + assert ( + model_card_content["business_details"]["line_of_business"] + ) == "Services that the business is offering." + assert (model_card_content["evaluation_details"][0]["name"]) == "Example evaluation job" + + break + finally: + try: + pipeline.delete() + except Exception: + pass + + +def test_model_registration_with_model_card_json( + sagemaker_session_for_pipeline, + role, + pipeline_name, +): + instance_count = ParameterInteger(name="InstanceCount", default_value=1) + instance_type = "ml.m5.xlarge" + + # upload model data to s3 + model_local_path = os.path.join(DATA_DIR, "mxnet_mnist/model.tar.gz") + model_base_uri = "s3://{}/{}/input/model/{}".format( + sagemaker_session_for_pipeline.default_bucket(), + "register_model_test_with_drift_baseline", + utils.unique_name_from_base("model"), + ) + model_uri = S3Uploader.upload( + model_local_path, model_base_uri, sagemaker_session=sagemaker_session_for_pipeline + ) + model_uri_param = ParameterString(name="model_uri", default_value=model_uri) + + # upload metrics to s3 + metrics_data = ( + '{"regression_metrics": {"mse": {"value": 4.925353410353891, ' + '"standard_deviation": 2.219186917819692}}}' + ) + metrics_base_uri = "s3://{}/{}/input/metrics/{}".format( + sagemaker_session_for_pipeline.default_bucket(), + "register_model_test_with_drift_baseline", + utils.unique_name_from_base("metrics"), + ) + metrics_uri = S3Uploader.upload_string_as_file_body( + body=metrics_data, + desired_s3_uri=metrics_base_uri, + sagemaker_session=sagemaker_session_for_pipeline, + ) + metrics_uri_param = ParameterString(name="metrics_uri", default_value=metrics_uri) + + model_metrics = ModelMetrics( + bias=MetricsSource( + s3_uri=metrics_uri_param, + content_type="application/json", + ), + explainability=MetricsSource( + s3_uri=metrics_uri_param, + content_type="application/json", + ), + bias_pre_training=MetricsSource( + s3_uri=metrics_uri_param, + content_type="application/json", + ), + bias_post_training=MetricsSource( + s3_uri=metrics_uri_param, + content_type="application/json", + ), + ) + customer_metadata_properties = {"key1": "value1"} + domain = "COMPUTER_VISION" + task = "IMAGE_CLASSIFICATION" + sample_payload_url = "s3://test-bucket/model" + framework = "TENSORFLOW" + framework_version = "2.9" + nearest_model_name = "resnet50" + data_input_configuration = '{"input_1":[1,224,224,3]}' + skip_model_validation = "All" + + # If image_uri is not provided, the instance_type should not be a pipeline variable + # since instance_type is used to retrieve image_uri in compile time (PySDK) + estimator = XGBoost( + entry_point="training.py", + source_dir=os.path.join(DATA_DIR, "sip"), + instance_type=instance_type, + instance_count=instance_count, + framework_version="0.90-2", + sagemaker_session=sagemaker_session_for_pipeline, + py_version="py3", + role=role, + ) + + model_card_content = { + "model_overview": { + "model_creator": "TestCreator", + }, + "intended_uses": { + "purpose_of_model": "Test model card.", + "intended_uses": "Not used except this test.", + "factors_affecting_model_efficiency": "No.", + "risk_rating": "Low", + "explanations_for_risk_rating": "Just an example.", + }, + "business_details": { + "business_problem": "The business problem that your model is used to solve.", + "business_stakeholders": "The stakeholders who have the interest in the business.", + "line_of_business": "Services that the business is offering.", + }, + "evaluation_details": [ + { + "name": "Example evaluation job", + "evaluation_observation": "Evaluation observations.", + "metric_groups": [ + { + "name": "binary classification metrics", + "metric_data": [{"name": "accuracy", "type": "number", "value": 0.5}], + } + ], + } + ], + "additional_information": { + "ethical_considerations": "Your model ethical consideration.", + "caveats_and_recommendations": 'Your model"s caveats and recommendations.', + "custom_details": {"custom details1": "details value"}, + }, + } + my_card = ModelPackageModelCard( + model_card_status=ModelCardStatusEnum.DRAFT, model_card_content=model_card_content + ) + + step_register = RegisterModel( + name="MyRegisterModelStep", + estimator=estimator, + model_data=model_uri_param, + content_types=["application/json"], + response_types=["application/json"], + inference_instances=["ml.t2.medium", "ml.m5.xlarge"], + transform_instances=["ml.m5.xlarge"], + model_package_group_name="testModelPackageGroup", + model_metrics=model_metrics, + customer_metadata_properties=customer_metadata_properties, + domain=domain, + sample_payload_url=sample_payload_url, + task=task, + framework=framework, + framework_version=framework_version, + nearest_model_name=nearest_model_name, + data_input_configuration=data_input_configuration, + skip_model_validation=skip_model_validation, + model_card=my_card, + ) + + pipeline = Pipeline( + name=pipeline_name, + parameters=[ + model_uri_param, + metrics_uri_param, + instance_count, + ], + steps=[step_register], + sagemaker_session=sagemaker_session_for_pipeline, + ) + + try: + response = pipeline.create(role) + create_arn = response["PipelineArn"] + + for _ in retries( + max_retry_count=5, + exception_message_prefix="Waiting for a successful execution of pipeline", + seconds_to_sleep=10, + ): + execution = pipeline.start( + parameters={"model_uri": model_uri, "metrics_uri": metrics_uri} + ) + response = execution.describe() + + assert response["PipelineArn"] == create_arn + + wait_pipeline_execution(execution=execution) + execution_steps = execution.list_steps() + + assert len(execution_steps) == 1 + failure_reason = execution_steps[0].get("FailureReason", "") + if failure_reason != "": + logging.error( + f"Pipeline execution failed with error: {failure_reason}." " Retrying.." + ) + continue + assert execution_steps[0]["StepStatus"] == "Succeeded" + assert execution_steps[0]["StepName"] == "MyRegisterModelStep-RegisterModel" + + response = sagemaker_session_for_pipeline.sagemaker_client.describe_model_package( + ModelPackageName=execution_steps[0]["Metadata"]["RegisterModel"]["Arn"] + ) + + assert ( + response["ModelMetrics"]["Explainability"]["Report"]["ContentType"] + == "application/json" + ) + assert response["CustomerMetadataProperties"] == customer_metadata_properties + assert response["Domain"] == domain + assert response["Task"] == task + assert response["SamplePayloadUrl"] == sample_payload_url + assert response["SkipModelValidation"] == skip_model_validation + assert (response["ModelCard"]["ModelCardStatus"]) == ModelCardStatusEnum.DRAFT + model_card_content = json.loads(response["ModelCard"]["ModelCardContent"]) + assert (model_card_content["model_overview"]["model_creator"]) == "TestCreator" + assert (model_card_content["intended_uses"]["purpose_of_model"]) == "Test model card." + assert ( + model_card_content["business_details"]["line_of_business"] + ) == "Services that the business is offering." + assert (model_card_content["evaluation_details"][0]["name"]) == "Example evaluation job" + + break + finally: + try: + pipeline.delete() + except Exception: + pass + + def test_model_registration_with_model_repack( sagemaker_session_for_pipeline, role, diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index fd45601801..295f1a8d24 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -51,6 +51,8 @@ from sagemaker.instance_group import InstanceGroup from sagemaker.interactive_apps import SupportedInteractiveAppTypes from sagemaker.model import FrameworkModel +from sagemaker.model_card.model_card import ModelCard, ModelOverview +from sagemaker.model_card.schema_constraints import ModelCardStatusEnum from sagemaker.mxnet.estimator import MXNet from sagemaker.predictor import Predictor from sagemaker.pytorch.estimator import PyTorch @@ -4336,6 +4338,12 @@ def test_register_default_image(sagemaker_session): framework_version = "2.9" nearest_model_name = "resnet50" data_input_config = '{"input_1":[1,224,224,3]}' + model_overview = ModelOverview(model_creator="TestCreator") + model_card = ModelCard( + name="TestCard", + status=ModelCardStatusEnum.DRAFT, + model_overview=model_overview, + ) estimator.register( content_types=content_types, @@ -4349,9 +4357,13 @@ def test_register_default_image(sagemaker_session): framework_version=framework_version, nearest_model_name=nearest_model_name, data_input_configuration=data_input_config, + model_card=model_card, ) sagemaker_session.create_model.assert_not_called() - + exp_model_card = { + "ModelCardStatus": "Draft", + "ModelCardContent": '{"model_overview": {"model_creator": "TestCreator", "model_artifact": []}}', + } expected_create_model_package_request = { "containers": [{"Image": estimator.image_uri, "ModelDataUrl": estimator.model_data}], "content_types": content_types, @@ -4362,6 +4374,7 @@ def test_register_default_image(sagemaker_session): "marketplace_cert": False, "sample_payload_url": sample_payload_url, "task": task, + "model_card": exp_model_card, } sagemaker_session.create_model_package_from_containers.assert_called_with( **expected_create_model_package_request diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index f7dede1ce9..8ab186e27c 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -24,6 +24,8 @@ from botocore.exceptions import ClientError from mock import ANY, MagicMock, Mock, patch, call, mock_open +from sagemaker.model_card.schema_constraints import ModelCardStatusEnum + from .common import _raise_unexpected_client_error import sagemaker from sagemaker import TrainingInput, Session, get_execution_role, exceptions @@ -5343,6 +5345,21 @@ def test_create_model_package_from_containers_all_args(sagemaker_session): domain = "COMPUTER_VISION" task = "IMAGE_CLASSIFICATION" sample_payload_url = "s3://test-bucket/model" + model_card = { + "ModelCardStatus": ModelCardStatusEnum.DRAFT, + "Content": { + "model_overview": { + "model_creator": "TestCreator", + }, + "intended_uses": { + "purpose_of_model": "Test model card.", + "intended_uses": "Not used except this test.", + "factors_affecting_model_efficiency": "No.", + "risk_rating": "Low", + "explanations_for_risk_rating": "Just an example.", + }, + }, + } sagemaker_session.create_model_package_from_containers( containers=containers, content_types=content_types, @@ -5361,6 +5378,7 @@ def test_create_model_package_from_containers_all_args(sagemaker_session): sample_payload_url=sample_payload_url, task=task, skip_model_validation=skip_model_validation, + model_card=model_card, ) expected_args = { "ModelPackageName": model_package_name, @@ -5382,6 +5400,7 @@ def test_create_model_package_from_containers_all_args(sagemaker_session): "SamplePayloadUrl": sample_payload_url, "Task": task, "SkipModelValidation": skip_model_validation, + "ModelCard": model_card, } sagemaker_session.sagemaker_client.create_model_package.assert_called_with(**expected_args) From c05c7fb12d78fc13a310816746443887ef65a59f Mon Sep 17 00:00:00 2001 From: selvask Date: Thu, 13 Jun 2024 11:30:20 +0530 Subject: [PATCH 2/3] Updated ModelCard content type --- src/sagemaker/model_card/model_card.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/sagemaker/model_card/model_card.py b/src/sagemaker/model_card/model_card.py index e13905aabf..1765df8d68 100644 --- a/src/sagemaker/model_card/model_card.py +++ b/src/sagemaker/model_card/model_card.py @@ -16,7 +16,7 @@ import json import logging from datetime import datetime -from typing import Optional, Union, List, Any +from typing import Optional, Union, List, Any, Dict from botocore.exceptions import ClientError from boto3.session import Session as boto3_Session from six.moves.urllib.parse import urlparse @@ -1890,7 +1890,7 @@ class ModelPackageModelCard(object): def __init__( self, - model_card_content: dict[str], + model_card_content: Dict[str, Any], model_card_status: str, ): From c6981e2387369cc0523aa32ffc9e968901c8d2a0 Mon Sep 17 00:00:00 2001 From: selvask Date: Thu, 13 Jun 2024 21:54:11 +0530 Subject: [PATCH 3/3] fix: ModelCard Object integ Test fix --- .../sagemaker/workflow/test_model_create_and_registration.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/integ/sagemaker/workflow/test_model_create_and_registration.py b/tests/integ/sagemaker/workflow/test_model_create_and_registration.py index 1b1cadc092..0bdbc18c99 100644 --- a/tests/integ/sagemaker/workflow/test_model_create_and_registration.py +++ b/tests/integ/sagemaker/workflow/test_model_create_and_registration.py @@ -825,6 +825,7 @@ def test_model_registration_with_model_card_object( my_card = ModelCard( name="TestName", + sagemaker_session=sagemaker_session_for_pipeline, status=ModelCardStatusEnum.DRAFT, model_overview=model_overview, intended_uses=intended_uses,