Skip to content

Commit 9c19524

Browse files
rahven14JoseJuan98
authored andcommitted
feature: include fields to work with inference recommender (aws#3174)
1 parent ae7b372 commit 9c19524

File tree

17 files changed

+510
-4
lines changed

17 files changed

+510
-4
lines changed

src/sagemaker/estimator.py

+24
Original file line numberDiff line numberDiff line change
@@ -1310,6 +1310,12 @@ def register(
13101310
drift_check_baselines=None,
13111311
customer_metadata_properties=None,
13121312
domain=None,
1313+
sample_payload_url=None,
1314+
task=None,
1315+
framework=None,
1316+
framework_version=None,
1317+
nearest_model_name=None,
1318+
data_input_configuration=None,
13131319
**kwargs,
13141320
):
13151321
"""Creates a model package for creating SageMaker models or listing on Marketplace.
@@ -1343,6 +1349,18 @@ def register(
13431349
metadata properties (default: None).
13441350
domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING",
13451351
"MACHINE_LEARNING" (default: None).
1352+
sample_payload_url (str): The S3 path where the sample payload is stored
1353+
(default: None).
1354+
task (str): Task values which are supported by Inference Recommender are "FILL_MASK",
1355+
"IMAGE_CLASSIFICATION", "OBJECT_DETECTION", "TEXT_GENERATION", "IMAGE_SEGMENTATION",
1356+
"CLASSIFICATION", "REGRESSION", "OTHER" (default: None).
1357+
framework (str): Machine learning framework of the model package container image
1358+
(default: None).
1359+
framework_version (str): Framework version of the Model Package Container Image
1360+
(default: None).
1361+
nearest_model_name (str): Name of a pre-trained machine learning benchmarked by
1362+
Amazon SageMaker Inference Recommender (default: None).
1363+
data_input_configuration (str): Input object for the model (default: None).
13461364
**kwargs: Passed to invocation of ``create_model()``. Implementations may customize
13471365
``create_model()`` to accept ``**kwargs`` to customize model creation during
13481366
deploy. For more, see the implementation docs.
@@ -1380,6 +1398,12 @@ def register(
13801398
drift_check_baselines=drift_check_baselines,
13811399
customer_metadata_properties=customer_metadata_properties,
13821400
domain=domain,
1401+
sample_payload_url=sample_payload_url,
1402+
task=task,
1403+
framework=framework,
1404+
framework_version=framework_version,
1405+
nearest_model_name=nearest_model_name,
1406+
data_input_configuration=data_input_configuration,
13831407
)
13841408

13851409
@property

src/sagemaker/huggingface/model.py

+24
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,12 @@ def register(
306306
drift_check_baselines=None,
307307
customer_metadata_properties=None,
308308
domain=None,
309+
sample_payload_url=None,
310+
task=None,
311+
framework=None,
312+
framework_version=None,
313+
nearest_model_name=None,
314+
data_input_configuration=None,
309315
):
310316
"""Creates a model package for creating SageMaker models or listing on Marketplace.
311317
@@ -337,6 +343,18 @@ def register(
337343
metadata properties (default: None).
338344
domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING",
339345
"MACHINE_LEARNING" (default: None).
346+
sample_payload_url (str): The S3 path where the sample payload is stored
347+
(default: None).
348+
task (str): Task values which are supported by Inference Recommender are "FILL_MASK",
349+
"IMAGE_CLASSIFICATION", "OBJECT_DETECTION", "TEXT_GENERATION", "IMAGE_SEGMENTATION",
350+
"CLASSIFICATION", "REGRESSION", "OTHER" (default: None).
351+
framework (str): Machine learning framework of the model package container image
352+
(default: None).
353+
framework_version (str): Framework version of the Model Package Container Image
354+
(default: None).
355+
nearest_model_name (str): Name of a pre-trained machine learning benchmarked by
356+
Amazon SageMaker Inference Recommender (default: None).
357+
data_input_configuration (str): Input object for the model (default: None).
340358
341359
Returns:
342360
A `sagemaker.model.ModelPackage` instance.
@@ -367,6 +385,12 @@ def register(
367385
drift_check_baselines=drift_check_baselines,
368386
customer_metadata_properties=customer_metadata_properties,
369387
domain=domain,
388+
sample_payload_url=sample_payload_url,
389+
task=task,
390+
framework=framework,
391+
framework_version=framework_version,
392+
nearest_model_name=nearest_model_name,
393+
data_input_configuration=data_input_configuration,
370394
)
371395

372396
def prepare_container_def(

src/sagemaker/model.py

+37-2
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,10 @@
3535
from sagemaker.serverless import ServerlessInferenceConfig
3636
from sagemaker.transformer import Transformer
3737
from sagemaker.jumpstart.utils import add_jumpstart_tags, get_jumpstart_base_name_if_jumpstart_model
38-
from sagemaker.utils import unique_name_from_base
38+
from sagemaker.utils import (
39+
unique_name_from_base,
40+
update_container_with_inference_params,
41+
)
3942
from sagemaker.async_inference import AsyncInferenceConfig
4043
from sagemaker.predictor_async import AsyncPredictor
4144
from sagemaker.workflow import is_pipeline_variable
@@ -310,6 +313,12 @@ def register(
310313
customer_metadata_properties=None,
311314
validation_specification=None,
312315
domain=None,
316+
task=None,
317+
sample_payload_url=None,
318+
framework=None,
319+
framework_version=None,
320+
nearest_model_name=None,
321+
data_input_configuration=None,
313322
):
314323
"""Creates a model package for creating SageMaker models or listing on Marketplace.
315324
@@ -339,6 +348,18 @@ def register(
339348
metadata properties (default: None).
340349
domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING",
341350
"MACHINE_LEARNING" (default: None).
351+
sample_payload_url (str): The S3 path where the sample payload is stored
352+
(default: None).
353+
task (str): Task values which are supported by Inference Recommender are "FILL_MASK",
354+
"IMAGE_CLASSIFICATION", "OBJECT_DETECTION", "TEXT_GENERATION", "IMAGE_SEGMENTATION",
355+
"CLASSIFICATION", "REGRESSION", "OTHER" (default: None).
356+
framework (str): Machine learning framework of the model package container image
357+
(default: None).
358+
framework_version (str): Framework version of the Model Package Container Image
359+
(default: None).
360+
nearest_model_name (str): Name of a pre-trained machine learning benchmarked by
361+
Amazon SageMaker Inference Recommender (default: None).
362+
data_input_configuration (str): Input object for the model (default: None).
342363
343364
Returns:
344365
A `sagemaker.model.ModelPackage` instance or pipeline step arguments
@@ -349,10 +370,22 @@ def register(
349370
raise ValueError("SageMaker Model Package cannot be created without model data.")
350371
if image_uri is not None:
351372
self.image_uri = image_uri
373+
352374
if model_package_group_name is not None:
353375
container_def = self.prepare_container_def()
376+
update_container_with_inference_params(
377+
framework=framework,
378+
framework_version=framework_version,
379+
nearest_model_name=nearest_model_name,
380+
data_input_configuration=data_input_configuration,
381+
container_obj=container_def,
382+
)
354383
else:
355-
container_def = {"Image": self.image_uri, "ModelDataUrl": self.model_data}
384+
container_def = {
385+
"Image": self.image_uri,
386+
"ModelDataUrl": self.model_data,
387+
}
388+
356389
model_pkg_args = sagemaker.get_model_package_args(
357390
content_types,
358391
response_types,
@@ -370,6 +403,8 @@ def register(
370403
customer_metadata_properties=customer_metadata_properties,
371404
validation_specification=validation_specification,
372405
domain=domain,
406+
sample_payload_url=sample_payload_url,
407+
task=task,
373408
)
374409
model_package = self.sagemaker_session.create_model_package_from_containers(
375410
**model_pkg_args

src/sagemaker/mxnet/model.py

+24
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,12 @@ def register(
159159
drift_check_baselines=None,
160160
customer_metadata_properties=None,
161161
domain=None,
162+
sample_payload_url=None,
163+
task=None,
164+
framework=None,
165+
framework_version=None,
166+
nearest_model_name=None,
167+
data_input_configuration=None,
162168
):
163169
"""Creates a model package for creating SageMaker models or listing on Marketplace.
164170
@@ -188,6 +194,18 @@ def register(
188194
metadata properties (default: None).
189195
domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING",
190196
"MACHINE_LEARNING" (default: None).
197+
sample_payload_url (str): The S3 path where the sample payload is stored
198+
(default: None).
199+
task (str): Task values which are supported by Inference Recommender are "FILL_MASK",
200+
"IMAGE_CLASSIFICATION", "OBJECT_DETECTION", "TEXT_GENERATION", "IMAGE_SEGMENTATION",
201+
"CLASSIFICATION", "REGRESSION", "OTHER" (default: None).
202+
framework (str): Machine learning framework of the model package container image
203+
(default: None).
204+
framework_version (str): Framework version of the Model Package Container Image
205+
(default: None).
206+
nearest_model_name (str): Name of a pre-trained machine learning benchmarked by
207+
Amazon SageMaker Inference Recommender (default: None).
208+
data_input_configuration (str): Input object for the model (default: None).
191209
192210
Returns:
193211
A `sagemaker.model.ModelPackage` instance.
@@ -218,6 +236,12 @@ def register(
218236
drift_check_baselines=drift_check_baselines,
219237
customer_metadata_properties=customer_metadata_properties,
220238
domain=domain,
239+
sample_payload_url=sample_payload_url,
240+
task=task,
241+
framework=framework,
242+
framework_version=framework_version,
243+
nearest_model_name=nearest_model_name,
244+
data_input_configuration=data_input_configuration,
221245
)
222246

223247
def prepare_container_def(

src/sagemaker/pipeline.py

+31-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,10 @@
2020
from sagemaker.drift_check_baselines import DriftCheckBaselines
2121
from sagemaker.metadata_properties import MetadataProperties
2222
from sagemaker.session import Session
23-
from sagemaker.utils import name_from_image
23+
from sagemaker.utils import (
24+
name_from_image,
25+
update_container_with_inference_params,
26+
)
2427
from sagemaker.transformer import Transformer
2528
from sagemaker.workflow.pipeline_context import runnable_by_pipeline
2629

@@ -279,6 +282,12 @@ def register(
279282
drift_check_baselines: Optional[DriftCheckBaselines] = None,
280283
customer_metadata_properties: Optional[Dict[str, str]] = None,
281284
domain: Optional[str] = None,
285+
sample_payload_url: Optional[str] = None,
286+
task: Optional[str] = None,
287+
framework: Optional[str] = None,
288+
framework_version: Optional[str] = None,
289+
nearest_model_name: Optional[str] = None,
290+
data_input_configuration: Optional[str] = None,
282291
):
283292
"""Creates a model package for creating SageMaker models or listing on Marketplace.
284293
@@ -308,6 +317,18 @@ def register(
308317
metadata properties (default: None).
309318
domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING",
310319
"MACHINE_LEARNING" (default: None).
320+
sample_payload_url (str): The S3 path where the sample payload is stored
321+
(default: None).
322+
task (str): Task values which are supported by Inference Recommender are "FILL_MASK",
323+
"IMAGE_CLASSIFICATION", "OBJECT_DETECTION", "TEXT_GENERATION", "IMAGE_SEGMENTATION",
324+
"CLASSIFICATION", "REGRESSION", "OTHER" (default: None).
325+
framework (str): Machine learning framework of the model package container image
326+
(default: None).
327+
framework_version (str): Framework version of the Model Package Container Image
328+
(default: None).
329+
nearest_model_name (str): Name of a pre-trained machine learning benchmarked by
330+
Amazon SageMaker Inference Recommender (default: None).
331+
data_input_configuration (str): Input object for the model (default: None).
311332
312333
Returns:
313334
A `sagemaker.model.ModelPackage` instance.
@@ -319,6 +340,13 @@ def register(
319340
container_def = self.pipeline_container_def(
320341
inference_instances[0] if inference_instances else None
321342
)
343+
update_container_with_inference_params(
344+
framework=framework,
345+
framework_version=framework_version,
346+
nearest_model_name=nearest_model_name,
347+
data_input_configuration=data_input_configuration,
348+
container_list=container_def,
349+
)
322350
else:
323351
container_def = [
324352
{
@@ -344,6 +372,8 @@ def register(
344372
drift_check_baselines=drift_check_baselines,
345373
customer_metadata_properties=customer_metadata_properties,
346374
domain=domain,
375+
sample_payload_url=sample_payload_url,
376+
task=task,
347377
)
348378

349379
self.sagemaker_session.create_model_package_from_containers(**model_pkg_args)

src/sagemaker/pytorch/model.py

+24
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,12 @@ def register(
160160
drift_check_baselines=None,
161161
customer_metadata_properties=None,
162162
domain=None,
163+
sample_payload_url=None,
164+
task=None,
165+
framework=None,
166+
framework_version=None,
167+
nearest_model_name=None,
168+
data_input_configuration=None,
163169
):
164170
"""Creates a model package for creating SageMaker models or listing on Marketplace.
165171
@@ -189,6 +195,18 @@ def register(
189195
metadata properties (default: None).
190196
domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING",
191197
"MACHINE_LEARNING" (default: None).
198+
sample_payload_url (str): The S3 path where the sample payload is stored
199+
(default: None).
200+
task (str): Task values which are supported by Inference Recommender are "FILL_MASK",
201+
"IMAGE_CLASSIFICATION", "OBJECT_DETECTION", "TEXT_GENERATION", "IMAGE_SEGMENTATION",
202+
"CLASSIFICATION", "REGRESSION", "OTHER" (default: None).
203+
framework (str): Machine learning framework of the model package container image
204+
(default: None).
205+
framework_version (str): Framework version of the Model Package Container Image
206+
(default: None).
207+
nearest_model_name (str): Name of a pre-trained machine learning benchmarked by
208+
Amazon SageMaker Inference Recommender (default: None).
209+
data_input_configuration (str): Input object for the model (default: None).
192210
193211
Returns:
194212
A `sagemaker.model.ModelPackage` instance.
@@ -219,6 +237,12 @@ def register(
219237
drift_check_baselines=drift_check_baselines,
220238
customer_metadata_properties=customer_metadata_properties,
221239
domain=domain,
240+
sample_payload_url=sample_payload_url,
241+
task=task,
242+
framework=framework,
243+
framework_version=framework_version,
244+
nearest_model_name=nearest_model_name,
245+
data_input_configuration=data_input_configuration,
222246
)
223247

224248
def prepare_container_def(

0 commit comments

Comments
 (0)