@@ -310,13 +310,32 @@ def _is_valid_model_id_hook():
310
310
311
311
super (JumpStartModel , self ).__init__ (** model_init_kwargs .to_kwargs_dict ())
312
312
313
- def _create_sagemaker_model (self , * args , ** kwargs ): # pylint: disable=unused-argument
313
+ def _create_sagemaker_model (
314
+ self ,
315
+ instance_type = None ,
316
+ accelerator_type = None ,
317
+ tags = None ,
318
+ serverless_inference_config = None ,
319
+ ** kwargs ,
320
+ ):
314
321
"""Create a SageMaker Model Entity
315
322
316
323
Args:
317
- args: Positional arguments coming from the caller. This class does not require
318
- any so they are ignored.
319
-
324
+ instance_type (str): The EC2 instance type that this Model will be
325
+ used for, this is only used to determine if the image needs GPU
326
+ support or not.
327
+ accelerator_type (str): Type of Elastic Inference accelerator to
328
+ attach to an endpoint for model loading and inference, for
329
+ example, 'ml.eia1.medium'. If not specified, no Elastic
330
+ Inference accelerator will be attached to the endpoint.
331
+ tags (List[dict[str, str]]): Optional. The list of tags to add to
332
+ the model. Example: >>> tags = [{'Key': 'tagname', 'Value':
333
+ 'tagvalue'}] For more information about tags, see
334
+ https://boto3.amazonaws.com/v1/documentation
335
+ /api/latest/reference/services/sagemaker.html#SageMaker.Client.add_tags
336
+ serverless_inference_config (sagemaker.serverless.ServerlessInferenceConfig):
337
+ Specifies configuration related to serverless endpoint. Instance type is
338
+ not provided in serverless inference. So this is used to find image URIs.
320
339
kwargs: Keyword arguments coming from the caller. This class does not require
321
340
any so they are ignored.
322
341
"""
@@ -347,10 +366,16 @@ def _create_sagemaker_model(self, *args, **kwargs): # pylint: disable=unused-ar
347
366
container_def ,
348
367
vpc_config = self .vpc_config ,
349
368
enable_network_isolation = self .enable_network_isolation (),
350
- tags = kwargs . get ( " tags" ) ,
369
+ tags = tags ,
351
370
)
352
371
else :
353
- super (JumpStartModel , self )._create_sagemaker_model (* args , ** kwargs )
372
+ super (JumpStartModel , self )._create_sagemaker_model (
373
+ instance_type = instance_type ,
374
+ accelerator_type = accelerator_type ,
375
+ tags = tags ,
376
+ serverless_inference_config = serverless_inference_config ,
377
+ ** kwargs ,
378
+ )
354
379
355
380
def deploy (
356
381
self ,
0 commit comments