Skip to content

Commit 66c53a3

Browse files
keshav-chandakKeshav Chandak
authored andcommitted
feat: Added update for model package (aws#4309)
Co-authored-by: Keshav Chandak <[email protected]>
1 parent 5bd0d86 commit 66c53a3

File tree

4 files changed

+293
-13
lines changed

4 files changed

+293
-13
lines changed

src/sagemaker/model.py

+95-7
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@
7575
)
7676
from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements
7777
from sagemaker.enums import EndpointType
78+
from sagemaker.session import get_add_model_package_inference_args
7879

7980
LOGGER = logging.getLogger("sagemaker")
8081

@@ -485,12 +486,6 @@ def register(
485486
if response_types is not None:
486487
self.response_types = response_types
487488

488-
if self.content_types is None:
489-
raise ValueError("The supported MIME types for the input data is not set")
490-
491-
if self.response_types is None:
492-
raise ValueError("The supported MIME types for the output data is not set")
493-
494489
if image_uri is not None:
495490
self.image_uri = image_uri
496491

@@ -2181,7 +2176,7 @@ def update_approval_status(self, approval_status, approval_description=None):
21812176
"""Update the approval status for the model package
21822177
21832178
Args:
2184-
approval_status (str or PipelineVariable): Model Approval Status, values can be
2179+
approval_status (str): Model Approval Status, values can be
21852180
"Approved", "Rejected", or "PendingManualApproval".
21862181
approval_description (str): Optional. Description for the approval status of the model
21872182
(default: None).
@@ -2202,3 +2197,96 @@ def update_approval_status(self, approval_status, approval_description=None):
22022197
update_approval_args["ApprovalDescription"] = approval_description
22032198

22042199
sagemaker_session.sagemaker_client.update_model_package(**update_approval_args)
2200+
2201+
def update_customer_metadata(self, customer_metadata_properties: Dict[str, str]):
2202+
"""Updating customer metadata properties for the model package
2203+
2204+
Args:
2205+
customer_metadata_properties (dict[str, str]):
2206+
A dictionary of key-value paired metadata properties (default: None).
2207+
"""
2208+
2209+
update_metadata_args = {
2210+
"ModelPackageArn": self.model_package_arn,
2211+
"CustomerMetadataProperties": customer_metadata_properties,
2212+
}
2213+
2214+
sagemaker_session = self.sagemaker_session or sagemaker.Session()
2215+
sagemaker_session.sagemaker_client.update_model_package(**update_metadata_args)
2216+
2217+
def remove_customer_metadata_properties(
2218+
self, customer_metadata_properties_to_remove: List[str]
2219+
):
2220+
"""Removes the specified keys from customer metadata properties
2221+
2222+
Args:
2223+
customer_metadata_properties (list[str, str]):
2224+
list of keys of customer metadata properties to remove.
2225+
"""
2226+
2227+
delete_metadata_args = {
2228+
"ModelPackageArn": self.model_package_arn,
2229+
"CustomerMetadataPropertiesToRemove": customer_metadata_properties_to_remove,
2230+
}
2231+
2232+
sagemaker_session = self.sagemaker_session or sagemaker.Session()
2233+
sagemaker_session.sagemaker_client.update_model_package(**delete_metadata_args)
2234+
2235+
def add_inference_specification(
2236+
self,
2237+
name: str,
2238+
containers: Dict = None,
2239+
image_uris: List[str] = None,
2240+
description: str = None,
2241+
content_types: List[str] = None,
2242+
response_types: List[str] = None,
2243+
inference_instances: List[str] = None,
2244+
transform_instances: List[str] = None,
2245+
):
2246+
"""Additional inference specification to be added for the model package
2247+
2248+
Args:
2249+
name (str): Name to identify the additional inference specification
2250+
containers (dict): The Amazon ECR registry path of the Docker image
2251+
that contains the inference code.
2252+
image_uris (List[str]): The ECR path where inference code is stored.
2253+
description (str): Description for the additional inference specification
2254+
content_types (list[str]): The supported MIME types
2255+
for the input data.
2256+
response_types (list[str]): The supported MIME types
2257+
for the output data.
2258+
inference_instances (list[str]): A list of the instance
2259+
types that are used to generate inferences in real-time (default: None).
2260+
transform_instances (list[str]): A list of the instance
2261+
types on which a transformation job can be run or on which an endpoint can be
2262+
deployed (default: None).
2263+
2264+
"""
2265+
sagemaker_session = self.sagemaker_session or sagemaker.Session()
2266+
if containers is not None and image_uris is not None:
2267+
raise ValueError("Cannot have both containers and image_uris.")
2268+
if containers is None and image_uris is None:
2269+
raise ValueError("Should have either containers or image_uris for inference.")
2270+
container_def = []
2271+
if image_uris:
2272+
for uri in image_uris:
2273+
container_def.append(
2274+
{
2275+
"Image": uri,
2276+
}
2277+
)
2278+
else:
2279+
container_def = containers
2280+
2281+
model_package_update_args = get_add_model_package_inference_args(
2282+
model_package_arn=self.model_package_arn,
2283+
name=name,
2284+
containers=container_def,
2285+
content_types=content_types,
2286+
description=description,
2287+
response_types=response_types,
2288+
inference_instances=inference_instances,
2289+
transform_instances=transform_instances,
2290+
)
2291+
2292+
sagemaker_session.sagemaker_client.update_model_package(**model_package_update_args)

src/sagemaker/session.py

+82-6
Original file line numberDiff line numberDiff line change
@@ -6557,15 +6557,21 @@ def get_create_model_package_request(
65576557
if task is not None:
65586558
request_dict["Task"] = task
65596559
if containers is not None:
6560-
if not all([content_types, response_types]):
6561-
raise ValueError(
6562-
"content_types and response_types " "must be provided if containers is present."
6563-
)
65646560
inference_specification = {
65656561
"Containers": containers,
6566-
"SupportedContentTypes": content_types,
6567-
"SupportedResponseMIMETypes": response_types,
65686562
}
6563+
if content_types is not None:
6564+
inference_specification.update(
6565+
{
6566+
"SupportedContentTypes": content_types,
6567+
}
6568+
)
6569+
if response_types is not None:
6570+
inference_specification.update(
6571+
{
6572+
"SupportedResponseMIMETypes": response_types,
6573+
}
6574+
)
65696575
if model_package_group_name is not None:
65706576
if inference_instances is not None:
65716577
inference_specification.update(
@@ -6598,6 +6604,76 @@ def get_create_model_package_request(
65986604
return request_dict
65996605

66006606

6607+
def get_add_model_package_inference_args(
6608+
model_package_arn,
6609+
name,
6610+
containers=None,
6611+
content_types=None,
6612+
response_types=None,
6613+
inference_instances=None,
6614+
transform_instances=None,
6615+
description=None,
6616+
):
6617+
"""Get request dictionary for UpdateModelPackage API for additional inference.
6618+
6619+
Args:
6620+
model_package_arn (str): Arn for the model package.
6621+
name (str): Name to identify the additional inference specification
6622+
containers (dict): The Amazon ECR registry path of the Docker image
6623+
that contains the inference code.
6624+
image_uris (List[str]): The ECR path where inference code is stored.
6625+
description (str): Description for the additional inference specification
6626+
content_types (list[str]): The supported MIME types
6627+
for the input data.
6628+
response_types (list[str]): The supported MIME types
6629+
for the output data.
6630+
inference_instances (list[str]): A list of the instance
6631+
types that are used to generate inferences in real-time (default: None).
6632+
transform_instances (list[str]): A list of the instance
6633+
types on which a transformation job can be run or on which an endpoint can be
6634+
deployed (default: None).
6635+
"""
6636+
6637+
request_dict = {}
6638+
if containers is not None:
6639+
inference_specification = {
6640+
"Containers": containers,
6641+
}
6642+
6643+
if name is not None:
6644+
inference_specification.update({"Name": name})
6645+
6646+
if description is not None:
6647+
inference_specification.update({"Description": description})
6648+
if content_types is not None:
6649+
inference_specification.update(
6650+
{
6651+
"SupportedContentTypes": content_types,
6652+
}
6653+
)
6654+
if response_types is not None:
6655+
inference_specification.update(
6656+
{
6657+
"SupportedResponseMIMETypes": response_types,
6658+
}
6659+
)
6660+
if inference_instances is not None:
6661+
inference_specification.update(
6662+
{
6663+
"SupportedRealtimeInferenceInstanceTypes": inference_instances,
6664+
}
6665+
)
6666+
if transform_instances is not None:
6667+
inference_specification.update(
6668+
{
6669+
"SupportedTransformInstanceTypes": transform_instances,
6670+
}
6671+
)
6672+
request_dict["AdditionalInferenceSpecificationsToAdd"] = [inference_specification]
6673+
request_dict.update({"ModelPackageArn": model_package_arn})
6674+
return request_dict
6675+
6676+
66016677
def update_args(args: Dict[str, Any], **kwargs):
66026678
"""Updates the request arguments dict with the value if populated.
66036679

tests/integ/test_model_package.py

+43
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from sagemaker.utils import unique_name_from_base
1818
from tests.integ import DATA_DIR
1919
from sagemaker.xgboost import XGBoostModel
20+
from sagemaker import image_uris
2021

2122
_XGBOOST_PATH = os.path.join(DATA_DIR, "xgboost_abalone")
2223

@@ -61,3 +62,45 @@ def test_update_approval_model_package(sagemaker_session):
6162
sagemaker_session.sagemaker_client.delete_model_package_group(
6263
ModelPackageGroupName=model_group_name
6364
)
65+
66+
67+
def test_inference_specification_addition(sagemaker_session):
68+
69+
model_group_name = unique_name_from_base("test-model-group")
70+
71+
sagemaker_session.sagemaker_client.create_model_package_group(
72+
ModelPackageGroupName=model_group_name
73+
)
74+
75+
xgb_model_data_s3 = sagemaker_session.upload_data(
76+
path=os.path.join(_XGBOOST_PATH, "xgb_model.tar.gz"),
77+
key_prefix="integ-test-data/xgboost/model",
78+
)
79+
model = XGBoostModel(
80+
model_data=xgb_model_data_s3, framework_version="1.3-1", sagemaker_session=sagemaker_session
81+
)
82+
83+
model_package = model.register(
84+
content_types=["text/csv"],
85+
response_types=["text/csv"],
86+
inference_instances=["ml.m5.large"],
87+
transform_instances=["ml.m5.large"],
88+
model_package_group_name=model_group_name,
89+
)
90+
91+
xgb_image = image_uris.retrieve(
92+
"xgboost", sagemaker_session.boto_region_name, version="1", image_scope="inference"
93+
)
94+
model_package.add_inference_specification(image_uris=[xgb_image], name="Inference")
95+
desc_model_package = sagemaker_session.sagemaker_client.describe_model_package(
96+
ModelPackageName=model_package.model_package_arn
97+
)
98+
assert len(desc_model_package["AdditionalInferenceSpecifications"]) == 1
99+
assert desc_model_package["AdditionalInferenceSpecifications"][0]["Name"] == "Inference"
100+
101+
sagemaker_session.sagemaker_client.delete_model_package(
102+
ModelPackageName=model_package.model_package_arn
103+
)
104+
sagemaker_session.sagemaker_client.delete_model_package_group(
105+
ModelPackageGroupName=model_group_name
106+
)

tests/unit/sagemaker/model/test_model_package.py

+73
Original file line numberDiff line numberDiff line change
@@ -326,3 +326,76 @@ def test_model_package_auto_approve_on_deploy(update_approval_status, sagemaker_
326326
update_approval_status.call_args_list[0][1]["approval_status"]
327327
== ModelApprovalStatusEnum.APPROVED
328328
)
329+
330+
331+
def test_update_customer_metadata(sagemaker_session):
332+
model_package = ModelPackage(
333+
role="role",
334+
model_package_arn=MODEL_PACKAGE_VERSIONED_ARN,
335+
sagemaker_session=sagemaker_session,
336+
)
337+
338+
customer_metadata_to_update = {
339+
"Key": "Value",
340+
}
341+
model_package.update_customer_metadata(customer_metadata_properties=customer_metadata_to_update)
342+
343+
sagemaker_session.sagemaker_client.update_model_package.assert_called_with(
344+
ModelPackageArn=MODEL_PACKAGE_VERSIONED_ARN,
345+
CustomerMetadataProperties=customer_metadata_to_update,
346+
)
347+
348+
349+
def test_remove_customer_metadata(sagemaker_session):
350+
model_package = ModelPackage(
351+
role="role",
352+
model_package_arn=MODEL_PACKAGE_VERSIONED_ARN,
353+
sagemaker_session=sagemaker_session,
354+
)
355+
356+
customer_metadata_to_remove = ["Key"]
357+
358+
model_package.remove_customer_metadata_properties(
359+
customer_metadata_properties_to_remove=customer_metadata_to_remove
360+
)
361+
362+
sagemaker_session.sagemaker_client.update_model_package.assert_called_with(
363+
ModelPackageArn=MODEL_PACKAGE_VERSIONED_ARN,
364+
CustomerMetadataPropertiesToRemove=customer_metadata_to_remove,
365+
)
366+
367+
368+
def test_add_inference_specification(sagemaker_session):
369+
model_package = ModelPackage(
370+
role="role",
371+
model_package_arn=MODEL_PACKAGE_VERSIONED_ARN,
372+
sagemaker_session=sagemaker_session,
373+
)
374+
375+
image_uris = ["image_uri"]
376+
377+
containers = [{"Image": "image_uri"}]
378+
379+
try:
380+
model_package.add_inference_specification(
381+
image_uris=image_uris, name="Inference", containers=containers
382+
)
383+
except ValueError as ve:
384+
assert "Cannot have both containers and image_uris." in str(ve)
385+
386+
try:
387+
model_package.add_inference_specification(name="Inference")
388+
except ValueError as ve:
389+
assert "Should have either containers or image_uris for inference." in str(ve)
390+
391+
model_package.add_inference_specification(image_uris=image_uris, name="Inference")
392+
393+
sagemaker_session.sagemaker_client.update_model_package.assert_called_with(
394+
ModelPackageArn=MODEL_PACKAGE_VERSIONED_ARN,
395+
AdditionalInferenceSpecificationsToAdd=[
396+
{
397+
"Containers": [{"Image": "image_uri"}],
398+
"Name": "Inference",
399+
}
400+
],
401+
)

0 commit comments

Comments
 (0)