From f56f7d0e4defddffe5449bcd8e4b4bf9c9909d14 Mon Sep 17 00:00:00 2001 From: Keshav Chandak Date: Tue, 5 Dec 2023 10:34:56 +0000 Subject: [PATCH] Added update for model package --- src/sagemaker/model.py | 102 ++++++++++++++++-- src/sagemaker/session.py | 88 +++++++++++++-- tests/integ/test_model_package.py | 43 ++++++++ .../sagemaker/model/test_model_package.py | 73 +++++++++++++ 4 files changed, 293 insertions(+), 13 deletions(-) diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index d9122cacf1..9caca5feff 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -75,6 +75,7 @@ ) from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements from sagemaker.enums import EndpointType +from sagemaker.session import get_add_model_package_inference_args LOGGER = logging.getLogger("sagemaker") @@ -485,12 +486,6 @@ def register( if response_types is not None: self.response_types = response_types - if self.content_types is None: - raise ValueError("The supported MIME types for the input data is not set") - - if self.response_types is None: - raise ValueError("The supported MIME types for the output data is not set") - if image_uri is not None: self.image_uri = image_uri @@ -2181,7 +2176,7 @@ def update_approval_status(self, approval_status, approval_description=None): """Update the approval status for the model package Args: - approval_status (str or PipelineVariable): Model Approval Status, values can be + approval_status (str): Model Approval Status, values can be "Approved", "Rejected", or "PendingManualApproval". approval_description (str): Optional. Description for the approval status of the model (default: None). @@ -2202,3 +2197,96 @@ def update_approval_status(self, approval_status, approval_description=None): update_approval_args["ApprovalDescription"] = approval_description sagemaker_session.sagemaker_client.update_model_package(**update_approval_args) + + def update_customer_metadata(self, customer_metadata_properties: Dict[str, str]): + """Updating customer metadata properties for the model package + + Args: + customer_metadata_properties (dict[str, str]): + A dictionary of key-value paired metadata properties (default: None). + """ + + update_metadata_args = { + "ModelPackageArn": self.model_package_arn, + "CustomerMetadataProperties": customer_metadata_properties, + } + + sagemaker_session = self.sagemaker_session or sagemaker.Session() + sagemaker_session.sagemaker_client.update_model_package(**update_metadata_args) + + def remove_customer_metadata_properties( + self, customer_metadata_properties_to_remove: List[str] + ): + """Removes the specified keys from customer metadata properties + + Args: + customer_metadata_properties (list[str, str]): + list of keys of customer metadata properties to remove. + """ + + delete_metadata_args = { + "ModelPackageArn": self.model_package_arn, + "CustomerMetadataPropertiesToRemove": customer_metadata_properties_to_remove, + } + + sagemaker_session = self.sagemaker_session or sagemaker.Session() + sagemaker_session.sagemaker_client.update_model_package(**delete_metadata_args) + + def add_inference_specification( + self, + name: str, + containers: Dict = None, + image_uris: List[str] = None, + description: str = None, + content_types: List[str] = None, + response_types: List[str] = None, + inference_instances: List[str] = None, + transform_instances: List[str] = None, + ): + """Additional inference specification to be added for the model package + + Args: + name (str): Name to identify the additional inference specification + containers (dict): The Amazon ECR registry path of the Docker image + that contains the inference code. + image_uris (List[str]): The ECR path where inference code is stored. + description (str): Description for the additional inference specification + content_types (list[str]): The supported MIME types + for the input data. + response_types (list[str]): The supported MIME types + for the output data. + inference_instances (list[str]): A list of the instance + types that are used to generate inferences in real-time (default: None). + transform_instances (list[str]): A list of the instance + types on which a transformation job can be run or on which an endpoint can be + deployed (default: None). + + """ + sagemaker_session = self.sagemaker_session or sagemaker.Session() + if containers is not None and image_uris is not None: + raise ValueError("Cannot have both containers and image_uris.") + if containers is None and image_uris is None: + raise ValueError("Should have either containers or image_uris for inference.") + container_def = [] + if image_uris: + for uri in image_uris: + container_def.append( + { + "Image": uri, + } + ) + else: + container_def = containers + + model_package_update_args = get_add_model_package_inference_args( + model_package_arn=self.model_package_arn, + name=name, + containers=container_def, + content_types=content_types, + description=description, + response_types=response_types, + inference_instances=inference_instances, + transform_instances=transform_instances, + ) + + sagemaker_session.sagemaker_client.update_model_package(**model_package_update_args) diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 6b7a8dc2c7..3b2de0239e 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -6557,15 +6557,21 @@ def get_create_model_package_request( if task is not None: request_dict["Task"] = task if containers is not None: - if not all([content_types, response_types]): - raise ValueError( - "content_types and response_types " "must be provided if containers is present." - ) inference_specification = { "Containers": containers, - "SupportedContentTypes": content_types, - "SupportedResponseMIMETypes": response_types, } + if content_types is not None: + inference_specification.update( + { + "SupportedContentTypes": content_types, + } + ) + if response_types is not None: + inference_specification.update( + { + "SupportedResponseMIMETypes": response_types, + } + ) if model_package_group_name is not None: if inference_instances is not None: inference_specification.update( @@ -6598,6 +6604,76 @@ def get_create_model_package_request( return request_dict +def get_add_model_package_inference_args( + model_package_arn, + name, + containers=None, + content_types=None, + response_types=None, + inference_instances=None, + transform_instances=None, + description=None, +): + """Get request dictionary for UpdateModelPackage API for additional inference. + + Args: + model_package_arn (str): Arn for the model package. + name (str): Name to identify the additional inference specification + containers (dict): The Amazon ECR registry path of the Docker image + that contains the inference code. + image_uris (List[str]): The ECR path where inference code is stored. + description (str): Description for the additional inference specification + content_types (list[str]): The supported MIME types + for the input data. + response_types (list[str]): The supported MIME types + for the output data. + inference_instances (list[str]): A list of the instance + types that are used to generate inferences in real-time (default: None). + transform_instances (list[str]): A list of the instance + types on which a transformation job can be run or on which an endpoint can be + deployed (default: None). + """ + + request_dict = {} + if containers is not None: + inference_specification = { + "Containers": containers, + } + + if name is not None: + inference_specification.update({"Name": name}) + + if description is not None: + inference_specification.update({"Description": description}) + if content_types is not None: + inference_specification.update( + { + "SupportedContentTypes": content_types, + } + ) + if response_types is not None: + inference_specification.update( + { + "SupportedResponseMIMETypes": response_types, + } + ) + if inference_instances is not None: + inference_specification.update( + { + "SupportedRealtimeInferenceInstanceTypes": inference_instances, + } + ) + if transform_instances is not None: + inference_specification.update( + { + "SupportedTransformInstanceTypes": transform_instances, + } + ) + request_dict["AdditionalInferenceSpecificationsToAdd"] = [inference_specification] + request_dict.update({"ModelPackageArn": model_package_arn}) + return request_dict + + def update_args(args: Dict[str, Any], **kwargs): """Updates the request arguments dict with the value if populated. diff --git a/tests/integ/test_model_package.py b/tests/integ/test_model_package.py index 641056265e..1554825fc2 100644 --- a/tests/integ/test_model_package.py +++ b/tests/integ/test_model_package.py @@ -17,6 +17,7 @@ from sagemaker.utils import unique_name_from_base from tests.integ import DATA_DIR from sagemaker.xgboost import XGBoostModel +from sagemaker import image_uris _XGBOOST_PATH = os.path.join(DATA_DIR, "xgboost_abalone") @@ -61,3 +62,45 @@ def test_update_approval_model_package(sagemaker_session): sagemaker_session.sagemaker_client.delete_model_package_group( ModelPackageGroupName=model_group_name ) + + +def test_inference_specification_addition(sagemaker_session): + + model_group_name = unique_name_from_base("test-model-group") + + sagemaker_session.sagemaker_client.create_model_package_group( + ModelPackageGroupName=model_group_name + ) + + xgb_model_data_s3 = sagemaker_session.upload_data( + path=os.path.join(_XGBOOST_PATH, "xgb_model.tar.gz"), + key_prefix="integ-test-data/xgboost/model", + ) + model = XGBoostModel( + model_data=xgb_model_data_s3, framework_version="1.3-1", sagemaker_session=sagemaker_session + ) + + model_package = model.register( + content_types=["text/csv"], + response_types=["text/csv"], + inference_instances=["ml.m5.large"], + transform_instances=["ml.m5.large"], + model_package_group_name=model_group_name, + ) + + xgb_image = image_uris.retrieve( + "xgboost", sagemaker_session.boto_region_name, version="1", image_scope="inference" + ) + model_package.add_inference_specification(image_uris=[xgb_image], name="Inference") + desc_model_package = sagemaker_session.sagemaker_client.describe_model_package( + ModelPackageName=model_package.model_package_arn + ) + assert len(desc_model_package["AdditionalInferenceSpecifications"]) == 1 + assert desc_model_package["AdditionalInferenceSpecifications"][0]["Name"] == "Inference" + + sagemaker_session.sagemaker_client.delete_model_package( + ModelPackageName=model_package.model_package_arn + ) + sagemaker_session.sagemaker_client.delete_model_package_group( + ModelPackageGroupName=model_group_name + ) diff --git a/tests/unit/sagemaker/model/test_model_package.py b/tests/unit/sagemaker/model/test_model_package.py index cd2c3d1637..8be561030e 100644 --- a/tests/unit/sagemaker/model/test_model_package.py +++ b/tests/unit/sagemaker/model/test_model_package.py @@ -326,3 +326,76 @@ def test_model_package_auto_approve_on_deploy(update_approval_status, sagemaker_ update_approval_status.call_args_list[0][1]["approval_status"] == ModelApprovalStatusEnum.APPROVED ) + + +def test_update_customer_metadata(sagemaker_session): + model_package = ModelPackage( + role="role", + model_package_arn=MODEL_PACKAGE_VERSIONED_ARN, + sagemaker_session=sagemaker_session, + ) + + customer_metadata_to_update = { + "Key": "Value", + } + model_package.update_customer_metadata(customer_metadata_properties=customer_metadata_to_update) + + sagemaker_session.sagemaker_client.update_model_package.assert_called_with( + ModelPackageArn=MODEL_PACKAGE_VERSIONED_ARN, + CustomerMetadataProperties=customer_metadata_to_update, + ) + + +def test_remove_customer_metadata(sagemaker_session): + model_package = ModelPackage( + role="role", + model_package_arn=MODEL_PACKAGE_VERSIONED_ARN, + sagemaker_session=sagemaker_session, + ) + + customer_metadata_to_remove = ["Key"] + + model_package.remove_customer_metadata_properties( + customer_metadata_properties_to_remove=customer_metadata_to_remove + ) + + sagemaker_session.sagemaker_client.update_model_package.assert_called_with( + ModelPackageArn=MODEL_PACKAGE_VERSIONED_ARN, + CustomerMetadataPropertiesToRemove=customer_metadata_to_remove, + ) + + +def test_add_inference_specification(sagemaker_session): + model_package = ModelPackage( + role="role", + model_package_arn=MODEL_PACKAGE_VERSIONED_ARN, + sagemaker_session=sagemaker_session, + ) + + image_uris = ["image_uri"] + + containers = [{"Image": "image_uri"}] + + try: + model_package.add_inference_specification( + image_uris=image_uris, name="Inference", containers=containers + ) + except ValueError as ve: + assert "Cannot have both containers and image_uris." in str(ve) + + try: + model_package.add_inference_specification(name="Inference") + except ValueError as ve: + assert "Should have either containers or image_uris for inference." in str(ve) + + model_package.add_inference_specification(image_uris=image_uris, name="Inference") + + sagemaker_session.sagemaker_client.update_model_package.assert_called_with( + ModelPackageArn=MODEL_PACKAGE_VERSIONED_ARN, + AdditionalInferenceSpecificationsToAdd=[ + { + "Containers": [{"Image": "image_uri"}], + "Name": "Inference", + } + ], + )