Skip to content

feature: adding customer metadata support to registermodel step #2935

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Feb 18, 2022
4 changes: 4 additions & 0 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1263,6 +1263,7 @@ def register(
compile_model_family=None,
model_name=None,
drift_check_baselines=None,
customer_metadata_properties=None,
**kwargs,
):
"""Creates a model package for creating SageMaker models or listing on Marketplace.
Expand Down Expand Up @@ -1292,6 +1293,8 @@ def register(
model will be used (default: None).
model_name (str): User defined model name (default: None).
drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
customer_metadata_properties (dict[str, str]): A dictionary of key-value paired
metadata properties (default: None).
**kwargs: Passed to invocation of ``create_model()``. Implementations may customize
``create_model()`` to accept ``**kwargs`` to customize model creation during
deploy. For more, see the implementation docs.
Expand Down Expand Up @@ -1322,6 +1325,7 @@ def register(
approval_status,
description,
drift_check_baselines=drift_check_baselines,
customer_metadata_properties=customer_metadata_properties,
)

@property
Expand Down
4 changes: 4 additions & 0 deletions src/sagemaker/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,7 @@ def register(
approval_status=None,
description=None,
drift_check_baselines=None,
customer_metadata_properties=None,
):
"""Creates a model package for creating SageMaker models or listing on Marketplace.

