Skip to content

feature: add 'ModelCard' property to Register step #4726

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jun 13, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/sagemaker/chainer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
from sagemaker.chainer import defaults
from sagemaker.deserializers import NumpyDeserializer
from sagemaker.model_card import (
ModelCard,
ModelPackageModelCard,
)
from sagemaker.predictor import Predictor
from sagemaker.serializers import NumpySerializer
from sagemaker.utils import to_string
Expand Down Expand Up @@ -175,6 +179,7 @@ def register(
data_input_configuration: Optional[Union[str, PipelineVariable]] = None,
skip_model_validation: Optional[Union[str, PipelineVariable]] = None,
source_uri: Optional[Union[str, PipelineVariable]] = None,
model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None,
):
"""Creates a model package for creating SageMaker models or listing on Marketplace.

Expand Down Expand Up @@ -226,6 +231,8 @@ def register(
validation. Values can be "All" or "None" (default: None).
source_uri (str or PipelineVariable): The URI of the source for the model package
(default: None).
model_card (ModeCard or ModelPackageModelCard): document contains qualitative and
quantitative information about a model (default: None).

Returns:
str: A string of SageMaker Model Package ARN.
Expand Down Expand Up @@ -266,6 +273,7 @@ def register(
data_input_configuration=data_input_configuration,
skip_model_validation=skip_model_validation,
source_uri=source_uri,
model_card=model_card,
)

def prepare_container_def(
Expand Down
4 changes: 4 additions & 0 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1724,6 +1724,7 @@ def register(
data_input_configuration=None,
skip_model_validation=None,
source_uri=None,
model_card=None,
**kwargs,
):
"""Creates a model package for creating SageMaker models or listing on Marketplace.
Expand Down Expand Up @@ -1772,6 +1773,8 @@ def register(
skip_model_validation (str): Indicates if you want to skip model validation.
Values can be "All" or "None" (default: None).
source_uri (str): The URI of the source for the model package (default: None).
model_card (ModeCard or ModelPackageModelCard): document contains qualitative and
quantitative information about a model (default: None).
**kwargs: Passed to invocation of ``create_model()``. Implementations may customize
``create_model()`` to accept ``**kwargs`` to customize model creation during
deploy. For more, see the implementation docs.
Expand Down Expand Up @@ -1817,6 +1820,7 @@ def register(
data_input_configuration=data_input_configuration,
skip_model_validation=skip_model_validation,
source_uri=source_uri,
model_card=model_card,
)

@property
Expand Down
8 changes: 8 additions & 0 deletions src/sagemaker/huggingface/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@
)
from sagemaker.metadata_properties import MetadataProperties
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
from sagemaker.model_card import (
ModelCard,
ModelPackageModelCard,
)
from sagemaker.predictor import Predictor
from sagemaker.serializers import JSONSerializer
from sagemaker.session import Session
Expand Down Expand Up @@ -362,6 +366,7 @@ def register(
data_input_configuration: Optional[Union[str, PipelineVariable]] = None,
skip_model_validation: Optional[Union[str, PipelineVariable]] = None,
source_uri: Optional[Union[str, PipelineVariable]] = None,
model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None,
):
"""Creates a model package for creating SageMaker models or listing on Marketplace.

Expand Down Expand Up @@ -414,6 +419,8 @@ def register(
validation. Values can be "All" or "None" (default: None).
source_uri (str or PipelineVariable): The URI of the source for the model package
(default: None).
model_card (ModeCard or ModelPackageModelCard): document contains qualitative and
quantitative information about a model (default: None).

Returns:
A `sagemaker.model.ModelPackage` instance.
Expand Down Expand Up @@ -462,6 +469,7 @@ def register(
data_input_configuration=data_input_configuration,
skip_model_validation=skip_model_validation,
source_uri=source_uri,
model_card=model_card,
)

def prepare_container_def(
Expand Down
3 changes: 3 additions & 0 deletions src/sagemaker/jumpstart/factory/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
JUMPSTART_DEFAULT_REGION_NAME,
JUMPSTART_LOGGER,
)
from sagemaker.model_card.model_card import ModelCard, ModelPackageModelCard
from sagemaker.model_metrics import ModelMetrics
from sagemaker.metadata_properties import MetadataProperties
from sagemaker.drift_check_baselines import DriftCheckBaselines
Expand Down Expand Up @@ -646,6 +647,7 @@ def get_register_kwargs(
data_input_configuration: Optional[str] = None,
skip_model_validation: Optional[str] = None,
source_uri: Optional[str] = None,
model_card: Optional[Dict[ModelCard, ModelPackageModelCard]] = None,
) -> JumpStartModelRegisterKwargs:
"""Returns kwargs required to call `register` on `sagemaker.estimator.Model` object."""

Expand Down Expand Up @@ -678,6 +680,7 @@ def get_register_kwargs(
data_input_configuration=data_input_configuration,
skip_model_validation=skip_model_validation,
source_uri=source_uri,
model_card=model_card,
)

model_specs = verify_model_region_and_return_specs(
Expand Down
8 changes: 8 additions & 0 deletions src/sagemaker/jumpstart/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@
)
from sagemaker.jumpstart.constants import JUMPSTART_LOGGER
from sagemaker.jumpstart.enums import JumpStartModelType
from sagemaker.model_card import (
ModelCard,
ModelPackageModelCard,
)
from sagemaker.utils import stringify_object, format_tags, Tags
from sagemaker.model import (
Model,
Expand Down Expand Up @@ -692,6 +696,7 @@ def register(
data_input_configuration: Optional[Union[str, PipelineVariable]] = None,
skip_model_validation: Optional[Union[str, PipelineVariable]] = None,
source_uri: Optional[Union[str, PipelineVariable]] = None,
model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None,
):
"""Creates a model package for creating SageMaker models or listing on Marketplace.

