diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index 208dc208b4..98c43f348d 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -1301,6 +1301,12 @@ def register( drift_check_baselines=None, customer_metadata_properties=None, domain=None, + sample_payload_url=None, + task=None, + framework=None, + framework_version=None, + nearest_model_name=None, + data_input_configuration=None, **kwargs, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -1334,6 +1340,18 @@ def register( metadata properties (default: None). domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING", "MACHINE_LEARNING" (default: None). + sample_payload_url (str): The S3 path where the sample payload is stored + (default: None). + task (str): Task values which are supported by Inference Recommender are "FILL_MASK", + "IMAGE_CLASSIFICATION", "OBJECT_DETECTION", "TEXT_GENERATION", "IMAGE_SEGMENTATION", + "CLASSIFICATION", "REGRESSION", "OTHER" (default: None). + framework (str): Machine learning framework of the model package container image + (default: None). + framework_version (str): Framework version of the Model Package Container Image + (default: None). + nearest_model_name (str): Name of a pre-trained machine learning benchmarked by + Amazon SageMaker Inference Recommender (default: None). + data_input_configuration (str): Input object for the model (default: None). **kwargs: Passed to invocation of ``create_model()``. Implementations may customize ``create_model()`` to accept ``**kwargs`` to customize model creation during deploy. For more, see the implementation docs. @@ -1371,6 +1389,12 @@ def register( drift_check_baselines=drift_check_baselines, customer_metadata_properties=customer_metadata_properties, domain=domain, + sample_payload_url=sample_payload_url, + task=task, + framework=framework, + framework_version=framework_version, + nearest_model_name=nearest_model_name, + data_input_configuration=data_input_configuration, ) @property diff --git a/src/sagemaker/huggingface/model.py b/src/sagemaker/huggingface/model.py index b72f1b1af2..8814b72175 100644 --- a/src/sagemaker/huggingface/model.py +++ b/src/sagemaker/huggingface/model.py @@ -306,6 +306,12 @@ def register( drift_check_baselines=None, customer_metadata_properties=None, domain=None, + sample_payload_url=None, + task=None, + framework=None, + framework_version=None, + nearest_model_name=None, + data_input_configuration=None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -337,6 +343,18 @@ def register( metadata properties (default: None). domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING", "MACHINE_LEARNING" (default: None). + sample_payload_url (str): The S3 path where the sample payload is stored + (default: None). + task (str): Task values which are supported by Inference Recommender are "FILL_MASK", + "IMAGE_CLASSIFICATION", "OBJECT_DETECTION", "TEXT_GENERATION", "IMAGE_SEGMENTATION", + "CLASSIFICATION", "REGRESSION", "OTHER" (default: None). + framework (str): Machine learning framework of the model package container image + (default: None). + framework_version (str): Framework version of the Model Package Container Image + (default: None). + nearest_model_name (str): Name of a pre-trained machine learning benchmarked by + Amazon SageMaker Inference Recommender (default: None). + data_input_configuration (str): Input object for the model (default: None). Returns: A `sagemaker.model.ModelPackage` instance. @@ -367,6 +385,12 @@ def register( drift_check_baselines=drift_check_baselines, customer_metadata_properties=customer_metadata_properties, domain=domain, + sample_payload_url=sample_payload_url, + task=task, + framework=framework, + framework_version=framework_version, + nearest_model_name=nearest_model_name, + data_input_configuration=data_input_configuration, ) def prepare_container_def( diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index bfa4caa6e0..60c766379b 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -35,7 +35,10 @@ from sagemaker.serverless import ServerlessInferenceConfig from sagemaker.transformer import Transformer from sagemaker.jumpstart.utils import add_jumpstart_tags, get_jumpstart_base_name_if_jumpstart_model -from sagemaker.utils import unique_name_from_base +from sagemaker.utils import ( + unique_name_from_base, + update_container_with_inference_params, +) from sagemaker.async_inference import AsyncInferenceConfig from sagemaker.predictor_async import AsyncPredictor from sagemaker.workflow import is_pipeline_variable @@ -310,6 +313,12 @@ def register( customer_metadata_properties=None, validation_specification=None, domain=None, + task=None, + sample_payload_url=None, + framework=None, + framework_version=None, + nearest_model_name=None, + data_input_configuration=None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -339,6 +348,18 @@ def register( metadata properties (default: None). domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING", "MACHINE_LEARNING" (default: None). + sample_payload_url (str): The S3 path where the sample payload is stored + (default: None). + task (str): Task values which are supported by Inference Recommender are "FILL_MASK", + "IMAGE_CLASSIFICATION", "OBJECT_DETECTION", "TEXT_GENERATION", "IMAGE_SEGMENTATION", + "CLASSIFICATION", "REGRESSION", "OTHER" (default: None). + framework (str): Machine learning framework of the model package container image + (default: None). + framework_version (str): Framework version of the Model Package Container Image + (default: None). + nearest_model_name (str): Name of a pre-trained machine learning benchmarked by + Amazon SageMaker Inference Recommender (default: None). + data_input_configuration (str): Input object for the model (default: None). Returns: A `sagemaker.model.ModelPackage` instance or pipeline step arguments @@ -349,10 +370,22 @@ def register( raise ValueError("SageMaker Model Package cannot be created without model data.") if image_uri is not None: self.image_uri = image_uri + if model_package_group_name is not None: container_def = self.prepare_container_def() + update_container_with_inference_params( + framework=framework, + framework_version=framework_version, + nearest_model_name=nearest_model_name, + data_input_configuration=data_input_configuration, + container_obj=container_def, + ) else: - container_def = {"Image": self.image_uri, "ModelDataUrl": self.model_data} + container_def = { + "Image": self.image_uri, + "ModelDataUrl": self.model_data, + } + model_pkg_args = sagemaker.get_model_package_args( content_types, response_types, @@ -370,6 +403,8 @@ def register( customer_metadata_properties=customer_metadata_properties, validation_specification=validation_specification, domain=domain, + sample_payload_url=sample_payload_url, + task=task, ) model_package = self.sagemaker_session.create_model_package_from_containers( **model_pkg_args diff --git a/src/sagemaker/mxnet/model.py b/src/sagemaker/mxnet/model.py index fa2773bebb..60fc1d60d2 100644 --- a/src/sagemaker/mxnet/model.py +++ b/src/sagemaker/mxnet/model.py @@ -159,6 +159,12 @@ def register( drift_check_baselines=None, customer_metadata_properties=None, domain=None, + sample_payload_url=None, + task=None, + framework=None, + framework_version=None, + nearest_model_name=None, + data_input_configuration=None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -188,6 +194,18 @@ def register( metadata properties (default: None). domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING", "MACHINE_LEARNING" (default: None). + sample_payload_url (str): The S3 path where the sample payload is stored + (default: None). + task (str): Task values which are supported by Inference Recommender are "FILL_MASK", + "IMAGE_CLASSIFICATION", "OBJECT_DETECTION", "TEXT_GENERATION", "IMAGE_SEGMENTATION", + "CLASSIFICATION", "REGRESSION", "OTHER" (default: None). + framework (str): Machine learning framework of the model package container image + (default: None). + framework_version (str): Framework version of the Model Package Container Image + (default: None). + nearest_model_name (str): Name of a pre-trained machine learning benchmarked by + Amazon SageMaker Inference Recommender (default: None). + data_input_configuration (str): Input object for the model (default: None). Returns: A `sagemaker.model.ModelPackage` instance. @@ -218,6 +236,12 @@ def register( drift_check_baselines=drift_check_baselines, customer_metadata_properties=customer_metadata_properties, domain=domain, + sample_payload_url=sample_payload_url, + task=task, + framework=framework, + framework_version=framework_version, + nearest_model_name=nearest_model_name, + data_input_configuration=data_input_configuration, ) def prepare_container_def( diff --git a/src/sagemaker/pipeline.py b/src/sagemaker/pipeline.py index 75fae3bfc4..8cdb82ffe7 100644 --- a/src/sagemaker/pipeline.py +++ b/src/sagemaker/pipeline.py @@ -20,7 +20,10 @@ from sagemaker.drift_check_baselines import DriftCheckBaselines from sagemaker.metadata_properties import MetadataProperties from sagemaker.session import Session -from sagemaker.utils import name_from_image +from sagemaker.utils import ( + name_from_image, + update_container_with_inference_params, +) from sagemaker.transformer import Transformer from sagemaker.workflow.pipeline_context import runnable_by_pipeline @@ -279,6 +282,12 @@ def register( drift_check_baselines: Optional[DriftCheckBaselines] = None, customer_metadata_properties: Optional[Dict[str, str]] = None, domain: Optional[str] = None, + sample_payload_url: Optional[str] = None, + task: Optional[str] = None, + framework: Optional[str] = None, + framework_version: Optional[str] = None, + nearest_model_name: Optional[str] = None, + data_input_configuration: Optional[str] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -308,6 +317,18 @@ def register( metadata properties (default: None). domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING", "MACHINE_LEARNING" (default: None). + sample_payload_url (str): The S3 path where the sample payload is stored + (default: None). + task (str): Task values which are supported by Inference Recommender are "FILL_MASK", + "IMAGE_CLASSIFICATION", "OBJECT_DETECTION", "TEXT_GENERATION", "IMAGE_SEGMENTATION", + "CLASSIFICATION", "REGRESSION", "OTHER" (default: None). + framework (str): Machine learning framework of the model package container image + (default: None). + framework_version (str): Framework version of the Model Package Container Image + (default: None). + nearest_model_name (str): Name of a pre-trained machine learning benchmarked by + Amazon SageMaker Inference Recommender (default: None). + data_input_configuration (str): Input object for the model (default: None). Returns: A `sagemaker.model.ModelPackage` instance. @@ -319,6 +340,13 @@ def register( container_def = self.pipeline_container_def( inference_instances[0] if inference_instances else None ) + update_container_with_inference_params( + framework=framework, + framework_version=framework_version, + nearest_model_name=nearest_model_name, + data_input_configuration=data_input_configuration, + container_list=container_def, + ) else: container_def = [ { @@ -344,6 +372,8 @@ def register( drift_check_baselines=drift_check_baselines, customer_metadata_properties=customer_metadata_properties, domain=domain, + sample_payload_url=sample_payload_url, + task=task, ) self.sagemaker_session.create_model_package_from_containers(**model_pkg_args) diff --git a/src/sagemaker/pytorch/model.py b/src/sagemaker/pytorch/model.py index 6e5d63c14d..b5e019f492 100644 --- a/src/sagemaker/pytorch/model.py +++ b/src/sagemaker/pytorch/model.py @@ -160,6 +160,12 @@ def register( drift_check_baselines=None, customer_metadata_properties=None, domain=None, + sample_payload_url=None, + task=None, + framework=None, + framework_version=None, + nearest_model_name=None, + data_input_configuration=None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -189,6 +195,18 @@ def register( metadata properties (default: None). domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING", "MACHINE_LEARNING" (default: None). + sample_payload_url (str): The S3 path where the sample payload is stored + (default: None). + task (str): Task values which are supported by Inference Recommender are "FILL_MASK", + "IMAGE_CLASSIFICATION", "OBJECT_DETECTION", "TEXT_GENERATION", "IMAGE_SEGMENTATION", + "CLASSIFICATION", "REGRESSION", "OTHER" (default: None). + framework (str): Machine learning framework of the model package container image + (default: None). + framework_version (str): Framework version of the Model Package Container Image + (default: None). + nearest_model_name (str): Name of a pre-trained machine learning benchmarked by + Amazon SageMaker Inference Recommender (default: None). + data_input_configuration (str): Input object for the model (default: None). Returns: A `sagemaker.model.ModelPackage` instance. @@ -219,6 +237,12 @@ def register( drift_check_baselines=drift_check_baselines, customer_metadata_properties=customer_metadata_properties, domain=domain, + sample_payload_url=sample_payload_url, + task=task, + framework=framework, + framework_version=framework_version, + nearest_model_name=nearest_model_name, + data_input_configuration=data_input_configuration, ) def prepare_container_def( diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 461dfd8bab..eb158eab3d 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -2822,6 +2822,8 @@ def create_model_package_from_containers( customer_metadata_properties=None, validation_specification=None, domain=None, + sample_payload_url=None, + task=None, ): """Get request dictionary for CreateModelPackage API. @@ -2851,6 +2853,11 @@ def create_model_package_from_containers( metadata properties (default: None). domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING", "MACHINE_LEARNING" (default: None). + sample_payload_url (str): The S3 path where the sample payload is stored + (default: None). + task (str): Task values which are supported by Inference Recommender are "FILL_MASK", + "IMAGE_CLASSIFICATION", "OBJECT_DETECTION", "TEXT_GENERATION", "IMAGE_SEGMENTATION", + "CLASSIFICATION", "REGRESSION", "OTHER" (default: None). """ model_pkg_request = get_create_model_package_request( @@ -2870,6 +2877,8 @@ def create_model_package_from_containers( customer_metadata_properties=customer_metadata_properties, validation_specification=validation_specification, domain=domain, + sample_payload_url=sample_payload_url, + task=task, ) def submit(request): @@ -4241,6 +4250,8 @@ def get_model_package_args( customer_metadata_properties=None, validation_specification=None, domain=None, + sample_payload_url=None, + task=None, ): """Get arguments for create_model_package method. @@ -4273,6 +4284,11 @@ def get_model_package_args( metadata properties (default: None). domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING", "MACHINE_LEARNING" (default: None). + sample_payload_url (str): The S3 path where the sample payload is stored (default: None). + task (str): Task values which are supported by Inference Recommender are "FILL_MASK", + "IMAGE_CLASSIFICATION", "OBJECT_DETECTION", "TEXT_GENERATION", "IMAGE_SEGMENTATION", + "CLASSIFICATION", "REGRESSION", "OTHER" (default: None). + Returns: dict: A dictionary of method argument names and values. """ @@ -4316,6 +4332,10 @@ def get_model_package_args( model_package_args["validation_specification"] = validation_specification if domain is not None: model_package_args["domain"] = domain + if sample_payload_url is not None: + model_package_args["sample_payload_url"] = sample_payload_url + if task is not None: + model_package_args["task"] = task return model_package_args @@ -4337,6 +4357,8 @@ def get_create_model_package_request( customer_metadata_properties=None, validation_specification=None, domain=None, + sample_payload_url=None, + task=None, ): """Get request dictionary for CreateModelPackage API. @@ -4367,6 +4389,10 @@ def get_create_model_package_request( metadata properties (default: None). domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING", "MACHINE_LEARNING" (default: None). + sample_payload_url (str): The S3 path where the sample payload is stored (default: None). + task (str): Task values which are supported by Inference Recommender are "FILL_MASK", + "IMAGE_CLASSIFICATION", "OBJECT_DETECTION", "TEXT_GENERATION", "IMAGE_SEGMENTATION", + "CLASSIFICATION", "REGRESSION", "OTHER" (default: None). """ if all([model_package_name, model_package_group_name]): @@ -4394,6 +4420,10 @@ def get_create_model_package_request( request_dict["ValidationSpecification"] = validation_specification if domain is not None: request_dict["Domain"] = domain + if sample_payload_url is not None: + request_dict["SamplePayloadUrl"] = sample_payload_url + if task is not None: + request_dict["Task"] = task if containers is not None: if not all([content_types, response_types]): raise ValueError( diff --git a/src/sagemaker/sklearn/model.py b/src/sagemaker/sklearn/model.py index 71fe048bf1..67f9d60175 100644 --- a/src/sagemaker/sklearn/model.py +++ b/src/sagemaker/sklearn/model.py @@ -154,6 +154,12 @@ def register( drift_check_baselines=None, customer_metadata_properties=None, domain=None, + sample_payload_url=None, + task=None, + framework=None, + framework_version=None, + nearest_model_name=None, + data_input_configuration=None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -183,6 +189,18 @@ def register( metadata properties (default: None). domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING", "MACHINE_LEARNING" (default: None). + sample_payload_url (str): The S3 path where the sample payload is stored + (default: None). + task (str): Task values which are supported by Inference Recommender are "FILL_MASK", + "IMAGE_CLASSIFICATION", "OBJECT_DETECTION", "TEXT_GENERATION", "IMAGE_SEGMENTATION", + "CLASSIFICATION", "REGRESSION", "OTHER" (default: None). + framework (str): Machine learning framework of the model package container image + (default: None). + framework_version (str): Framework version of the Model Package Container Image + (default: None). + nearest_model_name (str): Name of a pre-trained machine learning benchmarked by + Amazon SageMaker Inference Recommender (default: None). + data_input_configuration (str): Input object for the model (default: None). Returns: A `sagemaker.model.ModelPackage` instance. @@ -213,6 +231,12 @@ def register( drift_check_baselines=drift_check_baselines, customer_metadata_properties=customer_metadata_properties, domain=domain, + sample_payload_url=sample_payload_url, + task=task, + framework=framework, + framework_version=framework_version, + nearest_model_name=nearest_model_name, + data_input_configuration=data_input_configuration, ) def prepare_container_def( diff --git a/src/sagemaker/tensorflow/model.py b/src/sagemaker/tensorflow/model.py index 3c4bf3343a..e5e6798a63 100644 --- a/src/sagemaker/tensorflow/model.py +++ b/src/sagemaker/tensorflow/model.py @@ -206,6 +206,12 @@ def register( drift_check_baselines=None, customer_metadata_properties=None, domain=None, + sample_payload_url=None, + task=None, + framework=None, + framework_version=None, + nearest_model_name=None, + data_input_configuration=None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -235,6 +241,18 @@ def register( metadata properties (default: None). domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING", "MACHINE_LEARNING" (default: None). + sample_payload_url (str): The S3 path where the sample payload is stored + (default: None). + task (str): Task values which are supported by Inference Recommender are "FILL_MASK", + "IMAGE_CLASSIFICATION", "OBJECT_DETECTION", "TEXT_GENERATION", "IMAGE_SEGMENTATION", + "CLASSIFICATION", "REGRESSION", "OTHER" (default: None). + framework (str): Machine learning framework of the model package container image + (default: None). + framework_version (str): Framework version of the Model Package Container Image + (default: None). + nearest_model_name (str): Name of a pre-trained machine learning benchmarked by + Amazon SageMaker Inference Recommender (default: None). + data_input_configuration (str): Input object for the model (default: None). Returns: A `sagemaker.model.ModelPackage` instance. @@ -265,6 +283,12 @@ def register( drift_check_baselines=drift_check_baselines, customer_metadata_properties=customer_metadata_properties, domain=domain, + sample_payload_url=sample_payload_url, + task=task, + framework=framework, + framework_version=framework_version, + nearest_model_name=nearest_model_name, + data_input_configuration=data_input_configuration, ) def deploy( diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index 1d2e9fe5cb..ed5b3c5e75 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -722,3 +722,59 @@ def get_data_bucket(self, region_requested=None): get_ecr_image_uri_prefix = deprecations.removed_function("get_ecr_image_uri_prefix") + + +def update_container_with_inference_params( + framework=None, + framework_version=None, + nearest_model_name=None, + data_input_configuration=None, + container_obj=None, + container_list=None, +): + """Function to check if inference recommender parameters exist and update container. + + Args: + framework (str): Machine learning framework of the model package container image + (default: None). + framework_version (str): Framework version of the Model Package Container Image + (default: None). + nearest_model_name (str): Name of a pre-trained machine learning benchmarked by + Amazon SageMaker Inference Recommender (default: None). + data_input_configuration (str): Input object for the model (default: None). + container_obj (dict): object to be updated. + container_list (list): list to be updated. + + Returns: + dict: dict with inference recommender params + """ + + if ( + framework is not None + and framework_version is not None + and nearest_model_name is not None + and data_input_configuration is not None + ): + if container_list is not None: + for obj in container_list: + obj.update( + { + "Framework": framework, + "FrameworkVersion": framework_version, + "NearestModelName": nearest_model_name, + "ModelInput": { + "DataInputConfig": data_input_configuration, + }, + } + ) + if container_obj is not None: + container_obj.update( + { + "Framework": framework, + "FrameworkVersion": framework_version, + "NearestModelName": nearest_model_name, + "ModelInput": { + "DataInputConfig": data_input_configuration, + }, + } + ) diff --git a/src/sagemaker/workflow/_utils.py b/src/sagemaker/workflow/_utils.py index 7b8a3cdc25..7a0a399299 100644 --- a/src/sagemaker/workflow/_utils.py +++ b/src/sagemaker/workflow/_utils.py @@ -285,6 +285,8 @@ def __init__( drift_check_baselines=None, customer_metadata_properties=None, domain=None, + sample_payload_url=None, + task=None, **kwargs, ): """Constructor of a register model step. @@ -329,6 +331,11 @@ def __init__( metadata properties (default: None). domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING", "MACHINE_LEARNING" (default: None). + sample_payload_url (str): The S3 path where the sample payload is stored + (default: None). + task (str): Task values which are supported by Inference Recommender are "FILL_MASK", + "IMAGE_CLASSIFICATION", "OBJECT_DETECTION", "TEXT_GENERATION", "IMAGE_SEGMENTATION", + "CLASSIFICATION", "REGRESSION", "OTHER" (default: None). **kwargs: additional arguments to `create_model`. """ super(_RegisterModelStep, self).__init__( @@ -360,6 +367,8 @@ def __init__( self.drift_check_baselines = drift_check_baselines self.customer_metadata_properties = customer_metadata_properties self.domain = domain + self.sample_payload_url = sample_payload_url + self.task = task self.metadata_properties = metadata_properties self.approval_status = approval_status self.image_uri = image_uri @@ -438,6 +447,8 @@ def arguments(self) -> RequestType: container_def_list=self.container_def_list, customer_metadata_properties=self.customer_metadata_properties, domain=self.domain, + sample_payload_url=self.sample_payload_url, + task=self.task, ) request_dict = get_create_model_package_request(**model_package_args) diff --git a/src/sagemaker/workflow/step_collections.py b/src/sagemaker/workflow/step_collections.py index d52ddace87..dd9529916e 100644 --- a/src/sagemaker/workflow/step_collections.py +++ b/src/sagemaker/workflow/step_collections.py @@ -27,6 +27,7 @@ from sagemaker.workflow.steps import Step, CreateModelStep, TransformStep from sagemaker.workflow._utils import _RegisterModelStep, _RepackModelStep from sagemaker.workflow.retry import RetryPolicy +from sagemaker.utils import update_container_with_inference_params @attr.s @@ -80,6 +81,12 @@ def __init__( drift_check_baselines=None, customer_metadata_properties=None, domain=None, + sample_payload_url=None, + task=None, + framework=None, + framework_version=None, + nearest_model_name=None, + data_input_configuration=None, **kwargs, ): """Construct steps `_RepackModelStep` and `_RegisterModelStep` based on the estimator. @@ -123,6 +130,18 @@ def __init__( metadata properties (default: None). domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING", "MACHINE_LEARNING" (default: None). + sample_payload_url (str): The S3 path where the sample payload is stored + (default: None). + task (str): Task values which are supported by Inference Recommender are "FILL_MASK", + "IMAGE_CLASSIFICATION", "OBJECT_DETECTION", "TEXT_GENERATION", "IMAGE_SEGMENTATION", + "CLASSIFICATION", "REGRESSION", "OTHER" (default: None). + framework (str): Machine learning framework of the model package container image + (default: None). + framework_version (str): Framework version of the Model Package Container Image + (default: None). + nearest_model_name (str): Name of a pre-trained machine learning benchmarked by + Amazon SageMaker Inference Recommender (default: None). + data_input_configuration (str): Input object for the model (default: None). **kwargs: additional arguments to `create_model`. """ @@ -228,6 +247,14 @@ def __init__( ) ] + update_container_with_inference_params( + framework=framework, + framework_version=framework_version, + nearest_model_name=nearest_model_name, + data_input_configuration=data_input_configuration, + container_list=self.container_def_list, + ) + register_model_step = _RegisterModelStep( name=name, estimator=estimator, @@ -249,6 +276,8 @@ def __init__( retry_policies=register_model_step_retry_policies, customer_metadata_properties=customer_metadata_properties, domain=domain, + sample_payload_url=sample_payload_url, + task=task, **kwargs, ) if not repack_model: diff --git a/tests/integ/sagemaker/workflow/test_model_create_and_registration.py b/tests/integ/sagemaker/workflow/test_model_create_and_registration.py index d8f1d9ab6c..d0f617a266 100644 --- a/tests/integ/sagemaker/workflow/test_model_create_and_registration.py +++ b/tests/integ/sagemaker/workflow/test_model_create_and_registration.py @@ -94,6 +94,13 @@ def test_conditional_pytorch_training_model_registration( good_enough_input = ParameterInteger(name="GoodEnoughInput", default_value=1) in_condition_input = ParameterString(name="Foo", default_value="Foo") + task = "IMAGE_CLASSIFICATION" + sample_payload_url = "s3://test-bucket/model" + framework = "TENSORFLOW" + framework_version = "2.9" + nearest_model_name = "resnet50" + data_input_configuration = '{"input_1":[1,224,224,3]}' + # If image_uri is not provided, the instance_type should not be a pipeline variable # since instance_type is used to retrieve image_uri in compile time (PySDK) pytorch_estimator = PyTorch( @@ -120,6 +127,12 @@ def test_conditional_pytorch_training_model_registration( inference_instances=["*"], transform_instances=["*"], description="test-description", + sample_payload_url=sample_payload_url, + task=task, + framework=framework, + framework_version=framework_version, + nearest_model_name=nearest_model_name, + data_input_configuration=data_input_configuration, ) model = Model( @@ -201,6 +214,13 @@ def test_mxnet_model_registration( instance_count = ParameterInteger(name="InstanceCount", default_value=1) instance_type = ParameterString(name="InstanceType", default_value="ml.m5.xlarge") + task = "IMAGE_CLASSIFICATION" + sample_payload_url = "s3://test-bucket/model" + framework = "TENSORFLOW" + framework_version = "2.9" + nearest_model_name = "resnet50" + data_input_configuration = '{"input_1":[1,224,224,3]}' + model = MXNetModel( entry_point=entry_point, source_dir=source_dir, @@ -219,6 +239,12 @@ def test_mxnet_model_registration( inference_instances=["ml.m5.xlarge"], transform_instances=["*"], description="test-description", + sample_payload_url=sample_payload_url, + task=task, + framework=framework, + framework_version=framework_version, + nearest_model_name=nearest_model_name, + data_input_configuration=data_input_configuration, ) pipeline = Pipeline( @@ -262,6 +288,13 @@ def test_sklearn_xgboost_sip_model_registration( instance_count = ParameterInteger(name="InstanceCount", default_value=1) instance_type = "ml.m5.xlarge" + task = "IMAGE_CLASSIFICATION" + sample_payload_url = "s3://test-bucket/model" + framework = "TENSORFLOW" + framework_version = "2.9" + nearest_model_name = "resnet50" + data_input_configuration = '{"input_1":[1,224,224,3]}' + # The instance_type should not be a pipeline variable # since it is used to retrieve image_uri in compile time (PySDK) sklearn_processor = SKLearnProcessor( @@ -412,6 +445,12 @@ def test_sklearn_xgboost_sip_model_registration( inference_instances=["ml.t2.medium", "ml.m5.xlarge"], transform_instances=["ml.m5.xlarge"], model_package_group_name="windturbine", + sample_payload_url=sample_payload_url, + task=task, + framework=framework, + framework_version=framework_version, + nearest_model_name=nearest_model_name, + data_input_configuration=data_input_configuration, ) pipeline = Pipeline( @@ -555,6 +594,12 @@ def test_model_registration_with_drift_check_baselines( ) customer_metadata_properties = {"key1": "value1"} domain = "COMPUTER_VISION" + task = "IMAGE_CLASSIFICATION" + sample_payload_url = "s3://test-bucket/model" + framework = "TENSORFLOW" + framework_version = "2.9" + nearest_model_name = "resnet50" + data_input_configuration = '{"input_1":[1,224,224,3]}' # If image_uri is not provided, the instance_type should not be a pipeline variable # since instance_type is used to retrieve image_uri in compile time (PySDK) @@ -568,6 +613,7 @@ def test_model_registration_with_drift_check_baselines( py_version="py3", role=role, ) + step_register = RegisterModel( name="MyRegisterModelStep", estimator=estimator, @@ -581,6 +627,12 @@ def test_model_registration_with_drift_check_baselines( drift_check_baselines=drift_check_baselines, customer_metadata_properties=customer_metadata_properties, domain=domain, + sample_payload_url=sample_payload_url, + task=task, + framework=framework, + framework_version=framework_version, + nearest_model_name=nearest_model_name, + data_input_configuration=data_input_configuration, ) pipeline = Pipeline( @@ -652,6 +704,8 @@ def test_model_registration_with_drift_check_baselines( ) assert response["CustomerMetadataProperties"] == customer_metadata_properties assert response["Domain"] == domain + assert response["Task"] == task + assert response["SamplePayloadUrl"] == sample_payload_url break finally: try: diff --git a/tests/unit/sagemaker/workflow/test_pipeline_session.py b/tests/unit/sagemaker/workflow/test_pipeline_session.py index d2954ede7b..90a9116c07 100644 --- a/tests/unit/sagemaker/workflow/test_pipeline_session.py +++ b/tests/unit/sagemaker/workflow/test_pipeline_session.py @@ -116,6 +116,12 @@ def test_pipeline_session_context_for_model_step(pipeline_session_mock): inference_instances=["ml.t2.medium", "ml.m5.xlarge"], transform_instances=["ml.m5.xlarge"], model_package_group_name="MyModelPackageGroup", + task="IMAGE_CLASSIFICATION", + sample_payload_url="s3://test-bucket/model", + framework="TENSORFLOW", + framework_version="2.9", + nearest_model_name="resnet50", + data_input_configuration='{"input_1":[1,224,224,3]}', ) # The context should be cleaned up before return assert not pipeline_session_mock.context @@ -136,11 +142,16 @@ def test_pipeline_session_context_for_model_step_without_instance_types( source_dir=f"{DATA_DIR}", role=_ROLE, ) - register_step_args = model.register( content_types=["text/csv"], response_types=["text/csv"], model_package_group_name="MyModelPackageGroup", + task="IMAGE_CLASSIFICATION", + sample_payload_url="s3://test-bucket/model", + framework="TENSORFLOW", + framework_version="2.9", + nearest_model_name="resnet50", + data_input_configuration='{"input_1":[1,224,224,3]}', ) expected_output = { @@ -159,6 +170,12 @@ def test_pipeline_session_context_for_model_step_without_instance_types( name="ModelData", default_value="s3://my-bucket/file", ), + "Framework": "TENSORFLOW", + "FrameworkVersion": "2.9", + "NearestModelName": "resnet50", + "ModelInput": { + "DataInputConfig": '{"input_1":[1,224,224,3]}', + }, } ], "SupportedContentTypes": ["text/csv"], @@ -168,6 +185,8 @@ def test_pipeline_session_context_for_model_step_without_instance_types( }, "CertifyForMarketplace": False, "ModelApprovalStatus": "PendingManualApproval", + "SamplePayloadUrl": "s3://test-bucket/model", + "Task": "IMAGE_CLASSIFICATION", } assert register_step_args.create_model_package_request == expected_output diff --git a/tests/unit/sagemaker/workflow/test_step_collections.py b/tests/unit/sagemaker/workflow/test_step_collections.py index 9d41e70aca..4aa55fd068 100644 --- a/tests/unit/sagemaker/workflow/test_step_collections.py +++ b/tests/unit/sagemaker/workflow/test_step_collections.py @@ -368,6 +368,12 @@ def test_register_model(estimator, model_metrics, drift_check_baselines): display_name="RegisterModelStep", depends_on=["TestStep"], tags=[{"Key": "myKey", "Value": "myValue"}], + sample_payload_url="s3://test-bucket/model", + task="IMAGE_CLASSIFICATION", + framework="TENSORFLOW", + framework_version="2.9", + nearest_model_name="resnet50", + data_input_configuration='{"input_1":[1,224,224,3]}', ) assert ordered(register_model.request_dicts()) == ordered( [ @@ -412,6 +418,8 @@ def test_register_model(estimator, model_metrics, drift_check_baselines): "ModelPackageDescription": "description", "ModelPackageGroupName": "mpg", "Tags": [{"Key": "myKey", "Value": "myValue"}], + "SamplePayloadUrl": "s3://test-bucket/model", + "Task": "IMAGE_CLASSIFICATION", }, }, ] @@ -433,6 +441,12 @@ def test_register_model_tf(estimator_tf, model_metrics, drift_check_baselines): drift_check_baselines=drift_check_baselines, approval_status="Approved", description="description", + sample_payload_url="s3://test-bucket/model", + task="IMAGE_CLASSIFICATION", + framework="TENSORFLOW", + framework_version="2.9", + nearest_model_name="resnet50", + data_input_configuration='{"input_1":[1,224,224,3]}', ) assert ordered(register_model.request_dicts()) == ordered( [ @@ -474,6 +488,8 @@ def test_register_model_tf(estimator_tf, model_metrics, drift_check_baselines): }, "ModelPackageDescription": "description", "ModelPackageGroupName": "mpg", + "SamplePayloadUrl": "s3://test-bucket/model", + "Task": "IMAGE_CLASSIFICATION", }, }, ] @@ -502,6 +518,12 @@ def test_register_model_sip(estimator, model_metrics, drift_check_baselines): description="description", model=pipeline_model, depends_on=["TestStep"], + sample_payload_url="s3://test-bucket/model", + task="IMAGE_CLASSIFICATION", + framework="TENSORFLOW", + framework_version="2.9", + nearest_model_name="resnet50", + data_input_configuration='{"input_1":[1,224,224,3]}', ) assert ordered(register_model.request_dicts()) == ordered( [ @@ -517,11 +539,23 @@ def test_register_model_sip(estimator, model_metrics, drift_check_baselines): "Image": "fakeimage1", "ModelDataUrl": "Url1", "Environment": [{"k1": "v1"}, {"k2": "v2"}], + "Framework": "TENSORFLOW", + "FrameworkVersion": "2.9", + "NearestModelName": "resnet50", + "ModelInput": { + "DataInputConfig": '{"input_1":[1,224,224,3]}', + }, }, { "Image": "fakeimage2", "ModelDataUrl": "Url2", "Environment": [{"k3": "v3"}, {"k4": "v4"}], + "Framework": "TENSORFLOW", + "FrameworkVersion": "2.9", + "NearestModelName": "resnet50", + "ModelInput": { + "DataInputConfig": '{"input_1":[1,224,224,3]}', + }, }, ], "SupportedContentTypes": ["content_type"], @@ -550,6 +584,8 @@ def test_register_model_sip(estimator, model_metrics, drift_check_baselines): }, "ModelPackageDescription": "description", "ModelPackageGroupName": "mpg", + "SamplePayloadUrl": "s3://test-bucket/model", + "Task": "IMAGE_CLASSIFICATION", }, }, ] @@ -578,6 +614,12 @@ def test_register_model_with_model_repack_with_estimator( dependencies=[dummy_requirements], depends_on=["TestStep"], tags=[{"Key": "myKey", "Value": "myValue"}], + sample_payload_url="s3://test-bucket/model", + task="IMAGE_CLASSIFICATION", + framework="TENSORFLOW", + framework_version="2.9", + nearest_model_name="resnet50", + data_input_configuration='{"input_1":[1,224,224,3]}', ) request_dicts = register_model.request_dicts() @@ -680,6 +722,8 @@ def test_register_model_with_model_repack_with_estimator( "ModelPackageDescription": "description", "ModelPackageGroupName": "mpg", "Tags": [{"Key": "myKey", "Value": "myValue"}], + "SamplePayloadUrl": "s3://test-bucket/model", + "Task": "IMAGE_CLASSIFICATION", } ) else: diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index 7e2e985845..78298025ea 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -3109,6 +3109,12 @@ def test_register_default_image(sagemaker_session): response_types = ["application/json"] inference_instances = ["ml.m4.xlarge"] transform_instances = ["ml.m4.xlarget"] + sample_payload_url = "s3://test-bucket/model" + task = "IMAGE_CLASSIFICATION" + framework = "TENSORFLOW" + framework_version = "2.9" + nearest_model_name = "resnet50" + data_input_config = '{"input_1":[1,224,224,3]}' estimator.register( content_types=content_types, @@ -3116,6 +3122,12 @@ def test_register_default_image(sagemaker_session): inference_instances=inference_instances, transform_instances=transform_instances, model_package_name=model_package_name, + sample_payload_url=sample_payload_url, + task=task, + framework=framework, + framework_version=framework_version, + nearest_model_name=nearest_model_name, + data_input_configuration=data_input_config, ) sagemaker_session.create_model.assert_not_called() @@ -3132,6 +3144,8 @@ def test_register_default_image(sagemaker_session): "transform_instances": transform_instances, "model_package_name": model_package_name, "marketplace_cert": False, + "sample_payload_url": sample_payload_url, + "task": task, } sagemaker_session.create_model_package_from_containers.assert_called_with( **expected_create_model_package_request @@ -3153,11 +3167,23 @@ def test_register_default_image_without_instance_type_args(sagemaker_session): model_package_name = "test-estimator-register-model" content_types = ["application/json"] response_types = ["application/json"] + sample_payload_url = "s3://test-bucket/model" + task = "IMAGE_CLASSIFICATION" + framework = "TENSORFLOW" + framework_version = "2.9" + nearest_model_name = "resnet50" + data_input_config = '{"input_1":[1,224,224,3]}' estimator.register( content_types=content_types, response_types=response_types, model_package_name=model_package_name, + sample_payload_url=sample_payload_url, + task=task, + framework=framework, + framework_version=framework_version, + nearest_model_name=nearest_model_name, + data_input_configuration=data_input_config, ) sagemaker_session.create_model.assert_not_called() @@ -3174,6 +3200,8 @@ def test_register_default_image_without_instance_type_args(sagemaker_session): "transform_instances": None, "model_package_name": model_package_name, "marketplace_cert": False, + "sample_payload_url": sample_payload_url, + "task": task, } sagemaker_session.create_model_package_from_containers.assert_called_with( **expected_create_model_package_request @@ -3198,6 +3226,12 @@ def test_register_inference_image(sagemaker_session): inference_instances = ["ml.m4.xlarge"] transform_instances = ["ml.m4.xlarget"] inference_image = "fake-inference-image" + sample_payload_url = "s3://test-bucket/model" + task = "IMAGE_CLASSIFICATION" + framework = "TENSORFLOW" + framework_version = "2.9" + nearest_model_name = "resnet50" + data_input_config = '{"input_1":[1,224,224,3]}' estimator.register( content_types=content_types, @@ -3205,7 +3239,13 @@ def test_register_inference_image(sagemaker_session): inference_instances=inference_instances, transform_instances=transform_instances, model_package_name=model_package_name, + sample_payload_url=sample_payload_url, + task=task, image_uri=inference_image, + framework=framework, + framework_version=framework_version, + nearest_model_name=nearest_model_name, + data_input_configuration=data_input_config, ) sagemaker_session.create_model.assert_not_called() @@ -3222,6 +3262,8 @@ def test_register_inference_image(sagemaker_session): "transform_instances": transform_instances, "model_package_name": model_package_name, "marketplace_cert": False, + "sample_payload_url": sample_payload_url, + "task": task, } sagemaker_session.create_model_package_from_containers.assert_called_with( **expected_create_model_package_request diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 2040ed5d80..a02ea6eeca 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -1705,6 +1705,12 @@ def test_create_model_with_both(expand_container_def, sagemaker_session): "Environment": {"SAGEMAKER_DEFAULT_INVOCATIONS_ACCEPT": "application/json"}, "Image": "mi-1", "ModelDataUrl": "s3://bucket/model_1.tar.gz", + "Framework": "TENSORFLOW", + "FrameworkVersion": "2.9", + "NearestModelName": "resnet50", + "ModelInput": { + "DataInputConfig": '{"input_1":[1,224,224,3]}', + }, }, {"Environment": {}, "Image": "mi-2", "ModelDataUrl": "s3://bucket/model_2.tar.gz"}, ] @@ -2387,6 +2393,8 @@ def test_create_model_package_from_containers_all_args(sagemaker_session): description = "description" customer_metadata_properties = {"key1": "value1"} domain = "COMPUTER_VISION" + task = "IMAGE_CLASSIFICATION" + sample_payload_url = "s3://test-bucket/model" sagemaker_session.create_model_package_from_containers( containers=containers, content_types=content_types, @@ -2402,6 +2410,8 @@ def test_create_model_package_from_containers_all_args(sagemaker_session): drift_check_baselines=drift_check_baselines, customer_metadata_properties=customer_metadata_properties, domain=domain, + sample_payload_url=sample_payload_url, + task=task, ) expected_args = { "ModelPackageName": model_package_name, @@ -2420,6 +2430,8 @@ def test_create_model_package_from_containers_all_args(sagemaker_session): "DriftCheckBaselines": drift_check_baselines, "CustomerMetadataProperties": customer_metadata_properties, "Domain": domain, + "SamplePayloadUrl": sample_payload_url, + "Task": task, } sagemaker_session.sagemaker_client.create_model_package.assert_called_with(**expected_args)