diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index b2f482a4d7..eeb14518dd 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -195,6 +195,7 @@ def _get_model_package_args( marketplace_cert=False, approval_status=None, description=None, + tags=None, ): """Get arguments for session.create_model_package method. @@ -250,6 +251,8 @@ def _get_model_package_args( model_package_args["approval_status"] = approval_status if description is not None: model_package_args["description"] = description + if tags is not None: + model_package_args["tags"] = tags return model_package_args def _init_sagemaker_session_if_does_not_exist(self, instance_type): diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 901d61f086..980a720ac1 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -2724,6 +2724,7 @@ def _get_create_model_package_request( marketplace_cert=False, approval_status="PendingManualApproval", description=None, + tags=None, ): """Get request dictionary for CreateModelPackage API. @@ -2761,6 +2762,8 @@ def _get_create_model_package_request( request_dict["ModelPackageGroupName"] = model_package_group_name if description is not None: request_dict["ModelPackageDescription"] = description + if tags is not None: + request_dict["Tags"] = tags if model_metrics: request_dict["ModelMetrics"] = model_metrics if metadata_properties: diff --git a/src/sagemaker/workflow/_utils.py b/src/sagemaker/workflow/_utils.py index ceed4e0dec..a2ab24e3da 100644 --- a/src/sagemaker/workflow/_utils.py +++ b/src/sagemaker/workflow/_utils.py @@ -225,6 +225,7 @@ def __init__( compile_model_family=None, description=None, depends_on: List[str] = None, + tags=None, **kwargs, ): """Constructor of a register model step. @@ -264,6 +265,7 @@ def __init__( self.inference_instances = inference_instances self.transform_instances = transform_instances self.model_package_group_name = model_package_group_name + self.tags = tags self.model_metrics = model_metrics self.metadata_properties = metadata_properties self.approval_status = approval_status @@ -324,10 +326,12 @@ def arguments(self) -> RequestType: metadata_properties=self.metadata_properties, approval_status=self.approval_status, description=self.description, + tags=self.tags, ) request_dict = model.sagemaker_session._get_create_model_package_request( **model_package_args ) + # these are not available in the workflow service and will cause rejection if "CertifyForMarketplace" in request_dict: request_dict.pop("CertifyForMarketplace") diff --git a/src/sagemaker/workflow/step_collections.py b/src/sagemaker/workflow/step_collections.py index dd8f32b7fc..6ee048c0b2 100644 --- a/src/sagemaker/workflow/step_collections.py +++ b/src/sagemaker/workflow/step_collections.py @@ -67,6 +67,7 @@ def __init__( image_uri=None, compile_model_family=None, description=None, + tags=None, **kwargs, ): """Construct steps `_RepackModelStep` and `_RegisterModelStep` based on the estimator. @@ -94,6 +95,10 @@ def __init__( compile_model_family (str): The instance family for the compiled model. If specified, a compiled model is used (default: None). description (str): Model Package description (default: None). + tags (List[dict[str, str]]): The list of tags to attach to the model package group. Note + that tags will only be applied to newly created model package groups; if the + name of an existing group is passed to "model_package_group_name", + tags will not be applied. **kwargs: additional arguments to `create_model`. """ steps: List[Step] = [] @@ -134,6 +139,7 @@ def __init__( image_uri=image_uri, compile_model_family=compile_model_family, description=description, + tags=tags, **kwargs, ) if not repack_model: diff --git a/tests/unit/sagemaker/workflow/test_step_collections.py b/tests/unit/sagemaker/workflow/test_step_collections.py index 9719e13aec..9ca14f4aaf 100644 --- a/tests/unit/sagemaker/workflow/test_step_collections.py +++ b/tests/unit/sagemaker/workflow/test_step_collections.py @@ -182,6 +182,7 @@ def test_register_model(estimator, model_metrics): approval_status="Approved", description="description", depends_on=["TestStep"], + tags=[{"Key": "myKey", "Value": "myValue"}], ) assert ordered(register_model.request_dicts()) == ordered( [ @@ -210,6 +211,7 @@ def test_register_model(estimator, model_metrics): }, "ModelPackageDescription": "description", "ModelPackageGroupName": "mpg", + "Tags": [{"Key": "myKey", "Value": "myValue"}], }, }, ]