Expand Down Expand Up @@ -739,6 +744,8 @@ def register(
validation. Values can be "All" or "None" (default: None).
source_uri (str or PipelineVariable): The URI of the source for the model package
(default: None).
model_card (ModeCard or ModelPackageModelCard): document contains qualitative and
quantitative information about a model (default: None).

Returns:
A `sagemaker.model.ModelPackage` instance.
Expand Down Expand Up @@ -773,6 +780,7 @@ def register(
data_input_configuration=data_input_configuration,
skip_model_validation=skip_model_validation,
source_uri=source_uri,
model_card=model_card,
)

model_package = super(JumpStartModel, self).register(**register_kwargs.to_kwargs_dict())
Expand Down
4 changes: 4 additions & 0 deletions src/sagemaker/jumpstart/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from copy import deepcopy
from enum import Enum
from typing import Any, Dict, List, Optional, Set, Union
from sagemaker.model_card.model_card import ModelCard, ModelPackageModelCard
from sagemaker.utils import get_instance_type_family, format_tags, Tags, deep_override_dict
from sagemaker.model_metrics import ModelMetrics
from sagemaker.metadata_properties import MetadataProperties
Expand Down Expand Up @@ -2114,6 +2115,7 @@ class JumpStartModelRegisterKwargs(JumpStartKwargs):
"data_input_configuration",
"skip_model_validation",
"source_uri",
"model_card",
]

