@@ -272,7 +272,7 @@ def deploy(
272
272
is not None. Otherwise, return None.
273
273
"""
274
274
275
- if not self .image_uri and instance_type .startswith ("ml.inf" ):
275
+ if not self .image_uri and instance_type is not None and instance_type .startswith ("ml.inf" ):
276
276
self .image_uri = self .serving_image_uri (
277
277
region_name = self .sagemaker_session .boto_session .region_name ,
278
278
instance_type = instance_type ,
@@ -365,7 +365,9 @@ def register(
365
365
drift_check_baselines = drift_check_baselines ,
366
366
)
367
367
368
- def prepare_container_def (self , instance_type = None , accelerator_type = None ):
368
+ def prepare_container_def (
369
+ self , instance_type = None , accelerator_type = None , serverless_inference_config = None
370
+ ):
369
371
"""A container definition with framework configuration set in model environment variables.
370
372
371
373
Args:
@@ -381,14 +383,17 @@ def prepare_container_def(self, instance_type=None, accelerator_type=None):
381
383
"""
382
384
deploy_image = self .image_uri
383
385
if not deploy_image :
384
- if instance_type is None :
386
+ if instance_type is None and serverless_inference_config is None :
385
387
raise ValueError (
386
388
"Must supply either an instance type (for choosing CPU vs GPU) or an image URI."
387
389
)
388
390
389
391
region_name = self .sagemaker_session .boto_session .region_name
390
392
deploy_image = self .serving_image_uri (
391
- region_name , instance_type , accelerator_type = accelerator_type
393
+ region_name ,
394
+ instance_type ,
395
+ accelerator_type = accelerator_type ,
396
+ serverless_inference_config = serverless_inference_config ,
392
397
)
393
398
394
399
deploy_key_prefix = model_code_key_prefix (self .key_prefix , self .name , deploy_image )
@@ -402,7 +407,13 @@ def prepare_container_def(self, instance_type=None, accelerator_type=None):
402
407
deploy_image , self .repacked_model_data or self .model_data , deploy_env
403
408
)
404
409
405
- def serving_image_uri (self , region_name , instance_type , accelerator_type = None ):
410
+ def serving_image_uri (
411
+ self ,
412
+ region_name ,
413
+ instance_type = None ,
414
+ accelerator_type = None ,
415
+ serverless_inference_config = None ,
416
+ ):
406
417
"""Create a URI for the serving image.
407
418
408
419
Args:
@@ -432,4 +443,5 @@ def serving_image_uri(self, region_name, instance_type, accelerator_type=None):
432
443
accelerator_type = accelerator_type ,
433
444
image_scope = "inference" ,
434
445
base_framework_version = base_framework_version ,
446
+ serverless_inference_config = serverless_inference_config ,
435
447
)
0 commit comments