@@ -133,11 +133,7 @@ def __init__(
133
133
py_version (str): Python version you want to use for executing your
134
134
model training code. Defaults to ``None``. Required unless
135
135
``image_uri`` is provided.
136
- image_uri (str): A Docker image URI. Defaults to None. For serverless
137
- inferece, it is required. More image information can be found in
138
- `Amazon SageMaker provided algorithms and Deep Learning Containers
139
- <https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-algo-docker-registry-paths.html>`_.
140
- For instance based inference, if not specified, a
136
+ image_uri (str): A Docker image URI. Defaults to None. If not specified, a
141
137
default image for PyTorch will be used. If ``framework_version``
142
138
or ``py_version`` are ``None``, then ``image_uri`` is required. If
143
139
also ``None``, then a ``ValueError`` will be raised.
@@ -272,7 +268,7 @@ def deploy(
272
268
is not None. Otherwise, return None.
273
269
"""
274
270
275
- if not self .image_uri and instance_type .startswith ("ml.inf" ):
271
+ if not self .image_uri and instance_type is not None and instance_type .startswith ("ml.inf" ):
276
272
self .image_uri = self .serving_image_uri (
277
273
region_name = self .sagemaker_session .boto_session .region_name ,
278
274
instance_type = instance_type ,
@@ -365,7 +361,9 @@ def register(
365
361
drift_check_baselines = drift_check_baselines ,
366
362
)
367
363
368
- def prepare_container_def (self , instance_type = None , accelerator_type = None ):
364
+ def prepare_container_def (
365
+ self , instance_type = None , accelerator_type = None , serverless_inference_config = None
366
+ ):
369
367
"""A container definition with framework configuration set in model environment variables.
370
368
371
369
Args:
@@ -374,21 +372,27 @@ def prepare_container_def(self, instance_type=None, accelerator_type=None):
374
372
accelerator_type (str): The Elastic Inference accelerator type to
375
373
deploy to the instance for loading and making inferences to the
376
374
model.
375
+ serverless_inference_config (sagemaker.serverless.ServerlessInferenceConfig):
376
+ Specifies configuration related to serverless endpoint. Instance type is
377
+ not provided in serverless inference. So this is used to find image URIs.
377
378
378
379
Returns:
379
380
dict[str, str]: A container definition object usable with the
380
381
CreateModel API.
381
382
"""
382
383
deploy_image = self .image_uri
383
384
if not deploy_image :
384
- if instance_type is None :
385
+ if instance_type is None and serverless_inference_config is None :
385
386
raise ValueError (
386
387
"Must supply either an instance type (for choosing CPU vs GPU) or an image URI."
387
388
)
388
389
389
390
region_name = self .sagemaker_session .boto_session .region_name
390
391
deploy_image = self .serving_image_uri (
391
- region_name , instance_type , accelerator_type = accelerator_type
392
+ region_name ,
393
+ instance_type ,
394
+ accelerator_type = accelerator_type ,
395
+ serverless_inference_config = serverless_inference_config ,
392
396
)
393
397
394
398
deploy_key_prefix = model_code_key_prefix (self .key_prefix , self .name , deploy_image )
@@ -402,7 +406,13 @@ def prepare_container_def(self, instance_type=None, accelerator_type=None):
402
406
deploy_image , self .repacked_model_data or self .model_data , deploy_env
403
407
)
404
408
405
- def serving_image_uri (self , region_name , instance_type , accelerator_type = None ):
409
+ def serving_image_uri (
410
+ self ,
411
+ region_name ,
412
+ instance_type = None ,
413
+ accelerator_type = None ,
414
+ serverless_inference_config = None ,
415
+ ):
406
416
"""Create a URI for the serving image.
407
417
408
418
Args:
@@ -412,6 +422,9 @@ def serving_image_uri(self, region_name, instance_type, accelerator_type=None):
412
422
accelerator_type (str): The Elastic Inference accelerator type to
413
423
deploy to the instance for loading and making inferences to the
414
424
model.
425
+ serverless_inference_config (sagemaker.serverless.ServerlessInferenceConfig):
426
+ Specifies configuration related to serverless endpoint. Instance type is
427
+ not provided in serverless inference. So this is used used to determine device type.
415
428
416
429
Returns:
417
430
str: The appropriate image URI based on the given parameters.
@@ -432,4 +445,5 @@ def serving_image_uri(self, region_name, instance_type, accelerator_type=None):
432
445
accelerator_type = accelerator_type ,
433
446
image_scope = "inference" ,
434
447
base_framework_version = base_framework_version ,
448
+ serverless_inference_config = serverless_inference_config ,
435
449
)
0 commit comments