Expand All @@ -329,6 +330,8 @@ def register(
or "PendingManualApproval" (default: "PendingManualApproval").
description (str): Model Package description (default: None).
drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
customer_metadata_properties (dict[str, str]): A dictionary of key-value paired
metadata properties (default: None).

Returns:
A `sagemaker.model.ModelPackage` instance.
Expand Down Expand Up @@ -356,6 +359,7 @@ def register(
description=description,
container_def_list=[container_def],
drift_check_baselines=drift_check_baselines,
customer_metadata_properties=customer_metadata_properties,
)
model_package = self.sagemaker_session.create_model_package_from_containers(
**model_pkg_args
Expand Down
4 changes: 4 additions & 0 deletions src/sagemaker/mxnet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ def register(
approval_status=None,
description=None,
drift_check_baselines=None,
customer_metadata_properties=None,
):
"""Creates a model package for creating SageMaker models or listing on Marketplace.

Expand All @@ -183,6 +184,8 @@ def register(
or "PendingManualApproval" (default: "PendingManualApproval").
description (str): Model Package description (default: None).
drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
customer_metadata_properties (dict[str, str]): A dictionary of key-value paired
metadata properties (default: None).

Returns:
A `sagemaker.model.ModelPackage` instance.
Expand Down Expand Up @@ -211,6 +214,7 @@ def register(
approval_status,
description,
drift_check_baselines=drift_check_baselines,
customer_metadata_properties=customer_metadata_properties,
)

def prepare_container_def(self, instance_type=None, accelerator_type=None):
Expand Down
4 changes: 4 additions & 0 deletions src/sagemaker/pytorch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ def register(
approval_status=None,
description=None,
drift_check_baselines=None,
customer_metadata_properties=None,
):
"""Creates a model package for creating SageMaker models or listing on Marketplace.

Expand All @@ -182,6 +183,8 @@ def register(
or "PendingManualApproval" (default: "PendingManualApproval").
description (str): Model Package description (default: None).
drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
customer_metadata_properties (dict[str, str]): A dictionary of key-value paired
metadata properties (default: None).

Returns:
A `sagemaker.model.ModelPackage` instance.
Expand Down Expand Up @@ -210,6 +213,7 @@ def register(
approval_status,
description,
drift_check_baselines=drift_check_baselines,
customer_metadata_properties=customer_metadata_properties,
)

def prepare_container_def(self, instance_type=None, accelerator_type=None):
Expand Down
24 changes: 24 additions & 0 deletions src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -2778,6 +2778,7 @@ def create_model_package_from_containers(
approval_status="PendingManualApproval",
description=None,
drift_check_baselines=None,
customer_metadata_properties=None,
):
"""Get request dictionary for CreateModelPackage API.

Expand All @@ -2803,6 +2804,9 @@ def create_model_package_from_containers(
or "PendingManualApproval" (default: "PendingManualApproval").
description (str): Model Package description (default: None).
drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
customer_metadata_properties (dict[str, str]): A dictionary of key-value paired
metadata properties (default: None).

"""

request = get_create_model_package_request(
Expand All @@ -2819,7 +2823,17 @@ def create_model_package_from_containers(
approval_status,
description,
drift_check_baselines=drift_check_baselines,
customer_metadata_properties=customer_metadata_properties,
)
if model_package_group_name is not None:
try:
self.sagemaker_client.describe_model_package_group(
ModelPackageGroupName=request["ModelPackageGroupName"]
)
except ClientError:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any way this exception can be made more specific? IE check the error code, or test for sagemaker_client.exceptions.ResourceNotFound if that API returns ResourceNotFound

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be a validation exception which comes under client error. This is the only type of clienterror we see for describe_model_package_group.

self.sagemaker_client.create_model_package_group(
ModelPackageGroupName=request["ModelPackageGroupName"]
)
return self.sagemaker_client.create_model_package(**request)

def wait_for_model_package(self, model_package_name, poll=5):
Expand Down Expand Up @@ -4120,6 +4134,7 @@ def get_model_package_args(
tags=None,
container_def_list=None,
drift_check_baselines=None,
customer_metadata_properties=None,
):
"""Get arguments for create_model_package method.

Expand Down Expand Up @@ -4148,6 +4163,8 @@ def get_model_package_args(
(default: None).
container_def_list (list): A list of container defintiions (default: None).
drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
customer_metadata_properties (dict[str, str]): A dictionary of key-value paired
metadata properties (default: None).
Returns:
dict: A dictionary of method argument names and values.
"""
Expand Down Expand Up @@ -4185,6 +4202,8 @@ def get_model_package_args(
model_package_args["description"] = description
if tags is not None:
model_package_args["tags"] = tags
if customer_metadata_properties is not None:
model_package_args["customer_metadata_properties"] = customer_metadata_properties
return model_package_args


Expand All @@ -4203,6 +4222,7 @@ def get_create_model_package_request(
description=None,
tags=None,
drift_check_baselines=None,
customer_metadata_properties=None,
):
"""Get request dictionary for CreateModelPackage API.

Expand All @@ -4229,6 +4249,8 @@ def get_create_model_package_request(
tags (List[dict[str, str]]): A list of dictionaries containing key-value pairs
(default: None).
drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
customer_metadata_properties (dict[str, str]): A dictionary of key-value paired
metadata properties (default: None).
"""

if all([model_package_name, model_package_group_name]):
Expand All @@ -4250,6 +4272,8 @@ def get_create_model_package_request(
request_dict["DriftCheckBaselines"] = drift_check_baselines
if metadata_properties:
request_dict["MetadataProperties"] = metadata_properties
if customer_metadata_properties is not None:
request_dict["CustomerMetadataProperties"] = customer_metadata_properties
if containers is not None:
if not all([content_types, response_types, inference_instances, transform_instances]):
raise ValueError(
Expand Down
5 changes: 5 additions & 0 deletions src/sagemaker/tensorflow/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ def register(
approval_status=None,
description=None,
drift_check_baselines=None,
customer_metadata_properties=None,
):
"""Creates a model package for creating SageMaker models or listing on Marketplace.

Expand All @@ -226,6 +227,9 @@ def register(
or "PendingManualApproval" (default: "PendingManualApproval").
description (str): Model Package description (default: None).
drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
customer_metadata_properties (dict[str, str]): A dictionary of key-value paired
metadata properties (default: None).


Returns:
A `sagemaker.model.ModelPackage` instance.
Expand Down Expand Up @@ -254,6 +258,7 @@ def register(
approval_status,
description,
drift_check_baselines=drift_check_baselines,
customer_metadata_properties=customer_metadata_properties,
)

def deploy(
Expand Down
5 changes: 5 additions & 0 deletions src/sagemaker/workflow/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,7 @@ def __init__(
tags=None,
container_def_list=None,
drift_check_baselines=None,
customer_metadata_properties=None,
**kwargs,
):
"""Constructor of a register model step.
Expand Down Expand Up @@ -347,6 +348,8 @@ def __init__(
this step depends on
retry_policies (List[RetryPolicy]): The list of retry policies for the current step
drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
customer_metadata_properties (dict[str, str]): A dictionary of key-value paired
metadata properties (default: None).
**kwargs: additional arguments to `create_model`.
"""
super(_RegisterModelStep, self).__init__(
Expand All @@ -362,6 +365,7 @@ def __init__(
self.tags = tags
self.model_metrics = model_metrics
self.drift_check_baselines = drift_check_baselines
self.customer_metadata_properties = customer_metadata_properties
self.metadata_properties = metadata_properties
self.approval_status = approval_status
self.image_uri = image_uri
Expand Down Expand Up @@ -435,6 +439,7 @@ def arguments(self) -> RequestType:
description=self.description,
tags=self.tags,
container_def_list=self.container_def_list,
customer_metadata_properties=self.customer_metadata_properties,
)

request_dict = get_create_model_package_request(**model_package_args)
Expand Down
7 changes: 6 additions & 1 deletion src/sagemaker/workflow/step_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def __init__(
tags=None,
model: Union[Model, PipelineModel] = None,
drift_check_baselines=None,
customer_metadata_properties=None,
**kwargs,
):
"""Construct steps `_RepackModelStep` and `_RegisterModelStep` based on the estimator.
Expand All @@ -95,7 +96,7 @@ def __init__(
for the repack model step
register_model_step_retry_policies (List[RetryPolicy]): The list of retry policies
for register model step
model_package_group_name (str): The Model Package Group name, exclusive to
model_package_group_name (str): The Model Package Group name or Arn, exclusive to
`model_package_name`, using `model_package_group_name` makes the Model Package
versioned (default: None).
model_metrics (ModelMetrics): ModelMetrics object (default: None).
Expand All @@ -113,6 +114,9 @@ def __init__(
model (object or Model): A PipelineModel object that comprises a list of models
which gets executed as a serial inference pipeline or a Model object.
drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
customer_metadata_properties (dict[str, str]): A dictionary of key-value paired
metadata properties (default: None).

**kwargs: additional arguments to `create_model`.
"""
steps: List[Step] = []
Expand Down Expand Up @@ -229,6 +233,7 @@ def __init__(
tags=tags,
container_def_list=self.container_def_list,
retry_policies=register_model_step_retry_policies,
customer_metadata_properties=customer_metadata_properties,
**kwargs,
)
if not repack_model:
Expand Down
3 changes: 3 additions & 0 deletions tests/integ/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1951,6 +1951,7 @@ def test_model_registration_with_drift_check_baselines(
content_type="application/json",
),
)
customer_metadata_properties = {"key1": "value1"}
estimator = XGBoost(
entry_point="training.py",
source_dir=os.path.join(DATA_DIR, "sip"),
Expand All @@ -1972,6 +1973,7 @@ def test_model_registration_with_drift_check_baselines(
model_package_group_name="testModelPackageGroup",
model_metrics=model_metrics,
drift_check_baselines=drift_check_baselines,
customer_metadata_properties=customer_metadata_properties,
)

pipeline = Pipeline(
Expand Down Expand Up @@ -2042,6 +2044,7 @@ def test_model_registration_with_drift_check_baselines(
response["DriftCheckBaselines"]["ModelDataQuality"]["Statistics"]["ContentType"]
== "application/json"
)
assert response["CustomerMetadataProperties"] == customer_metadata_properties
break
finally:
try:
Expand Down
3 changes: 3 additions & 0 deletions tests/unit/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -2385,6 +2385,7 @@ def test_create_model_package_from_containers_all_args(sagemaker_session):
marketplace_cert = (True,)
approval_status = ("Approved",)
description = "description"
customer_metadata_properties = {"key1": "value1"}
sagemaker_session.create_model_package_from_containers(
containers=containers,
content_types=content_types,
Expand All @@ -2398,6 +2399,7 @@ def test_create_model_package_from_containers_all_args(sagemaker_session):
approval_status=approval_status,
description=description,
drift_check_baselines=drift_check_baselines,
customer_metadata_properties=customer_metadata_properties,
)
expected_args = {
"ModelPackageName": model_package_name,
Expand All @@ -2414,6 +2416,7 @@ def test_create_model_package_from_containers_all_args(sagemaker_session):
"CertifyForMarketplace": marketplace_cert,
"ModelApprovalStatus": approval_status,
"DriftCheckBaselines": drift_check_baselines,
"CustomerMetadataProperties": customer_metadata_properties,
}
sagemaker_session.sagemaker_client.create_model_package.assert_called_with(**expected_args)

Expand Down