Skip to content

Commit c1157bd

Browse files
committed
fix: tags for jumpstart model package models
1 parent 569e85c commit c1157bd

File tree

2 files changed

+37
-7
lines changed

2 files changed

+37
-7
lines changed

src/sagemaker/jumpstart/model.py

+31-6
Original file line numberDiff line numberDiff line change
@@ -310,13 +310,32 @@ def _is_valid_model_id_hook():
310310

311311
super(JumpStartModel, self).__init__(**model_init_kwargs.to_kwargs_dict())
312312

313-
def _create_sagemaker_model(self, *args, **kwargs): # pylint: disable=unused-argument
313+
def _create_sagemaker_model(
314+
self,
315+
instance_type=None,
316+
accelerator_type=None,
317+
tags=None,
318+
serverless_inference_config=None,
319+
**kwargs,
320+
):
314321
"""Create a SageMaker Model Entity
315322
316323
Args:
317-
args: Positional arguments coming from the caller. This class does not require
318-
any so they are ignored.
319-
324+
instance_type (str): The EC2 instance type that this Model will be
325+
used for, this is only used to determine if the image needs GPU
326+
support or not.
327+
accelerator_type (str): Type of Elastic Inference accelerator to
328+
attach to an endpoint for model loading and inference, for
329+
example, 'ml.eia1.medium'. If not specified, no Elastic
330+
Inference accelerator will be attached to the endpoint.
331+
tags (List[dict[str, str]]): Optional. The list of tags to add to
332+
the model. Example: >>> tags = [{'Key': 'tagname', 'Value':
333+
'tagvalue'}] For more information about tags, see
334+
https://boto3.amazonaws.com/v1/documentation
335+
/api/latest/reference/services/sagemaker.html#SageMaker.Client.add_tags
336+
serverless_inference_config (sagemaker.serverless.ServerlessInferenceConfig):
337+
Specifies configuration related to serverless endpoint. Instance type is
338+
not provided in serverless inference. So this is used to find image URIs.
320339
kwargs: Keyword arguments coming from the caller. This class does not require
321340
any so they are ignored.
322341
"""
@@ -347,10 +366,16 @@ def _create_sagemaker_model(self, *args, **kwargs): # pylint: disable=unused-ar
347366
container_def,
348367
vpc_config=self.vpc_config,
349368
enable_network_isolation=self.enable_network_isolation(),
350-
tags=kwargs.get("tags"),
369+
tags=tags,
351370
)
352371
else:
353-
super(JumpStartModel, self)._create_sagemaker_model(*args, **kwargs)
372+
super(JumpStartModel, self)._create_sagemaker_model(
373+
instance_type=instance_type,
374+
accelerator_type=accelerator_type,
375+
tags=tags,
376+
serverless_inference_config=serverless_inference_config,
377+
**kwargs,
378+
)
354379

355380
def deploy(
356381
self,

tests/unit/sagemaker/jumpstart/model/test_model.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -589,7 +589,10 @@ def test_jumpstart_model_package_arn(
589589

590590
model = JumpStartModel(model_id=model_id)
591591

592-
model.deploy()
592+
tag = {"Key": "foo", "Value": "bar"}
593+
tags = [tag]
594+
595+
model.deploy(tags=tags)
593596

594597
self.assertEqual(
595598
mock_session.return_value.create_model.call_args[0][2],
@@ -599,6 +602,8 @@ def test_jumpstart_model_package_arn(
599602
},
600603
)
601604

605+
self.assertIn(tag, mock_session.return_value.create_model.call_args[1]["tags"])
606+
602607
@mock.patch("sagemaker.jumpstart.model.is_valid_model_id")
603608
@mock.patch("sagemaker.jumpstart.factory.model.Session")
604609
@mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")

0 commit comments

Comments
 (0)