Skip to content

Commit 75efdef

Browse files
feature: add 'ModelCard' property to Register step (#4726)
* feature: add 'ModelCard' property to RegisterModel step * Updated ModelCard content type * fix: ModelCard Object integ Test fix --------- Co-authored-by: Gokul A <[email protected]>
1 parent 42fc662 commit 75efdef

File tree

21 files changed

+613
-3
lines changed

21 files changed

+613
-3
lines changed

src/sagemaker/chainer/model.py

+8
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@
2828
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
2929
from sagemaker.chainer import defaults
3030
from sagemaker.deserializers import NumpyDeserializer
31+
from sagemaker.model_card import (
32+
ModelCard,
33+
ModelPackageModelCard,
34+
)
3135
from sagemaker.predictor import Predictor
3236
from sagemaker.serializers import NumpySerializer
3337
from sagemaker.utils import to_string
@@ -175,6 +179,7 @@ def register(
175179
data_input_configuration: Optional[Union[str, PipelineVariable]] = None,
176180
skip_model_validation: Optional[Union[str, PipelineVariable]] = None,
177181
source_uri: Optional[Union[str, PipelineVariable]] = None,
182+
model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None,
178183
):
179184
"""Creates a model package for creating SageMaker models or listing on Marketplace.
180185
@@ -226,6 +231,8 @@ def register(
226231
validation. Values can be "All" or "None" (default: None).
227232
source_uri (str or PipelineVariable): The URI of the source for the model package
228233
(default: None).
234+
model_card (ModeCard or ModelPackageModelCard): document contains qualitative and
235+
quantitative information about a model (default: None).
229236
230237
Returns:
231238
str: A string of SageMaker Model Package ARN.
@@ -266,6 +273,7 @@ def register(
266273
data_input_configuration=data_input_configuration,
267274
skip_model_validation=skip_model_validation,
268275
source_uri=source_uri,
276+
model_card=model_card,
269277
)
270278

271279
def prepare_container_def(

src/sagemaker/estimator.py

+4
Original file line numberDiff line numberDiff line change
@@ -1724,6 +1724,7 @@ def register(
17241724
data_input_configuration=None,
17251725
skip_model_validation=None,
17261726
source_uri=None,
1727+
model_card=None,
17271728
**kwargs,
17281729
):
17291730
"""Creates a model package for creating SageMaker models or listing on Marketplace.
@@ -1772,6 +1773,8 @@ def register(
17721773
skip_model_validation (str): Indicates if you want to skip model validation.
17731774
Values can be "All" or "None" (default: None).
17741775
source_uri (str): The URI of the source for the model package (default: None).
1776+
model_card (ModeCard or ModelPackageModelCard): document contains qualitative and
1777+
quantitative information about a model (default: None).
17751778
**kwargs: Passed to invocation of ``create_model()``. Implementations may customize
17761779
``create_model()`` to accept ``**kwargs`` to customize model creation during
17771780
deploy. For more, see the implementation docs.
@@ -1817,6 +1820,7 @@ def register(
18171820
data_input_configuration=data_input_configuration,
18181821
skip_model_validation=skip_model_validation,
18191822
source_uri=source_uri,
1823+
model_card=model_card,
18201824
)
18211825

18221826
@property

src/sagemaker/huggingface/model.py

+8
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@
2626
)
2727
from sagemaker.metadata_properties import MetadataProperties
2828
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
29+
from sagemaker.model_card import (
30+
ModelCard,
31+
ModelPackageModelCard,
32+
)
2933
from sagemaker.predictor import Predictor
3034
from sagemaker.serializers import JSONSerializer
3135
from sagemaker.session import Session
@@ -362,6 +366,7 @@ def register(
362366
data_input_configuration: Optional[Union[str, PipelineVariable]] = None,
363367
skip_model_validation: Optional[Union[str, PipelineVariable]] = None,
364368
source_uri: Optional[Union[str, PipelineVariable]] = None,
369+
model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None,
365370
):
366371
"""Creates a model package for creating SageMaker models or listing on Marketplace.
367372
@@ -414,6 +419,8 @@ def register(
414419
validation. Values can be "All" or "None" (default: None).
415420
source_uri (str or PipelineVariable): The URI of the source for the model package
416421
(default: None).
422+
model_card (ModeCard or ModelPackageModelCard): document contains qualitative and
423+
quantitative information about a model (default: None).
417424
418425
Returns:
419426
A `sagemaker.model.ModelPackage` instance.
@@ -462,6 +469,7 @@ def register(
462469
data_input_configuration=data_input_configuration,
463470
skip_model_validation=skip_model_validation,
464471
source_uri=source_uri,
472+
model_card=model_card,
465473
)
466474

467475
def prepare_container_def(

src/sagemaker/jumpstart/factory/model.py

+3
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
JUMPSTART_DEFAULT_REGION_NAME,
3434
JUMPSTART_LOGGER,
3535
)
36+
from sagemaker.model_card.model_card import ModelCard, ModelPackageModelCard
3637
from sagemaker.model_metrics import ModelMetrics
3738
from sagemaker.metadata_properties import MetadataProperties
3839
from sagemaker.drift_check_baselines import DriftCheckBaselines
@@ -646,6 +647,7 @@ def get_register_kwargs(
646647
data_input_configuration: Optional[str] = None,
647648
skip_model_validation: Optional[str] = None,
648649
source_uri: Optional[str] = None,
650+
model_card: Optional[Dict[ModelCard, ModelPackageModelCard]] = None,
649651
) -> JumpStartModelRegisterKwargs:
650652
"""Returns kwargs required to call `register` on `sagemaker.estimator.Model` object."""
651653

@@ -678,6 +680,7 @@ def get_register_kwargs(
678680
data_input_configuration=data_input_configuration,
679681
skip_model_validation=skip_model_validation,
680682
source_uri=source_uri,
683+
model_card=model_card,
681684
)
682685

683686
model_specs = verify_model_region_and_return_specs(

src/sagemaker/jumpstart/model.py

+8
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@
4343
)
4444
from sagemaker.jumpstart.constants import JUMPSTART_LOGGER
4545
from sagemaker.jumpstart.enums import JumpStartModelType
46+
from sagemaker.model_card import (
47+
ModelCard,
48+
ModelPackageModelCard,
49+
)
4650
from sagemaker.utils import stringify_object, format_tags, Tags
4751
from sagemaker.model import (
4852
Model,
@@ -692,6 +696,7 @@ def register(
692696
data_input_configuration: Optional[Union[str, PipelineVariable]] = None,
693697
skip_model_validation: Optional[Union[str, PipelineVariable]] = None,
694698
source_uri: Optional[Union[str, PipelineVariable]] = None,
699+
model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None,
695700
):
696701
"""Creates a model package for creating SageMaker models or listing on Marketplace.
697702
@@ -739,6 +744,8 @@ def register(
739744
validation. Values can be "All" or "None" (default: None).
740745
source_uri (str or PipelineVariable): The URI of the source for the model package
741746
(default: None).
747+
model_card (ModeCard or ModelPackageModelCard): document contains qualitative and
748+
quantitative information about a model (default: None).
742749
743750
Returns:
744751
A `sagemaker.model.ModelPackage` instance.
@@ -773,6 +780,7 @@ def register(
773780
data_input_configuration=data_input_configuration,
774781
skip_model_validation=skip_model_validation,
775782
source_uri=source_uri,
783+
model_card=model_card,
776784
)
777785

778786
model_package = super(JumpStartModel, self).register(**register_kwargs.to_kwargs_dict())

src/sagemaker/jumpstart/types.py

+4
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from copy import deepcopy
1616
from enum import Enum
1717
from typing import Any, Dict, List, Optional, Set, Union
18+
from sagemaker.model_card.model_card import ModelCard, ModelPackageModelCard
1819
from sagemaker.utils import get_instance_type_family, format_tags, Tags, deep_override_dict
1920
from sagemaker.model_metrics import ModelMetrics
2021
from sagemaker.metadata_properties import MetadataProperties
@@ -2114,6 +2115,7 @@ class JumpStartModelRegisterKwargs(JumpStartKwargs):
21142115
"data_input_configuration",
21152116
"skip_model_validation",
21162117
"source_uri",
2118+
"model_card",
21172119
]
21182120

21192121
SERIALIZATION_EXCLUSION_SET = {
@@ -2155,6 +2157,7 @@ def __init__(
21552157
data_input_configuration: Optional[str] = None,
21562158
skip_model_validation: Optional[str] = None,
21572159
source_uri: Optional[str] = None,
2160+
model_card: Optional[Dict[ModelCard, ModelPackageModelCard]] = None,
21582161
) -> None:
21592162
"""Instantiates JumpStartModelRegisterKwargs object."""
21602163

@@ -2187,3 +2190,4 @@ def __init__(
21872190
self.data_input_configuration = data_input_configuration
21882191
self.skip_model_validation = skip_model_validation
21892192
self.source_uri = source_uri
2193+
self.model_card = model_card

src/sagemaker/model.py

+8
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,10 @@
4444
ENDPOINT_CONFIG_ASYNC_KMS_KEY_ID_PATH,
4545
load_sagemaker_config,
4646
)
47+
from sagemaker.model_card import (
48+
ModelCard,
49+
ModelPackageModelCard,
50+
)
4751
from sagemaker.model_card.schema_constraints import ModelApprovalStatusEnum
4852
from sagemaker.session import Session
4953
from sagemaker.model_metrics import ModelMetrics
@@ -428,6 +432,7 @@ def register(
428432
data_input_configuration: Optional[Union[str, PipelineVariable]] = None,
429433
skip_model_validation: Optional[Union[str, PipelineVariable]] = None,
430434
source_uri: Optional[Union[str, PipelineVariable]] = None,
435+
model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None,
431436
):
432437
"""Creates a model package for creating SageMaker models or listing on Marketplace.
433438
@@ -479,6 +484,8 @@ def register(
479484
validation. Values can be "All" or "None" (default: None).
480485
source_uri (str or PipelineVariable): The URI of the source for the model package
481486
(default: None).
487+
model_card (ModeCard or ModelPackageModelCard): document contains qualitative and
488+
quantitative information about a model (default: None).
482489
483490
Returns:
484491
A `sagemaker.model.ModelPackage` instance or pipeline step arguments
@@ -545,6 +552,7 @@ def register(
545552
task=task,
546553
skip_model_validation=skip_model_validation,
547554
source_uri=source_uri,
555+
model_card=model_card,
548556
)
549557
model_package = self.sagemaker_session.create_model_package_from_containers(
550558
**model_pkg_args

src/sagemaker/model_card/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
AdditionalInformation,
3030
ModelCard,
3131
ModelPackage,
32+
ModelPackageModelCard,
3233
)
3334

3435
from sagemaker.model_card.schema_constraints import ( # noqa: F401 # pylint: disable=unused-import

src/sagemaker/model_card/model_card.py

+27-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import json
1717
import logging
1818
from datetime import datetime
19-
from typing import Optional, Union, List, Any
19+
from typing import Optional, Union, List, Any, Dict
2020
from botocore.exceptions import ClientError
2121
from boto3.session import Session as boto3_Session
2222
from six.moves.urllib.parse import urlparse
@@ -1883,3 +1883,29 @@ def list_export_jobs(
18831883
return sagemaker_session.sagemaker_client.list_model_card_export_jobs(
18841884
ModelCardName=model_card_name, **kwargs
18851885
)
1886+
1887+
1888+
class ModelPackageModelCard(object):
1889+
"""Use an Amazon SageMaker Model Card to document qualitative and quantitative information about a model.""" # noqa E501 # pylint: disable=c0301
1890+
1891+
def __init__(
1892+
self,
1893+
model_card_content: Dict[str, Any],
1894+
model_card_status: str,
1895+
):
1896+
1897+
self.model_card_content = model_card_content
1898+
self.model_card_status = model_card_status
1899+
1900+
def _create_request_args(self):
1901+
"""Generate the request body for create model card call.
1902+
1903+
Args:
1904+
model_card_content dict[str]: Content of the model card.
1905+
model_card_status (str): Status of the model card you want to export.
1906+
1907+
""" # noqa E501 # pylint: disable=line-too-long
1908+
request_args = {}
1909+
request_args["ModelCardStatus"] = self.model_card_status
1910+
request_args["Content"] = json.dumps(self.model_card_content, cls=_JSONEncoder)
1911+
return request_args

src/sagemaker/mxnet/model.py

+8
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@
2929
)
3030
from sagemaker.metadata_properties import MetadataProperties
3131
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
32+
from sagemaker.model_card import (
33+
ModelCard,
34+
ModelPackageModelCard,
35+
)
3236
from sagemaker.mxnet import defaults
3337
from sagemaker.predictor import Predictor
3438
from sagemaker.serializers import JSONSerializer
@@ -177,6 +181,7 @@ def register(
177181
data_input_configuration: Optional[Union[str, PipelineVariable]] = None,
178182
skip_model_validation: Optional[Union[str, PipelineVariable]] = None,
179183
source_uri: Optional[Union[str, PipelineVariable]] = None,
184+
model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None,
180185
):
181186
"""Creates a model package for creating SageMaker models or listing on Marketplace.
182187
@@ -228,6 +233,8 @@ def register(
228233
validation. Values can be "All" or "None" (default: None).
229234
source_uri (str or PipelineVariable): The URI of the source for the model package
230235
(default: None).
236+
model_card (ModeCard or ModelPackageModelCard): document contains qualitative and
237+
quantitative information about a model (default: None).
231238
232239
Returns:
233240
A `sagemaker.model.ModelPackage` instance.
@@ -268,6 +275,7 @@ def register(
268275
data_input_configuration=data_input_configuration,
269276
skip_model_validation=skip_model_validation,
270277
source_uri=source_uri,
278+
model_card=model_card,
271279
)
272280

273281
def prepare_container_def(

src/sagemaker/pipeline.py

+8
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@
2626
)
2727
from sagemaker.drift_check_baselines import DriftCheckBaselines
2828
from sagemaker.metadata_properties import MetadataProperties
29+
from sagemaker.model_card import (
30+
ModelCard,
31+
ModelPackageModelCard,
32+
)
2933
from sagemaker.session import Session
3034
from sagemaker.utils import (
3135
name_from_image,
@@ -361,6 +365,7 @@ def register(
361365
data_input_configuration: Optional[Union[str, PipelineVariable]] = None,
362366
skip_model_validation: Optional[Union[str, PipelineVariable]] = None,
363367
source_uri: Optional[Union[str, PipelineVariable]] = None,
368+
model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None,
364369
):
365370
"""Creates a model package for creating SageMaker models or listing on Marketplace.
366371
@@ -412,6 +417,8 @@ def register(
412417
validation. Values can be "All" or "None" (default: None).
413418
source_uri (str or PipelineVariable): The URI of the source for the model package
414419
(default: None).
420+
model_card (ModeCard or ModelPackageModelCard): document contains qualitative and
421+
quantitative information about a model (default: None).
415422
416423
Returns:
417424
If ``sagemaker_session`` is a ``PipelineSession`` instance, returns pipeline step
@@ -460,6 +467,7 @@ def register(
460467
task=task,
461468
skip_model_validation=skip_model_validation,
462469
source_uri=source_uri,
470+
model_card=model_card,
463471
)
464472

465473
self.sagemaker_session.create_model_package_from_containers(**model_pkg_args)

src/sagemaker/pytorch/model.py

+8
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@
2929
)
3030
from sagemaker.metadata_properties import MetadataProperties
3131
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
32+
from sagemaker.model_card import (
33+
ModelCard,
34+
ModelPackageModelCard,
35+
)
3236
from sagemaker.pytorch import defaults
3337
from sagemaker.predictor import Predictor
3438
from sagemaker.serializers import NumpySerializer
@@ -179,6 +183,7 @@ def register(
179183
data_input_configuration: Optional[Union[str, PipelineVariable]] = None,
180184
skip_model_validation: Optional[Union[str, PipelineVariable]] = None,
181185
source_uri: Optional[Union[str, PipelineVariable]] = None,
186+
model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None,
182187
):
183188
"""Creates a model package for creating SageMaker models or listing on Marketplace.
184189
@@ -230,6 +235,8 @@ def register(
230235
validation. Values can be "All" or "None" (default: None).
231236
source_uri (str or PipelineVariable): The URI of the source for the model package
232237
(default: None).
238+
model_card (ModeCard or ModelPackageModelCard): document contains qualitative and
239+
quantitative information about a model (default: None).
233240
234241
Returns:
235242
A `sagemaker.model.ModelPackage` instance.
@@ -270,6 +277,7 @@ def register(
270277
data_input_configuration=data_input_configuration,
271278
skip_model_validation=skip_model_validation,
272279
source_uri=source_uri,
280+
model_card=model_card,
273281
)
274282

275283
def prepare_container_def(

0 commit comments

Comments
 (0)