SERIALIZATION_EXCLUSION_SET = {
Expand Down Expand Up @@ -2155,6 +2157,7 @@ def __init__(
data_input_configuration: Optional[str] = None,
skip_model_validation: Optional[str] = None,
source_uri: Optional[str] = None,
model_card: Optional[Dict[ModelCard, ModelPackageModelCard]] = None,
) -> None:
"""Instantiates JumpStartModelRegisterKwargs object."""

Expand Down Expand Up @@ -2187,3 +2190,4 @@ def __init__(
self.data_input_configuration = data_input_configuration
self.skip_model_validation = skip_model_validation
self.source_uri = source_uri
self.model_card = model_card
8 changes: 8 additions & 0 deletions src/sagemaker/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@
ENDPOINT_CONFIG_ASYNC_KMS_KEY_ID_PATH,
load_sagemaker_config,
)
from sagemaker.model_card import (
ModelCard,
ModelPackageModelCard,
)
from sagemaker.model_card.schema_constraints import ModelApprovalStatusEnum
from sagemaker.session import Session
from sagemaker.model_metrics import ModelMetrics
Expand Down Expand Up @@ -428,6 +432,7 @@ def register(
data_input_configuration: Optional[Union[str, PipelineVariable]] = None,
skip_model_validation: Optional[Union[str, PipelineVariable]] = None,
source_uri: Optional[Union[str, PipelineVariable]] = None,
model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None,
):
"""Creates a model package for creating SageMaker models or listing on Marketplace.

Expand Down Expand Up @@ -479,6 +484,8 @@ def register(
validation. Values can be "All" or "None" (default: None).
source_uri (str or PipelineVariable): The URI of the source for the model package
(default: None).
model_card (ModeCard or ModelPackageModelCard): document contains qualitative and
quantitative information about a model (default: None).

Returns:
A `sagemaker.model.ModelPackage` instance or pipeline step arguments
Expand Down Expand Up @@ -545,6 +552,7 @@ def register(
task=task,
skip_model_validation=skip_model_validation,
source_uri=source_uri,
model_card=model_card,
)
model_package = self.sagemaker_session.create_model_package_from_containers(
**model_pkg_args
Expand Down
1 change: 1 addition & 0 deletions src/sagemaker/model_card/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
AdditionalInformation,
ModelCard,
ModelPackage,
ModelPackageModelCard,
)

from sagemaker.model_card.schema_constraints import ( # noqa: F401 # pylint: disable=unused-import
Expand Down
26 changes: 26 additions & 0 deletions src/sagemaker/model_card/model_card.py
Original file line number Diff line number Diff line change
Expand Up @@ -1883,3 +1883,29 @@ def list_export_jobs(
return sagemaker_session.sagemaker_client.list_model_card_export_jobs(
ModelCardName=model_card_name, **kwargs
)


class ModelPackageModelCard(object):
"""Use an Amazon SageMaker Model Card to document qualitative and quantitative information about a model.""" # noqa E501 # pylint: disable=c0301

def __init__(
self,
model_card_content: dict[str],
model_card_status: str,
):

self.model_card_content = model_card_content
self.model_card_status = model_card_status

def _create_request_args(self):
"""Generate the request body for create model card call.

Args:
model_card_content dict[str]: Content of the model card.
model_card_status (str): Status of the model card you want to export.

""" # noqa E501 # pylint: disable=line-too-long
request_args = {}
request_args["ModelCardStatus"] = self.model_card_status
request_args["Content"] = json.dumps(self.model_card_content, cls=_JSONEncoder)
return request_args
8 changes: 8 additions & 0 deletions src/sagemaker/mxnet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@
)
from sagemaker.metadata_properties import MetadataProperties
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
from sagemaker.model_card import (
ModelCard,
ModelPackageModelCard,
)
from sagemaker.mxnet import defaults
from sagemaker.predictor import Predictor
from sagemaker.serializers import JSONSerializer
Expand Down Expand Up @@ -177,6 +181,7 @@ def register(
data_input_configuration: Optional[Union[str, PipelineVariable]] = None,
skip_model_validation: Optional[Union[str, PipelineVariable]] = None,
source_uri: Optional[Union[str, PipelineVariable]] = None,
model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None,
):
"""Creates a model package for creating SageMaker models or listing on Marketplace.

Expand Down Expand Up @@ -228,6 +233,8 @@ def register(
validation. Values can be "All" or "None" (default: None).
source_uri (str or PipelineVariable): The URI of the source for the model package
(default: None).
model_card (ModeCard or ModelPackageModelCard): document contains qualitative and
quantitative information about a model (default: None).

Returns:
A `sagemaker.model.ModelPackage` instance.
Expand Down Expand Up @@ -268,6 +275,7 @@ def register(
data_input_configuration=data_input_configuration,
skip_model_validation=skip_model_validation,
source_uri=source_uri,
model_card=model_card,
)

def prepare_container_def(
Expand Down
8 changes: 8 additions & 0 deletions src/sagemaker/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@
)
from sagemaker.drift_check_baselines import DriftCheckBaselines
from sagemaker.metadata_properties import MetadataProperties
from sagemaker.model_card import (
ModelCard,
ModelPackageModelCard,
)
from sagemaker.session import Session
from sagemaker.utils import (
name_from_image,
Expand Down Expand Up @@ -361,6 +365,7 @@ def register(
data_input_configuration: Optional[Union[str, PipelineVariable]] = None,
skip_model_validation: Optional[Union[str, PipelineVariable]] = None,
source_uri: Optional[Union[str, PipelineVariable]] = None,
model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None,
):
"""Creates a model package for creating SageMaker models or listing on Marketplace.

Expand Down Expand Up @@ -412,6 +417,8 @@ def register(
validation. Values can be "All" or "None" (default: None).
source_uri (str or PipelineVariable): The URI of the source for the model package
(default: None).
model_card (ModeCard or ModelPackageModelCard): document contains qualitative and
quantitative information about a model (default: None).

Returns:
If ``sagemaker_session`` is a ``PipelineSession`` instance, returns pipeline step
Expand Down Expand Up @@ -460,6 +467,7 @@ def register(
task=task,
skip_model_validation=skip_model_validation,
source_uri=source_uri,
model_card=model_card,
)

self.sagemaker_session.create_model_package_from_containers(**model_pkg_args)
Expand Down
8 changes: 8 additions & 0 deletions src/sagemaker/pytorch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@
)
from sagemaker.metadata_properties import MetadataProperties
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
from sagemaker.model_card import (
ModelCard,
ModelPackageModelCard,
)
from sagemaker.pytorch import defaults
from sagemaker.predictor import Predictor
from sagemaker.serializers import NumpySerializer
Expand Down Expand Up @@ -179,6 +183,7 @@ def register(
data_input_configuration: Optional[Union[str, PipelineVariable]] = None,
skip_model_validation: Optional[Union[str, PipelineVariable]] = None,
source_uri: Optional[Union[str, PipelineVariable]] = None,
model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None,
):
"""Creates a model package for creating SageMaker models or listing on Marketplace.

Expand Down Expand Up @@ -230,6 +235,8 @@ def register(
validation. Values can be "All" or "None" (default: None).
source_uri (str or PipelineVariable): The URI of the source for the model package
(default: None).
model_card (ModeCard or ModelPackageModelCard): document contains qualitative and
quantitative information about a model (default: None).

Returns:
A `sagemaker.model.ModelPackage` instance.
Expand Down Expand Up @@ -270,6 +277,7 @@ def register(
data_input_configuration=data_input_configuration,
skip_model_validation=skip_model_validation,
source_uri=source_uri,
model_card=model_card,
)

def prepare_container_def(
Expand Down
Loading
Loading