35
35
from sagemaker .serverless import ServerlessInferenceConfig
36
36
from sagemaker .transformer import Transformer
37
37
from sagemaker .jumpstart .utils import add_jumpstart_tags , get_jumpstart_base_name_if_jumpstart_model
38
- from sagemaker .utils import unique_name_from_base
38
+ from sagemaker .utils import (
39
+ unique_name_from_base ,
40
+ update_container_with_inference_params ,
41
+ )
39
42
from sagemaker .async_inference import AsyncInferenceConfig
40
43
from sagemaker .predictor_async import AsyncPredictor
41
44
from sagemaker .workflow import is_pipeline_variable
@@ -310,6 +313,12 @@ def register(
310
313
customer_metadata_properties = None ,
311
314
validation_specification = None ,
312
315
domain = None ,
316
+ task = None ,
317
+ sample_payload_url = None ,
318
+ framework = None ,
319
+ framework_version = None ,
320
+ nearest_model_name = None ,
321
+ data_input_configuration = None ,
313
322
):
314
323
"""Creates a model package for creating SageMaker models or listing on Marketplace.
315
324
@@ -339,6 +348,18 @@ def register(
339
348
metadata properties (default: None).
340
349
domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING",
341
350
"MACHINE_LEARNING" (default: None).
351
+ sample_payload_url (str): The S3 path where the sample payload is stored
352
+ (default: None).
353
+ task (str): Task values which are supported by Inference Recommender are "FILL_MASK",
354
+ "IMAGE_CLASSIFICATION", "OBJECT_DETECTION", "TEXT_GENERATION", "IMAGE_SEGMENTATION",
355
+ "CLASSIFICATION", "REGRESSION", "OTHER" (default: None).
356
+ framework (str): Machine learning framework of the model package container image
357
+ (default: None).
358
+ framework_version (str): Framework version of the Model Package Container Image
359
+ (default: None).
360
+ nearest_model_name (str): Name of a pre-trained machine learning benchmarked by
361
+ Amazon SageMaker Inference Recommender (default: None).
362
+ data_input_configuration (str): Input object for the model (default: None).
342
363
343
364
Returns:
344
365
A `sagemaker.model.ModelPackage` instance or pipeline step arguments
@@ -349,10 +370,22 @@ def register(
349
370
raise ValueError ("SageMaker Model Package cannot be created without model data." )
350
371
if image_uri is not None :
351
372
self .image_uri = image_uri
373
+
352
374
if model_package_group_name is not None :
353
375
container_def = self .prepare_container_def ()
376
+ update_container_with_inference_params (
377
+ framework = framework ,
378
+ framework_version = framework_version ,
379
+ nearest_model_name = nearest_model_name ,
380
+ data_input_configuration = data_input_configuration ,
381
+ container_obj = container_def ,
382
+ )
354
383
else :
355
- container_def = {"Image" : self .image_uri , "ModelDataUrl" : self .model_data }
384
+ container_def = {
385
+ "Image" : self .image_uri ,
386
+ "ModelDataUrl" : self .model_data ,
387
+ }
388
+
356
389
model_pkg_args = sagemaker .get_model_package_args (
357
390
content_types ,
358
391
response_types ,
@@ -370,6 +403,8 @@ def register(
370
403
customer_metadata_properties = customer_metadata_properties ,
371
404
validation_specification = validation_specification ,
372
405
domain = domain ,
406
+ sample_payload_url = sample_payload_url ,
407
+ task = task ,
373
408
)
374
409
model_package = self .sagemaker_session .create_model_package_from_containers (
375
410
** model_pkg_args
0 commit comments