Skip to content

fix: tags for jumpstart model package models #4061

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Aug 14, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 33 additions & 6 deletions src/sagemaker/jumpstart/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,13 +310,34 @@ def _is_valid_model_id_hook():

super(JumpStartModel, self).__init__(**model_init_kwargs.to_kwargs_dict())

def _create_sagemaker_model(self, *args, **kwargs): # pylint: disable=unused-argument
def _create_sagemaker_model(
self,
instance_type=None,
accelerator_type=None,
tags=None,
serverless_inference_config=None,
**kwargs,
):
"""Create a SageMaker Model Entity

Args:
args: Positional arguments coming from the caller. This class does not require
any so they are ignored.

instance_type (str): Optional. The EC2 instance type that this Model will be
used for, this is only used to determine if the image needs GPU
support or not. (Default: None).
accelerator_type (str): Optional. Type of Elastic Inference accelerator to
attach to an endpoint for model loading and inference, for
example, 'ml.eia1.medium'. If not specified, no Elastic
Inference accelerator will be attached to the endpoint. (Default: None).
tags (List[dict[str, str]]): Optional. The list of tags to add to
the model. Example: >>> tags = [{'Key': 'tagname', 'Value':
'tagvalue'}] For more information about tags, see
https://boto3.amazonaws.com/v1/documentation
/api/latest/reference/services/sagemaker.html#SageMaker.Client.add_tags
(Default: None).
serverless_inference_config (sagemaker.serverless.ServerlessInferenceConfig):
Optional. Specifies configuration related to serverless endpoint. Instance type is
not provided in serverless inference. So this is used to find image URIs.
(Default: None).
kwargs: Keyword arguments coming from the caller. This class does not require
any so they are ignored.
"""
Expand Down Expand Up @@ -347,10 +368,16 @@ def _create_sagemaker_model(self, *args, **kwargs): # pylint: disable=unused-ar
container_def,
vpc_config=self.vpc_config,
enable_network_isolation=self.enable_network_isolation(),
tags=kwargs.get("tags"),
tags=tags,
)
else:
super(JumpStartModel, self)._create_sagemaker_model(*args, **kwargs)
super(JumpStartModel, self)._create_sagemaker_model(
instance_type=instance_type,
accelerator_type=accelerator_type,
tags=tags,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sanity check: will tags value end up being None here? Or will the values in kwargs cause a keyword argument name conflict?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tags cannot be in kwargs since it's an explicit argument. kwargs will only contain new parameters should they get added in the future, to ensure it gets passed to the parent class.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just confirming that in the previous behavior, kwargs.get("tags") returned None b/c the tags key would not have existed

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, if just to make sure on above, if it does maybe we just append the tags from kwargs

serverless_inference_config=serverless_inference_config,
**kwargs,
)

def deploy(
self,
Expand Down
5 changes: 4 additions & 1 deletion src/sagemaker/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1375,7 +1375,10 @@ def deploy(
self._base_name = "-".join((self._base_name, compiled_model_suffix))

self._create_sagemaker_model(
instance_type, accelerator_type, tags, serverless_inference_config
instance_type=instance_type,
accelerator_type=accelerator_type,
tags=tags,
serverless_inference_config=serverless_inference_config,
)

serverless_inference_config_dict = (
Expand Down
7 changes: 6 additions & 1 deletion tests/unit/sagemaker/jumpstart/model/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,10 @@ def test_jumpstart_model_package_arn(

model = JumpStartModel(model_id=model_id)

model.deploy()
tag = {"Key": "foo", "Value": "bar"}
tags = [tag]

model.deploy(tags=tags)

self.assertEqual(
mock_session.return_value.create_model.call_args[0][2],
Expand All @@ -599,6 +602,8 @@ def test_jumpstart_model_package_arn(
},
)

self.assertIn(tag, mock_session.return_value.create_model.call_args[1]["tags"])

@mock.patch("sagemaker.jumpstart.model.is_valid_model_id")
@mock.patch("sagemaker.jumpstart.factory.model.Session")
@mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/sagemaker/model/test_model_package.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def test_create_sagemaker_model_include_tags(sagemaker_session):
sagemaker_session=sagemaker_session,
)

model_package._create_sagemaker_model(tags=tags)
model_package.deploy(tags=tags, instance_type="ml.p2.xlarge", initial_instance_count=1)

sagemaker_session.create_model.assert_called_with(
model_name,
Expand Down