Skip to content

Feat: Added update for model package #4309

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 95 additions & 7 deletions src/sagemaker/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
)
from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements
from sagemaker.enums import EndpointType
from sagemaker.session import get_add_model_package_inference_args

LOGGER = logging.getLogger("sagemaker")

Expand Down Expand Up @@ -485,12 +486,6 @@ def register(
if response_types is not None:
self.response_types = response_types

if self.content_types is None:
raise ValueError("The supported MIME types for the input data is not set")

if self.response_types is None:
raise ValueError("The supported MIME types for the output data is not set")

if image_uri is not None:
self.image_uri = image_uri

Expand Down Expand Up @@ -2181,7 +2176,7 @@ def update_approval_status(self, approval_status, approval_description=None):
"""Update the approval status for the model package

Args:
approval_status (str or PipelineVariable): Model Approval Status, values can be
approval_status (str): Model Approval Status, values can be
"Approved", "Rejected", or "PendingManualApproval".
approval_description (str): Optional. Description for the approval status of the model
(default: None).
Expand All @@ -2202,3 +2197,96 @@ def update_approval_status(self, approval_status, approval_description=None):
update_approval_args["ApprovalDescription"] = approval_description

sagemaker_session.sagemaker_client.update_model_package(**update_approval_args)

def update_customer_metadata(self, customer_metadata_properties: Dict[str, str]):
"""Updating customer metadata properties for the model package

Args:
customer_metadata_properties (dict[str, str]):
A dictionary of key-value paired metadata properties (default: None).
"""

update_metadata_args = {
"ModelPackageArn": self.model_package_arn,
"CustomerMetadataProperties": customer_metadata_properties,
}

sagemaker_session = self.sagemaker_session or sagemaker.Session()
sagemaker_session.sagemaker_client.update_model_package(**update_metadata_args)

def remove_customer_metadata_properties(
self, customer_metadata_properties_to_remove: List[str]
):
"""Removes the specified keys from customer metadata properties

Args:
customer_metadata_properties (list[str, str]):
list of keys of customer metadata properties to remove.
"""

delete_metadata_args = {
"ModelPackageArn": self.model_package_arn,
"CustomerMetadataPropertiesToRemove": customer_metadata_properties_to_remove,
}

sagemaker_session = self.sagemaker_session or sagemaker.Session()
sagemaker_session.sagemaker_client.update_model_package(**delete_metadata_args)

def add_inference_specification(
self,
name: str,
containers: Dict = None,
image_uris: List[str] = None,
description: str = None,
content_types: List[str] = None,
response_types: List[str] = None,
inference_instances: List[str] = None,
transform_instances: List[str] = None,
):
"""Additional inference specification to be added for the model package

Args:
name (str): Name to identify the additional inference specification
containers (dict): The Amazon ECR registry path of the Docker image
that contains the inference code.
image_uris (List[str]): The ECR path where inference code is stored.
description (str): Description for the additional inference specification
content_types (list[str]): The supported MIME types
for the input data.
response_types (list[str]): The supported MIME types
for the output data.
inference_instances (list[str]): A list of the instance
types that are used to generate inferences in real-time (default: None).
transform_instances (list[str]): A list of the instance
types on which a transformation job can be run or on which an endpoint can be
deployed (default: None).

"""
sagemaker_session = self.sagemaker_session or sagemaker.Session()
if containers is not None and image_uris is not None:
raise ValueError("Cannot have both containers and image_uris.")
if containers is None and image_uris is None:
raise ValueError("Should have either containers or image_uris for inference.")
container_def = []
if image_uris:
for uri in image_uris:
container_def.append(
{
"Image": uri,
}
)
else:
container_def = containers

model_package_update_args = get_add_model_package_inference_args(
model_package_arn=self.model_package_arn,
name=name,
containers=container_def,
content_types=content_types,
description=description,
response_types=response_types,
inference_instances=inference_instances,
transform_instances=transform_instances,
)

sagemaker_session.sagemaker_client.update_model_package(**model_package_update_args)
88 changes: 82 additions & 6 deletions src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -6557,15 +6557,21 @@ def get_create_model_package_request(
if task is not None:
request_dict["Task"] = task
if containers is not None:
if not all([content_types, response_types]):
raise ValueError(
"content_types and response_types " "must be provided if containers is present."
)
inference_specification = {
"Containers": containers,
"SupportedContentTypes": content_types,
"SupportedResponseMIMETypes": response_types,
}
if content_types is not None:
inference_specification.update(
{
"SupportedContentTypes": content_types,
}
)
if response_types is not None:
inference_specification.update(
{
"SupportedResponseMIMETypes": response_types,
}
)
if model_package_group_name is not None:
if inference_instances is not None:
inference_specification.update(
Expand Down Expand Up @@ -6598,6 +6604,76 @@ def get_create_model_package_request(
return request_dict


def get_add_model_package_inference_args(
model_package_arn,
name,
containers=None,
content_types=None,
response_types=None,
inference_instances=None,
transform_instances=None,
description=None,
):
"""Get request dictionary for UpdateModelPackage API for additional inference.

Args:
model_package_arn (str): Arn for the model package.
name (str): Name to identify the additional inference specification
containers (dict): The Amazon ECR registry path of the Docker image
that contains the inference code.
image_uris (List[str]): The ECR path where inference code is stored.
description (str): Description for the additional inference specification
content_types (list[str]): The supported MIME types
for the input data.
response_types (list[str]): The supported MIME types
for the output data.
inference_instances (list[str]): A list of the instance
types that are used to generate inferences in real-time (default: None).
transform_instances (list[str]): A list of the instance
types on which a transformation job can be run or on which an endpoint can be
deployed (default: None).
"""

request_dict = {}
if containers is not None:
inference_specification = {
"Containers": containers,
}

if name is not None:
inference_specification.update({"Name": name})

if description is not None:
inference_specification.update({"Description": description})
if content_types is not None:
inference_specification.update(
{
"SupportedContentTypes": content_types,
}
)
if response_types is not None:
inference_specification.update(
{
"SupportedResponseMIMETypes": response_types,
}
)
if inference_instances is not None:
inference_specification.update(
{
"SupportedRealtimeInferenceInstanceTypes": inference_instances,
}
)
if transform_instances is not None:
inference_specification.update(
{
"SupportedTransformInstanceTypes": transform_instances,
}
)
request_dict["AdditionalInferenceSpecificationsToAdd"] = [inference_specification]
request_dict.update({"ModelPackageArn": model_package_arn})
return request_dict


def update_args(args: Dict[str, Any], **kwargs):
"""Updates the request arguments dict with the value if populated.

Expand Down
43 changes: 43 additions & 0 deletions tests/integ/test_model_package.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from sagemaker.utils import unique_name_from_base
from tests.integ import DATA_DIR
from sagemaker.xgboost import XGBoostModel
from sagemaker import image_uris

_XGBOOST_PATH = os.path.join(DATA_DIR, "xgboost_abalone")

Expand Down Expand Up @@ -61,3 +62,45 @@ def test_update_approval_model_package(sagemaker_session):
sagemaker_session.sagemaker_client.delete_model_package_group(
ModelPackageGroupName=model_group_name
)


def test_inference_specification_addition(sagemaker_session):

model_group_name = unique_name_from_base("test-model-group")

sagemaker_session.sagemaker_client.create_model_package_group(
ModelPackageGroupName=model_group_name
)

xgb_model_data_s3 = sagemaker_session.upload_data(
path=os.path.join(_XGBOOST_PATH, "xgb_model.tar.gz"),
key_prefix="integ-test-data/xgboost/model",
)
model = XGBoostModel(
model_data=xgb_model_data_s3, framework_version="1.3-1", sagemaker_session=sagemaker_session
)

model_package = model.register(
content_types=["text/csv"],
response_types=["text/csv"],
inference_instances=["ml.m5.large"],
transform_instances=["ml.m5.large"],
model_package_group_name=model_group_name,
)

xgb_image = image_uris.retrieve(
"xgboost", sagemaker_session.boto_region_name, version="1", image_scope="inference"
)
model_package.add_inference_specification(image_uris=[xgb_image], name="Inference")
desc_model_package = sagemaker_session.sagemaker_client.describe_model_package(
ModelPackageName=model_package.model_package_arn
)
assert len(desc_model_package["AdditionalInferenceSpecifications"]) == 1
assert desc_model_package["AdditionalInferenceSpecifications"][0]["Name"] == "Inference"

sagemaker_session.sagemaker_client.delete_model_package(
ModelPackageName=model_package.model_package_arn
)
sagemaker_session.sagemaker_client.delete_model_package_group(
ModelPackageGroupName=model_group_name
)
73 changes: 73 additions & 0 deletions tests/unit/sagemaker/model/test_model_package.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,3 +326,76 @@ def test_model_package_auto_approve_on_deploy(update_approval_status, sagemaker_
update_approval_status.call_args_list[0][1]["approval_status"]
== ModelApprovalStatusEnum.APPROVED
)


def test_update_customer_metadata(sagemaker_session):
model_package = ModelPackage(
role="role",
model_package_arn=MODEL_PACKAGE_VERSIONED_ARN,
sagemaker_session=sagemaker_session,
)

customer_metadata_to_update = {
"Key": "Value",
}
model_package.update_customer_metadata(customer_metadata_properties=customer_metadata_to_update)

sagemaker_session.sagemaker_client.update_model_package.assert_called_with(
ModelPackageArn=MODEL_PACKAGE_VERSIONED_ARN,
CustomerMetadataProperties=customer_metadata_to_update,
)


def test_remove_customer_metadata(sagemaker_session):
model_package = ModelPackage(
role="role",
model_package_arn=MODEL_PACKAGE_VERSIONED_ARN,
sagemaker_session=sagemaker_session,
)

customer_metadata_to_remove = ["Key"]

model_package.remove_customer_metadata_properties(
customer_metadata_properties_to_remove=customer_metadata_to_remove
)

sagemaker_session.sagemaker_client.update_model_package.assert_called_with(
ModelPackageArn=MODEL_PACKAGE_VERSIONED_ARN,
CustomerMetadataPropertiesToRemove=customer_metadata_to_remove,
)


def test_add_inference_specification(sagemaker_session):
model_package = ModelPackage(
role="role",
model_package_arn=MODEL_PACKAGE_VERSIONED_ARN,
sagemaker_session=sagemaker_session,
)

image_uris = ["image_uri"]

containers = [{"Image": "image_uri"}]

try:
model_package.add_inference_specification(
image_uris=image_uris, name="Inference", containers=containers
)
except ValueError as ve:
assert "Cannot have both containers and image_uris." in str(ve)

try:
model_package.add_inference_specification(name="Inference")
except ValueError as ve:
assert "Should have either containers or image_uris for inference." in str(ve)

model_package.add_inference_specification(image_uris=image_uris, name="Inference")

sagemaker_session.sagemaker_client.update_model_package.assert_called_with(
ModelPackageArn=MODEL_PACKAGE_VERSIONED_ARN,
AdditionalInferenceSpecificationsToAdd=[
{
"Containers": [{"Image": "image_uri"}],
"Name": "Inference",
}
],
)