Skip to content

Commit 4e6dd93

Browse files
committed
feature: include fields to work with inference recommender
1 parent d9463d3 commit 4e6dd93

File tree

14 files changed

+372
-2
lines changed

14 files changed

+372
-2
lines changed

src/sagemaker/estimator.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1301,6 +1301,8 @@ def register(
13011301
drift_check_baselines=None,
13021302
customer_metadata_properties=None,
13031303
domain=None,
1304+
sample_payload_url=None,
1305+
task=None,
13041306
**kwargs,
13051307
):
13061308
"""Creates a model package for creating SageMaker models or listing on Marketplace.
@@ -1334,6 +1336,11 @@ def register(
13341336
metadata properties (default: None).
13351337
domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING",
13361338
"MACHINE_LEARNING" (default: None).
1339+
sample_payload_url (str): The S3 path where the sample payload is stored
1340+
(default: None).
1341+
task (str): Task values which are supported by Inference Recommender are "FILL_MASK",
1342+
"IMAGE_CLASSIFICATION", "OBJECT_DETECTION", "TEXT_GENERATION", "IMAGE_SEGMENTATION",
1343+
"CLASSIFICATION", "REGRESSION", "OTHER" (default: None).
13371344
**kwargs: Passed to invocation of ``create_model()``. Implementations may customize
13381345
``create_model()`` to accept ``**kwargs`` to customize model creation during
13391346
deploy. For more, see the implementation docs.
@@ -1371,6 +1378,8 @@ def register(
13711378
drift_check_baselines=drift_check_baselines,
13721379
customer_metadata_properties=customer_metadata_properties,
13731380
domain=domain,
1381+
sample_payload_url=sample_payload_url,
1382+
task=task,
13741383
)
13751384

13761385
@property

src/sagemaker/huggingface/model.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,12 @@ def register(
306306
drift_check_baselines=None,
307307
customer_metadata_properties=None,
308308
domain=None,
309+
sample_payload_url=None,
310+
task=None,
311+
framework=None,
312+
framework_version=None,
313+
nearest_model_name=None,
314+
data_input_configuration=None,
309315
):
310316
"""Creates a model package for creating SageMaker models or listing on Marketplace.
311317
@@ -337,6 +343,18 @@ def register(
337343
metadata properties (default: None).
338344
domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING",
339345
"MACHINE_LEARNING" (default: None).
346+
sample_payload_url (str): The S3 path where the sample payload is stored
347+
(default: None).
348+
task (str): Task values which are supported by Inference Recommender are "FILL_MASK",
349+
"IMAGE_CLASSIFICATION", "OBJECT_DETECTION", "TEXT_GENERATION", "IMAGE_SEGMENTATION",
350+
"CLASSIFICATION", "REGRESSION", "OTHER" (default: None).
351+
framework (str): Machine learning framework of the model package container image
352+
(default: None).
353+
framework_version (str): Framework version of the Model Package Container Image
354+
(default: None).
355+
nearest_model_name (str): Name of a pre-trained machine learning benchmarked by
356+
Amazon SageMaker Inference Recommender (default: None).
357+
data_input_configuration (str): Input object for the model (default: None).
340358
341359
Returns:
342360
A `sagemaker.model.ModelPackage` instance.
@@ -367,6 +385,12 @@ def register(
367385
drift_check_baselines=drift_check_baselines,
368386
customer_metadata_properties=customer_metadata_properties,
369387
domain=domain,
388+
sample_payload_url=sample_payload_url,
389+
task=task,
390+
framework=framework,
391+
framework_version=framework_version,
392+
nearest_model_name=nearest_model_name,
393+
data_input_configuration=data_input_configuration,
370394
)
371395

372396
def prepare_container_def(

src/sagemaker/model.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,12 @@ def register(
310310
customer_metadata_properties=None,
311311
validation_specification=None,
312312
domain=None,
313+
task=None,
314+
sample_payload_url=None,
315+
framework=None,
316+
framework_version=None,
317+
nearest_model_name=None,
318+
data_input_configuration=None,
313319
):
314320
"""Creates a model package for creating SageMaker models or listing on Marketplace.
315321
@@ -339,6 +345,18 @@ def register(
339345
metadata properties (default: None).
340346
domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING",
341347
"MACHINE_LEARNING" (default: None).
348+
sample_payload_url (str): The S3 path where the sample payload is stored
349+
(default: None).
350+
task (str): Task values which are supported by Inference Recommender are "FILL_MASK",
351+
"IMAGE_CLASSIFICATION", "OBJECT_DETECTION", "TEXT_GENERATION", "IMAGE_SEGMENTATION",
352+
"CLASSIFICATION", "REGRESSION", "OTHER" (default: None).
353+
framework (str): Machine learning framework of the model package container image
354+
(default: None).
355+
framework_version (str): Framework version of the Model Package Container Image
356+
(default: None).
357+
nearest_model_name (str): Name of a pre-trained machine learning benchmarked by
358+
Amazon SageMaker Inference Recommender (default: None).
359+
data_input_configuration (str): Input object for the model (default: None).
342360
343361
Returns:
344362
A `sagemaker.model.ModelPackage` instance or pipeline step arguments
@@ -352,7 +370,20 @@ def register(
352370
if model_package_group_name is not None:
353371
container_def = self.prepare_container_def()
354372
else:
355-
container_def = {"Image": self.image_uri, "ModelDataUrl": self.model_data}
373+
container_def = {
374+
"Image": self.image_uri,
375+
"ModelDataUrl": self.model_data,
376+
}
377+
container_def.update(
378+
{
379+
"Framework": framework,
380+
"FrameworkVersion": framework_version,
381+
"NearestModelName": nearest_model_name,
382+
"ModelInput": {
383+
"DataInputConfig": data_input_configuration,
384+
},
385+
}
386+
)
356387
model_pkg_args = sagemaker.get_model_package_args(
357388
content_types,
358389
response_types,
@@ -370,6 +401,8 @@ def register(
370401
customer_metadata_properties=customer_metadata_properties,
371402
validation_specification=validation_specification,
372403
domain=domain,
404+
sample_payload_url=sample_payload_url,
405+
task=task,
373406
)
374407
model_package = self.sagemaker_session.create_model_package_from_containers(
375408
**model_pkg_args

src/sagemaker/mxnet/model.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,12 @@ def register(
159159
drift_check_baselines=None,
160160
customer_metadata_properties=None,
161161
domain=None,
162+
sample_payload_url=None,
163+
task=None,
164+
framework=None,
165+
framework_version=None,
166+
nearest_model_name=None,
167+
data_input_configuration=None,
162168
):
163169
"""Creates a model package for creating SageMaker models or listing on Marketplace.
164170
@@ -188,6 +194,18 @@ def register(
188194
metadata properties (default: None).
189195
domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING",
190196
"MACHINE_LEARNING" (default: None).
197+
sample_payload_url (str): The S3 path where the sample payload is stored
198+
(default: None).
199+
task (str): Task values which are supported by Inference Recommender are "FILL_MASK",
200+
"IMAGE_CLASSIFICATION", "OBJECT_DETECTION", "TEXT_GENERATION", "IMAGE_SEGMENTATION",
201+
"CLASSIFICATION", "REGRESSION", "OTHER" (default: None).
202+
framework (str): Machine learning framework of the model package container image
203+
(default: None).
204+
framework_version (str): Framework version of the Model Package Container Image
205+
(default: None).
206+
nearest_model_name (str): Name of a pre-trained machine learning benchmarked by
207+
Amazon SageMaker Inference Recommender (default: None).
208+
data_input_configuration (str): Input object for the model (default: None).
191209
192210
Returns:
193211
A `sagemaker.model.ModelPackage` instance.
@@ -218,6 +236,12 @@ def register(
218236
drift_check_baselines=drift_check_baselines,
219237
customer_metadata_properties=customer_metadata_properties,
220238
domain=domain,
239+
sample_payload_url=sample_payload_url,
240+
task=task,
241+
framework=framework,
242+
framework_version=framework_version,
243+
nearest_model_name=nearest_model_name,
244+
data_input_configuration=data_input_configuration,
221245
)
222246

223247
def prepare_container_def(

src/sagemaker/pipeline.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,12 @@ def register(
279279
drift_check_baselines: Optional[DriftCheckBaselines] = None,
280280
customer_metadata_properties: Optional[Dict[str, str]] = None,
281281
domain: Optional[str] = None,
282+
sample_payload_url: Optional[str] = None,
283+
task: Optional[str] = None,
284+
framework: Optional[str] = None,
285+
framework_version: Optional[str] = None,
286+
nearest_model_name: Optional[str] = None,
287+
data_input_configuration: Optional[str] = None,
282288
):
283289
"""Creates a model package for creating SageMaker models or listing on Marketplace.
284290
@@ -308,6 +314,18 @@ def register(
308314
metadata properties (default: None).
309315
domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING",
310316
"MACHINE_LEARNING" (default: None).
317+
sample_payload_url (str): The S3 path where the sample payload is stored
318+
(default: None).
319+
task (str): Task values which are supported by Inference Recommender are "FILL_MASK",
320+
"IMAGE_CLASSIFICATION", "OBJECT_DETECTION", "TEXT_GENERATION", "IMAGE_SEGMENTATION",
321+
"CLASSIFICATION", "REGRESSION", "OTHER" (default: None).
322+
framework (str): Machine learning framework of the model package container image
323+
(default: None).
324+
framework_version (str): Framework version of the Model Package Container Image
325+
(default: None).
326+
nearest_model_name (str): Name of a pre-trained machine learning benchmarked by
327+
Amazon SageMaker Inference Recommender (default: None).
328+
data_input_configuration (str): Input object for the model (default: None).
311329
312330
Returns:
313331
A `sagemaker.model.ModelPackage` instance.
@@ -319,11 +337,28 @@ def register(
319337
container_def = self.pipeline_container_def(
320338
inference_instances[0] if inference_instances else None
321339
)
340+
container_def[0].update(
341+
{
342+
"Framework": framework,
343+
"FrameworkVersion": framework_version,
344+
"NearestModelName": nearest_model_name,
345+
"ModelInput": {
346+
"DataInputConfig": data_input_configuration,
347+
},
348+
}
349+
)
322350
else:
323351
container_def = [
324352
{
325353
"Image": image_uri or model.image_uri,
326354
"ModelDataUrl": model.model_data,
355+
"Framework": framework or model.framework,
356+
"FrameworkVersion": framework_version or model.framework_version,
357+
"NearestModelName": nearest_model_name or model.nearest_model_name,
358+
"ModelInput": {
359+
"DataInputConfig": data_input_configuration
360+
or model.data_input_configuration
361+
},
327362
}
328363
for model in self.models
329364
]
@@ -344,6 +379,8 @@ def register(
344379
drift_check_baselines=drift_check_baselines,
345380
customer_metadata_properties=customer_metadata_properties,
346381
domain=domain,
382+
sample_payload_url=sample_payload_url,
383+
task=task,
347384
)
348385

349386
self.sagemaker_session.create_model_package_from_containers(**model_pkg_args)

src/sagemaker/pytorch/model.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,12 @@ def register(
160160
drift_check_baselines=None,
161161
customer_metadata_properties=None,
162162
domain=None,
163+
sample_payload_url=None,
164+
task=None,
165+
framework=None,
166+
framework_version=None,
167+
nearest_model_name=None,
168+
data_input_configuration=None,
163169
):
164170
"""Creates a model package for creating SageMaker models or listing on Marketplace.
165171
@@ -189,6 +195,18 @@ def register(
189195
metadata properties (default: None).
190196
domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING",
191197
"MACHINE_LEARNING" (default: None).
198+
sample_payload_url (str): The S3 path where the sample payload is stored
199+
(default: None).
200+
task (str): Task values which are supported by Inference Recommender are "FILL_MASK",
201+
"IMAGE_CLASSIFICATION", "OBJECT_DETECTION", "TEXT_GENERATION", "IMAGE_SEGMENTATION",
202+
"CLASSIFICATION", "REGRESSION", "OTHER" (default: None).
203+
framework (str): Machine learning framework of the model package container image
204+
(default: None).
205+
framework_version (str): Framework version of the Model Package Container Image
206+
(default: None).
207+
nearest_model_name (str): Name of a pre-trained machine learning benchmarked by
208+
Amazon SageMaker Inference Recommender (default: None).
209+
data_input_configuration (str): Input object for the model (default: None).
192210
193211
Returns:
194212
A `sagemaker.model.ModelPackage` instance.
@@ -219,6 +237,12 @@ def register(
219237
drift_check_baselines=drift_check_baselines,
220238
customer_metadata_properties=customer_metadata_properties,
221239
domain=domain,
240+
sample_payload_url=sample_payload_url,
241+
task=task,
242+
framework=framework,
243+
framework_version=framework_version,
244+
nearest_model_name=nearest_model_name,
245+
data_input_configuration=data_input_configuration,
222246
)
223247

224248
def prepare_container_def(

0 commit comments

Comments
 (0)