From 4e6dd93d97086baf5917536767c067cbf6b9b47b Mon Sep 17 00:00:00 2001 From: Rahul Venkatesh Date: Thu, 16 Jun 2022 15:13:24 +0530 Subject: [PATCH 1/9] feature: include fields to work with inference recommender --- src/sagemaker/estimator.py | 9 ++++ src/sagemaker/huggingface/model.py | 24 +++++++++ src/sagemaker/model.py | 35 ++++++++++++- src/sagemaker/mxnet/model.py | 24 +++++++++ src/sagemaker/pipeline.py | 37 +++++++++++++ src/sagemaker/pytorch/model.py | 24 +++++++++ src/sagemaker/session.py | 52 +++++++++++++++++++ src/sagemaker/sklearn/model.py | 24 +++++++++ src/sagemaker/tensorflow/model.py | 24 +++++++++ src/sagemaker/workflow/_utils.py | 11 ++++ src/sagemaker/workflow/step_collections.py | 30 +++++++++++ .../test_model_create_and_registration.py | 47 +++++++++++++++++ .../workflow/test_pipeline_session.py | 21 +++++++- tests/unit/test_session.py | 12 +++++ 14 files changed, 372 insertions(+), 2 deletions(-) diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index 208dc208b4..ecdc5b117c 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -1301,6 +1301,8 @@ def register( drift_check_baselines=None, customer_metadata_properties=None, domain=None, + sample_payload_url=None, + task=None, **kwargs, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -1334,6 +1336,11 @@ def register( metadata properties (default: None). domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING", "MACHINE_LEARNING" (default: None). + sample_payload_url (str): The S3 path where the sample payload is stored + (default: None). + task (str): Task values which are supported by Inference Recommender are "FILL_MASK", + "IMAGE_CLASSIFICATION", "OBJECT_DETECTION", "TEXT_GENERATION", "IMAGE_SEGMENTATION", + "CLASSIFICATION", "REGRESSION", "OTHER" (default: None). **kwargs: Passed to invocation of ``create_model()``. Implementations may customize ``create_model()`` to accept ``**kwargs`` to customize model creation during deploy. For more, see the implementation docs. @@ -1371,6 +1378,8 @@ def register( drift_check_baselines=drift_check_baselines, customer_metadata_properties=customer_metadata_properties, domain=domain, + sample_payload_url=sample_payload_url, + task=task, ) @property diff --git a/src/sagemaker/huggingface/model.py b/src/sagemaker/huggingface/model.py index b72f1b1af2..8814b72175 100644 --- a/src/sagemaker/huggingface/model.py +++ b/src/sagemaker/huggingface/model.py @@ -306,6 +306,12 @@ def register( drift_check_baselines=None, customer_metadata_properties=None, domain=None, + sample_payload_url=None, + task=None, + framework=None, + framework_version=None, + nearest_model_name=None, + data_input_configuration=None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -337,6 +343,18 @@ def register( metadata properties (default: None). domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING", "MACHINE_LEARNING" (default: None). + sample_payload_url (str): The S3 path where the sample payload is stored + (default: None). + task (str): Task values which are supported by Inference Recommender are "FILL_MASK", + "IMAGE_CLASSIFICATION", "OBJECT_DETECTION", "TEXT_GENERATION", "IMAGE_SEGMENTATION", + "CLASSIFICATION", "REGRESSION", "OTHER" (default: None). + framework (str): Machine learning framework of the model package container image + (default: None). + framework_version (str): Framework version of the Model Package Container Image + (default: None). + nearest_model_name (str): Name of a pre-trained machine learning benchmarked by + Amazon SageMaker Inference Recommender (default: None). + data_input_configuration (str): Input object for the model (default: None). Returns: A `sagemaker.model.ModelPackage` instance. @@ -367,6 +385,12 @@ def register( drift_check_baselines=drift_check_baselines, customer_metadata_properties=customer_metadata_properties, domain=domain, + sample_payload_url=sample_payload_url, + task=task, + framework=framework, + framework_version=framework_version, + nearest_model_name=nearest_model_name, + data_input_configuration=data_input_configuration, ) def prepare_container_def( diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index bfa4caa6e0..4dec833bdf 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -310,6 +310,12 @@ def register( customer_metadata_properties=None, validation_specification=None, domain=None, + task=None, + sample_payload_url=None, + framework=None, + framework_version=None, + nearest_model_name=None, + data_input_configuration=None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -339,6 +345,18 @@ def register( metadata properties (default: None). domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING", "MACHINE_LEARNING" (default: None). + sample_payload_url (str): The S3 path where the sample payload is stored + (default: None). + task (str): Task values which are supported by Inference Recommender are "FILL_MASK", + "IMAGE_CLASSIFICATION", "OBJECT_DETECTION", "TEXT_GENERATION", "IMAGE_SEGMENTATION", + "CLASSIFICATION", "REGRESSION", "OTHER" (default: None). + framework (str): Machine learning framework of the model package container image + (default: None). + framework_version (str): Framework version of the Model Package Container Image + (default: None). + nearest_model_name (str): Name of a pre-trained machine learning benchmarked by + Amazon SageMaker Inference Recommender (default: None). + data_input_configuration (str): Input object for the model (default: None). Returns: A `sagemaker.model.ModelPackage` instance or pipeline step arguments @@ -352,7 +370,20 @@ def register( if model_package_group_name is not None: container_def = self.prepare_container_def() else: - container_def = {"Image": self.image_uri, "ModelDataUrl": self.model_data} + container_def = { + "Image": self.image_uri, + "ModelDataUrl": self.model_data, + } + container_def.update( + { + "Framework": framework, + "FrameworkVersion": framework_version, + "NearestModelName": nearest_model_name, + "ModelInput": { + "DataInputConfig": data_input_configuration, + }, + } + ) model_pkg_args = sagemaker.get_model_package_args( content_types, response_types, @@ -370,6 +401,8 @@ def register( customer_metadata_properties=customer_metadata_properties, validation_specification=validation_specification, domain=domain, + sample_payload_url=sample_payload_url, + task=task, ) model_package = self.sagemaker_session.create_model_package_from_containers( **model_pkg_args diff --git a/src/sagemaker/mxnet/model.py b/src/sagemaker/mxnet/model.py index fa2773bebb..60fc1d60d2 100644 --- a/src/sagemaker/mxnet/model.py +++ b/src/sagemaker/mxnet/model.py @@ -159,6 +159,12 @@ def register( drift_check_baselines=None, customer_metadata_properties=None, domain=None, + sample_payload_url=None, + task=None, + framework=None, + framework_version=None, + nearest_model_name=None, + data_input_configuration=None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -188,6 +194,18 @@ def register( metadata properties (default: None). domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING", "MACHINE_LEARNING" (default: None). + sample_payload_url (str): The S3 path where the sample payload is stored + (default: None). + task (str): Task values which are supported by Inference Recommender are "FILL_MASK", + "IMAGE_CLASSIFICATION", "OBJECT_DETECTION", "TEXT_GENERATION", "IMAGE_SEGMENTATION", + "CLASSIFICATION", "REGRESSION", "OTHER" (default: None). + framework (str): Machine learning framework of the model package container image + (default: None). + framework_version (str): Framework version of the Model Package Container Image + (default: None). + nearest_model_name (str): Name of a pre-trained machine learning benchmarked by + Amazon SageMaker Inference Recommender (default: None). + data_input_configuration (str): Input object for the model (default: None). Returns: A `sagemaker.model.ModelPackage` instance. @@ -218,6 +236,12 @@ def register( drift_check_baselines=drift_check_baselines, customer_metadata_properties=customer_metadata_properties, domain=domain, + sample_payload_url=sample_payload_url, + task=task, + framework=framework, + framework_version=framework_version, + nearest_model_name=nearest_model_name, + data_input_configuration=data_input_configuration, ) def prepare_container_def( diff --git a/src/sagemaker/pipeline.py b/src/sagemaker/pipeline.py index 75fae3bfc4..6f5e144ca6 100644 --- a/src/sagemaker/pipeline.py +++ b/src/sagemaker/pipeline.py @@ -279,6 +279,12 @@ def register( drift_check_baselines: Optional[DriftCheckBaselines] = None, customer_metadata_properties: Optional[Dict[str, str]] = None, domain: Optional[str] = None, + sample_payload_url: Optional[str] = None, + task: Optional[str] = None, + framework: Optional[str] = None, + framework_version: Optional[str] = None, + nearest_model_name: Optional[str] = None, + data_input_configuration: Optional[str] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -308,6 +314,18 @@ def register( metadata properties (default: None). domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING", "MACHINE_LEARNING" (default: None). + sample_payload_url (str): The S3 path where the sample payload is stored + (default: None). + task (str): Task values which are supported by Inference Recommender are "FILL_MASK", + "IMAGE_CLASSIFICATION", "OBJECT_DETECTION", "TEXT_GENERATION", "IMAGE_SEGMENTATION", + "CLASSIFICATION", "REGRESSION", "OTHER" (default: None). + framework (str): Machine learning framework of the model package container image + (default: None). + framework_version (str): Framework version of the Model Package Container Image + (default: None). + nearest_model_name (str): Name of a pre-trained machine learning benchmarked by + Amazon SageMaker Inference Recommender (default: None). + data_input_configuration (str): Input object for the model (default: None). Returns: A `sagemaker.model.ModelPackage` instance. @@ -319,11 +337,28 @@ def register( container_def = self.pipeline_container_def( inference_instances[0] if inference_instances else None ) + container_def[0].update( + { + "Framework": framework, + "FrameworkVersion": framework_version, + "NearestModelName": nearest_model_name, + "ModelInput": { + "DataInputConfig": data_input_configuration, + }, + } + ) else: container_def = [ { "Image": image_uri or model.image_uri, "ModelDataUrl": model.model_data, + "Framework": framework or model.framework, + "FrameworkVersion": framework_version or model.framework_version, + "NearestModelName": nearest_model_name or model.nearest_model_name, + "ModelInput": { + "DataInputConfig": data_input_configuration + or model.data_input_configuration + }, } for model in self.models ] @@ -344,6 +379,8 @@ def register( drift_check_baselines=drift_check_baselines, customer_metadata_properties=customer_metadata_properties, domain=domain, + sample_payload_url=sample_payload_url, + task=task, ) self.sagemaker_session.create_model_package_from_containers(**model_pkg_args) diff --git a/src/sagemaker/pytorch/model.py b/src/sagemaker/pytorch/model.py index 6e5d63c14d..b5e019f492 100644 --- a/src/sagemaker/pytorch/model.py +++ b/src/sagemaker/pytorch/model.py @@ -160,6 +160,12 @@ def register( drift_check_baselines=None, customer_metadata_properties=None, domain=None, + sample_payload_url=None, + task=None, + framework=None, + framework_version=None, + nearest_model_name=None, + data_input_configuration=None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -189,6 +195,18 @@ def register( metadata properties (default: None). domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING", "MACHINE_LEARNING" (default: None). + sample_payload_url (str): The S3 path where the sample payload is stored + (default: None). + task (str): Task values which are supported by Inference Recommender are "FILL_MASK", + "IMAGE_CLASSIFICATION", "OBJECT_DETECTION", "TEXT_GENERATION", "IMAGE_SEGMENTATION", + "CLASSIFICATION", "REGRESSION", "OTHER" (default: None). + framework (str): Machine learning framework of the model package container image + (default: None). + framework_version (str): Framework version of the Model Package Container Image + (default: None). + nearest_model_name (str): Name of a pre-trained machine learning benchmarked by + Amazon SageMaker Inference Recommender (default: None). + data_input_configuration (str): Input object for the model (default: None). Returns: A `sagemaker.model.ModelPackage` instance. @@ -219,6 +237,12 @@ def register( drift_check_baselines=drift_check_baselines, customer_metadata_properties=customer_metadata_properties, domain=domain, + sample_payload_url=sample_payload_url, + task=task, + framework=framework, + framework_version=framework_version, + nearest_model_name=nearest_model_name, + data_input_configuration=data_input_configuration, ) def prepare_container_def( diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 461dfd8bab..f0827d2e73 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -2822,6 +2822,8 @@ def create_model_package_from_containers( customer_metadata_properties=None, validation_specification=None, domain=None, + sample_payload_url=None, + task=None, ): """Get request dictionary for CreateModelPackage API. @@ -2851,6 +2853,11 @@ def create_model_package_from_containers( metadata properties (default: None). domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING", "MACHINE_LEARNING" (default: None). + sample_payload_url (str): The S3 path where the sample payload is stored + (default: None). + task (str): Task values which are supported by Inference Recommender are "FILL_MASK", + "IMAGE_CLASSIFICATION", "OBJECT_DETECTION", "TEXT_GENERATION", "IMAGE_SEGMENTATION", + "CLASSIFICATION", "REGRESSION", "OTHER" (default: None). """ model_pkg_request = get_create_model_package_request( @@ -2870,6 +2877,8 @@ def create_model_package_from_containers( customer_metadata_properties=customer_metadata_properties, validation_specification=validation_specification, domain=domain, + sample_payload_url=sample_payload_url, + task=task, ) def submit(request): @@ -4241,6 +4250,8 @@ def get_model_package_args( customer_metadata_properties=None, validation_specification=None, domain=None, + sample_payload_url=None, + task=None, ): """Get arguments for create_model_package method. @@ -4273,15 +4284,42 @@ def get_model_package_args( metadata properties (default: None). domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING", "MACHINE_LEARNING" (default: None). + sample_payload_url (str): The S3 path where the sample payload is stored (default: None). + task (str): Task values which are supported by Inference Recommender are "FILL_MASK", + "IMAGE_CLASSIFICATION", "OBJECT_DETECTION", "TEXT_GENERATION", "IMAGE_SEGMENTATION", + "CLASSIFICATION", "REGRESSION", "OTHER" (default: None). + framework (str): Machine learning framework of the model package container image + (default: None). + framework_version (str): Framework version of the Model Package Container Image + (default: None). + nearest_model_name (str): Name of a pre-trained machine learning benchmarked by + Amazon SageMaker Inference Recommender (default: None). + data_input_configuration (str): Input object for the model (default: None). Returns: dict: A dictionary of method argument names and values. """ if container_def_list is not None: + container_def_list[0].update( + { + "Framework": container_def_list[0]["Framework"], + "FrameworkVersion": container_def_list[0]["FrameworkVersion"], + "NearestModelName": container_def_list[0]["NearestModelName"], + "ModelInput": { + "DataInputConfig": container_def_list[0]["ModelInput"]["DataInputConfig"], + }, + } + ) containers = container_def_list else: container = { "Image": image_uri, "ModelDataUrl": model_data, + "Framework": container_def_list[0]["Framework"], + "FrameworkVersion": container_def_list[0]["FrameworkVersion"], + "NearestModelName": container_def_list[0]["NearestModelName"], + "ModelInput": { + "DataInputConfig": container_def_list[0]["ModelInput"]["DataInputConfig"], + }, } containers = [container] @@ -4316,6 +4354,10 @@ def get_model_package_args( model_package_args["validation_specification"] = validation_specification if domain is not None: model_package_args["domain"] = domain + if sample_payload_url is not None: + model_package_args["sample_payload_url"] = sample_payload_url + if task is not None: + model_package_args["task"] = task return model_package_args @@ -4337,6 +4379,8 @@ def get_create_model_package_request( customer_metadata_properties=None, validation_specification=None, domain=None, + sample_payload_url=None, + task=None, ): """Get request dictionary for CreateModelPackage API. @@ -4367,6 +4411,10 @@ def get_create_model_package_request( metadata properties (default: None). domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING", "MACHINE_LEARNING" (default: None). + sample_payload_url (str): The S3 path where the sample payload is stored (default: None). + task (str): Task values which are supported by Inference Recommender are "FILL_MASK", + "IMAGE_CLASSIFICATION", "OBJECT_DETECTION", "TEXT_GENERATION", "IMAGE_SEGMENTATION", + "CLASSIFICATION", "REGRESSION", "OTHER" (default: None). """ if all([model_package_name, model_package_group_name]): @@ -4394,6 +4442,10 @@ def get_create_model_package_request( request_dict["ValidationSpecification"] = validation_specification if domain is not None: request_dict["Domain"] = domain + if sample_payload_url is not None: + request_dict["SamplePayloadUrl"] = sample_payload_url + if task is not None: + request_dict["Task"] = task if containers is not None: if not all([content_types, response_types]): raise ValueError( diff --git a/src/sagemaker/sklearn/model.py b/src/sagemaker/sklearn/model.py index 71fe048bf1..67f9d60175 100644 --- a/src/sagemaker/sklearn/model.py +++ b/src/sagemaker/sklearn/model.py @@ -154,6 +154,12 @@ def register( drift_check_baselines=None, customer_metadata_properties=None, domain=None, + sample_payload_url=None, + task=None, + framework=None, + framework_version=None, + nearest_model_name=None, + data_input_configuration=None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -183,6 +189,18 @@ def register( metadata properties (default: None). domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING", "MACHINE_LEARNING" (default: None). + sample_payload_url (str): The S3 path where the sample payload is stored + (default: None). + task (str): Task values which are supported by Inference Recommender are "FILL_MASK", + "IMAGE_CLASSIFICATION", "OBJECT_DETECTION", "TEXT_GENERATION", "IMAGE_SEGMENTATION", + "CLASSIFICATION", "REGRESSION", "OTHER" (default: None). + framework (str): Machine learning framework of the model package container image + (default: None). + framework_version (str): Framework version of the Model Package Container Image + (default: None). + nearest_model_name (str): Name of a pre-trained machine learning benchmarked by + Amazon SageMaker Inference Recommender (default: None). + data_input_configuration (str): Input object for the model (default: None). Returns: A `sagemaker.model.ModelPackage` instance. @@ -213,6 +231,12 @@ def register( drift_check_baselines=drift_check_baselines, customer_metadata_properties=customer_metadata_properties, domain=domain, + sample_payload_url=sample_payload_url, + task=task, + framework=framework, + framework_version=framework_version, + nearest_model_name=nearest_model_name, + data_input_configuration=data_input_configuration, ) def prepare_container_def( diff --git a/src/sagemaker/tensorflow/model.py b/src/sagemaker/tensorflow/model.py index 3c4bf3343a..e5e6798a63 100644 --- a/src/sagemaker/tensorflow/model.py +++ b/src/sagemaker/tensorflow/model.py @@ -206,6 +206,12 @@ def register( drift_check_baselines=None, customer_metadata_properties=None, domain=None, + sample_payload_url=None, + task=None, + framework=None, + framework_version=None, + nearest_model_name=None, + data_input_configuration=None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -235,6 +241,18 @@ def register( metadata properties (default: None). domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING", "MACHINE_LEARNING" (default: None). + sample_payload_url (str): The S3 path where the sample payload is stored + (default: None). + task (str): Task values which are supported by Inference Recommender are "FILL_MASK", + "IMAGE_CLASSIFICATION", "OBJECT_DETECTION", "TEXT_GENERATION", "IMAGE_SEGMENTATION", + "CLASSIFICATION", "REGRESSION", "OTHER" (default: None). + framework (str): Machine learning framework of the model package container image + (default: None). + framework_version (str): Framework version of the Model Package Container Image + (default: None). + nearest_model_name (str): Name of a pre-trained machine learning benchmarked by + Amazon SageMaker Inference Recommender (default: None). + data_input_configuration (str): Input object for the model (default: None). Returns: A `sagemaker.model.ModelPackage` instance. @@ -265,6 +283,12 @@ def register( drift_check_baselines=drift_check_baselines, customer_metadata_properties=customer_metadata_properties, domain=domain, + sample_payload_url=sample_payload_url, + task=task, + framework=framework, + framework_version=framework_version, + nearest_model_name=nearest_model_name, + data_input_configuration=data_input_configuration, ) def deploy( diff --git a/src/sagemaker/workflow/_utils.py b/src/sagemaker/workflow/_utils.py index 8a0523c73a..a9ea95371f 100644 --- a/src/sagemaker/workflow/_utils.py +++ b/src/sagemaker/workflow/_utils.py @@ -285,6 +285,8 @@ def __init__( drift_check_baselines=None, customer_metadata_properties=None, domain=None, + sample_payload_url=None, + task=None, **kwargs, ): """Constructor of a register model step. @@ -329,6 +331,11 @@ def __init__( metadata properties (default: None). domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING", "MACHINE_LEARNING" (default: None). + sample_payload_url (str): The S3 path where the sample payload is stored + (default: None). + task (str): Task values which are supported by Inference Recommender are "FILL_MASK", + "IMAGE_CLASSIFICATION", "OBJECT_DETECTION", "TEXT_GENERATION", "IMAGE_SEGMENTATION", + "CLASSIFICATION", "REGRESSION", "OTHER" (default: None). **kwargs: additional arguments to `create_model`. """ super(_RegisterModelStep, self).__init__( @@ -360,6 +367,8 @@ def __init__( self.drift_check_baselines = drift_check_baselines self.customer_metadata_properties = customer_metadata_properties self.domain = domain + self.sample_payload_url = sample_payload_url + self.task = task self.metadata_properties = metadata_properties self.approval_status = approval_status self.image_uri = image_uri @@ -438,6 +447,8 @@ def arguments(self) -> RequestType: container_def_list=self.container_def_list, customer_metadata_properties=self.customer_metadata_properties, domain=self.domain, + sample_payload_url=self.sample_payload_url, + task=self.task, ) request_dict = get_create_model_package_request(**model_package_args) diff --git a/src/sagemaker/workflow/step_collections.py b/src/sagemaker/workflow/step_collections.py index bc7deb4fa3..3356a054ee 100644 --- a/src/sagemaker/workflow/step_collections.py +++ b/src/sagemaker/workflow/step_collections.py @@ -81,6 +81,12 @@ def __init__( drift_check_baselines=None, customer_metadata_properties=None, domain=None, + sample_payload_url=None, + task=None, + framework=None, + framework_version=None, + nearest_model_name=None, + data_input_configuration=None, **kwargs, ): """Construct steps `_RepackModelStep` and `_RegisterModelStep` based on the estimator. @@ -124,6 +130,18 @@ def __init__( metadata properties (default: None). domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING", "MACHINE_LEARNING" (default: None). + sample_payload_url (str): The S3 path where the sample payload is stored + (default: None). + task (str): Task values which are supported by Inference Recommender are "FILL_MASK", + "IMAGE_CLASSIFICATION", "OBJECT_DETECTION", "TEXT_GENERATION", "IMAGE_SEGMENTATION", + "CLASSIFICATION", "REGRESSION", "OTHER" (default: None). + framework (str): Machine learning framework of the model package container image + (default: None). + framework_version (str): Framework version of the Model Package Container Image + (default: None). + nearest_model_name (str): Name of a pre-trained machine learning benchmarked by + Amazon SageMaker Inference Recommender (default: None). + data_input_configuration (str): Input object for the model (default: None). **kwargs: additional arguments to `create_model`. """ @@ -228,6 +246,16 @@ def __init__( inference_instances[0] if inference_instances else None ) ] + self.container_def_list[0].update( + { + "Framework": framework, + "FrameworkVersion": framework_version, + "NearestModelName": nearest_model_name, + "ModelInput": { + "DataInputConfig": data_input_configuration, + }, + } + ) register_model_step = _RegisterModelStep( name=name, @@ -250,6 +278,8 @@ def __init__( retry_policies=register_model_step_retry_policies, customer_metadata_properties=customer_metadata_properties, domain=domain, + sample_payload_url=sample_payload_url, + task=task, **kwargs, ) if not repack_model: diff --git a/tests/integ/sagemaker/workflow/test_model_create_and_registration.py b/tests/integ/sagemaker/workflow/test_model_create_and_registration.py index d8f1d9ab6c..265ba724ad 100644 --- a/tests/integ/sagemaker/workflow/test_model_create_and_registration.py +++ b/tests/integ/sagemaker/workflow/test_model_create_and_registration.py @@ -555,6 +555,12 @@ def test_model_registration_with_drift_check_baselines( ) customer_metadata_properties = {"key1": "value1"} domain = "COMPUTER_VISION" + task = "IMAGE_CLASSIFICATION" + sample_payload_url = "s3://test-bucket/model" + framework = "TENSORFLOW" + framework_version = "2.9" + nearest_model_name = "resnet50" + data_input_configuration = '{"input_1":[1,224,224,3]}' # If image_uri is not provided, the instance_type should not be a pipeline variable # since instance_type is used to retrieve image_uri in compile time (PySDK) @@ -568,8 +574,28 @@ def test_model_registration_with_drift_check_baselines( py_version="py3", role=role, ) + + base_dir = os.path.join(DATA_DIR, "mxnet_mnist") + source_dir = os.path.join(base_dir, "code") + entry_point = os.path.join(source_dir, "inference.py") + mx_mnist_model_data = os.path.join(base_dir, "model.tar.gz") + + instance_count = ParameterInteger(name="InstanceCount", default_value=1) + instance_type = ParameterString(name="InstanceType", default_value="ml.m5.xlarge") + + model = MXNetModel( + entry_point=entry_point, + source_dir=source_dir, + role=role, + model_data=mx_mnist_model_data, + framework_version="1.4.0", + py_version="py3", + sagemaker_session=sagemaker_session, + ) + step_register = RegisterModel( name="MyRegisterModelStep", + model=model, estimator=estimator, model_data=model_uri_param, content_types=["application/json"], @@ -581,6 +607,12 @@ def test_model_registration_with_drift_check_baselines( drift_check_baselines=drift_check_baselines, customer_metadata_properties=customer_metadata_properties, domain=domain, + sample_payload_url=sample_payload_url, + task=task, + framework=framework, + framework_version=framework_version, + nearest_model_name=nearest_model_name, + data_input_configuration=data_input_configuration, ) pipeline = Pipeline( @@ -652,6 +684,21 @@ def test_model_registration_with_drift_check_baselines( ) assert response["CustomerMetadataProperties"] == customer_metadata_properties assert response["Domain"] == domain + assert response["Task"] == task + assert response["SamplePayloadUrl"] == sample_payload_url + assert response["InferenceSpecification"]["Containers"][0]["Framework"] == framework + assert ( + response["InferenceSpecification"]["Containers"][0]["FrameworkVersion"] + == framework_version + ) + assert ( + response["InferenceSpecification"]["Containers"][0]["NearestModelName"] + == nearest_model_name + ) + assert ( + response["InferenceSpecification"]["Containers"][0]["ModelInput"]["DataInputConfig"] + == data_input_configuration + ) break finally: try: diff --git a/tests/unit/sagemaker/workflow/test_pipeline_session.py b/tests/unit/sagemaker/workflow/test_pipeline_session.py index d2954ede7b..90a9116c07 100644 --- a/tests/unit/sagemaker/workflow/test_pipeline_session.py +++ b/tests/unit/sagemaker/workflow/test_pipeline_session.py @@ -116,6 +116,12 @@ def test_pipeline_session_context_for_model_step(pipeline_session_mock): inference_instances=["ml.t2.medium", "ml.m5.xlarge"], transform_instances=["ml.m5.xlarge"], model_package_group_name="MyModelPackageGroup", + task="IMAGE_CLASSIFICATION", + sample_payload_url="s3://test-bucket/model", + framework="TENSORFLOW", + framework_version="2.9", + nearest_model_name="resnet50", + data_input_configuration='{"input_1":[1,224,224,3]}', ) # The context should be cleaned up before return assert not pipeline_session_mock.context @@ -136,11 +142,16 @@ def test_pipeline_session_context_for_model_step_without_instance_types( source_dir=f"{DATA_DIR}", role=_ROLE, ) - register_step_args = model.register( content_types=["text/csv"], response_types=["text/csv"], model_package_group_name="MyModelPackageGroup", + task="IMAGE_CLASSIFICATION", + sample_payload_url="s3://test-bucket/model", + framework="TENSORFLOW", + framework_version="2.9", + nearest_model_name="resnet50", + data_input_configuration='{"input_1":[1,224,224,3]}', ) expected_output = { @@ -159,6 +170,12 @@ def test_pipeline_session_context_for_model_step_without_instance_types( name="ModelData", default_value="s3://my-bucket/file", ), + "Framework": "TENSORFLOW", + "FrameworkVersion": "2.9", + "NearestModelName": "resnet50", + "ModelInput": { + "DataInputConfig": '{"input_1":[1,224,224,3]}', + }, } ], "SupportedContentTypes": ["text/csv"], @@ -168,6 +185,8 @@ def test_pipeline_session_context_for_model_step_without_instance_types( }, "CertifyForMarketplace": False, "ModelApprovalStatus": "PendingManualApproval", + "SamplePayloadUrl": "s3://test-bucket/model", + "Task": "IMAGE_CLASSIFICATION", } assert register_step_args.create_model_package_request == expected_output diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 2040ed5d80..a02ea6eeca 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -1705,6 +1705,12 @@ def test_create_model_with_both(expand_container_def, sagemaker_session): "Environment": {"SAGEMAKER_DEFAULT_INVOCATIONS_ACCEPT": "application/json"}, "Image": "mi-1", "ModelDataUrl": "s3://bucket/model_1.tar.gz", + "Framework": "TENSORFLOW", + "FrameworkVersion": "2.9", + "NearestModelName": "resnet50", + "ModelInput": { + "DataInputConfig": '{"input_1":[1,224,224,3]}', + }, }, {"Environment": {}, "Image": "mi-2", "ModelDataUrl": "s3://bucket/model_2.tar.gz"}, ] @@ -2387,6 +2393,8 @@ def test_create_model_package_from_containers_all_args(sagemaker_session): description = "description" customer_metadata_properties = {"key1": "value1"} domain = "COMPUTER_VISION" + task = "IMAGE_CLASSIFICATION" + sample_payload_url = "s3://test-bucket/model" sagemaker_session.create_model_package_from_containers( containers=containers, content_types=content_types, @@ -2402,6 +2410,8 @@ def test_create_model_package_from_containers_all_args(sagemaker_session): drift_check_baselines=drift_check_baselines, customer_metadata_properties=customer_metadata_properties, domain=domain, + sample_payload_url=sample_payload_url, + task=task, ) expected_args = { "ModelPackageName": model_package_name, @@ -2420,6 +2430,8 @@ def test_create_model_package_from_containers_all_args(sagemaker_session): "DriftCheckBaselines": drift_check_baselines, "CustomerMetadataProperties": customer_metadata_properties, "Domain": domain, + "SamplePayloadUrl": sample_payload_url, + "Task": task, } sagemaker_session.sagemaker_client.create_model_package.assert_called_with(**expected_args) From 27a67e09dd67599e24ad0cd0d2bab18354173a3b Mon Sep 17 00:00:00 2001 From: Rahul Venkatesh Date: Thu, 16 Jun 2022 22:40:13 +0530 Subject: [PATCH 2/9] fix: fixed failing UT's --- src/sagemaker/estimator.py | 15 +++++ src/sagemaker/session.py | 8 +-- src/sagemaker/workflow/step_collections.py | 22 +++---- .../workflow/test_step_collections.py | 65 +++++++++++++++++++ tests/unit/test_estimator.py | 60 +++++++++++++++++ 5 files changed, 155 insertions(+), 15 deletions(-) diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index ecdc5b117c..98c43f348d 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -1303,6 +1303,10 @@ def register( domain=None, sample_payload_url=None, task=None, + framework=None, + framework_version=None, + nearest_model_name=None, + data_input_configuration=None, **kwargs, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -1341,6 +1345,13 @@ def register( task (str): Task values which are supported by Inference Recommender are "FILL_MASK", "IMAGE_CLASSIFICATION", "OBJECT_DETECTION", "TEXT_GENERATION", "IMAGE_SEGMENTATION", "CLASSIFICATION", "REGRESSION", "OTHER" (default: None). + framework (str): Machine learning framework of the model package container image + (default: None). + framework_version (str): Framework version of the Model Package Container Image + (default: None). + nearest_model_name (str): Name of a pre-trained machine learning benchmarked by + Amazon SageMaker Inference Recommender (default: None). + data_input_configuration (str): Input object for the model (default: None). **kwargs: Passed to invocation of ``create_model()``. Implementations may customize ``create_model()`` to accept ``**kwargs`` to customize model creation during deploy. For more, see the implementation docs. @@ -1380,6 +1391,10 @@ def register( domain=domain, sample_payload_url=sample_payload_url, task=task, + framework=framework, + framework_version=framework_version, + nearest_model_name=nearest_model_name, + data_input_configuration=data_input_configuration, ) @property diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index f0827d2e73..a5c40c4e97 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -4314,11 +4314,11 @@ def get_model_package_args( container = { "Image": image_uri, "ModelDataUrl": model_data, - "Framework": container_def_list[0]["Framework"], - "FrameworkVersion": container_def_list[0]["FrameworkVersion"], - "NearestModelName": container_def_list[0]["NearestModelName"], + "Framework": None, + "FrameworkVersion": None, + "NearestModelName": None, "ModelInput": { - "DataInputConfig": container_def_list[0]["ModelInput"]["DataInputConfig"], + "DataInputConfig": None, }, } containers = [container] diff --git a/src/sagemaker/workflow/step_collections.py b/src/sagemaker/workflow/step_collections.py index 3356a054ee..b4825fd7a9 100644 --- a/src/sagemaker/workflow/step_collections.py +++ b/src/sagemaker/workflow/step_collections.py @@ -246,17 +246,17 @@ def __init__( inference_instances[0] if inference_instances else None ) ] - self.container_def_list[0].update( - { - "Framework": framework, - "FrameworkVersion": framework_version, - "NearestModelName": nearest_model_name, - "ModelInput": { - "DataInputConfig": data_input_configuration, - }, - } - ) - + for container_obj in self.container_def_list: + container_obj.update( + { + "Framework": framework, + "FrameworkVersion": framework_version, + "NearestModelName": nearest_model_name, + "ModelInput": { + "DataInputConfig": data_input_configuration, + }, + } + ) register_model_step = _RegisterModelStep( name=name, estimator=estimator, diff --git a/tests/unit/sagemaker/workflow/test_step_collections.py b/tests/unit/sagemaker/workflow/test_step_collections.py index 9d41e70aca..747038b081 100644 --- a/tests/unit/sagemaker/workflow/test_step_collections.py +++ b/tests/unit/sagemaker/workflow/test_step_collections.py @@ -368,6 +368,12 @@ def test_register_model(estimator, model_metrics, drift_check_baselines): display_name="RegisterModelStep", depends_on=["TestStep"], tags=[{"Key": "myKey", "Value": "myValue"}], + sample_payload_url="s3://test-bucket/model", + task="IMAGE_CLASSIFICATION", + framework="TENSORFLOW", + framework_version="2.9", + nearest_model_name="resnet50", + data_input_configuration='{"input_1":[1,224,224,3]}', ) assert ordered(register_model.request_dicts()) == ordered( [ @@ -383,6 +389,12 @@ def test_register_model(estimator, model_metrics, drift_check_baselines): { "Image": "012345678901.dkr.ecr.us-west-2.amazonaws.com/my-custom-image-uri", "ModelDataUrl": f"s3://{BUCKET}/model.tar.gz", + "Framework": None, + "FrameworkVersion": None, + "NearestModelName": None, + "ModelInput": { + "DataInputConfig": None, + }, } ], "SupportedContentTypes": ["content_type"], @@ -412,6 +424,8 @@ def test_register_model(estimator, model_metrics, drift_check_baselines): "ModelPackageDescription": "description", "ModelPackageGroupName": "mpg", "Tags": [{"Key": "myKey", "Value": "myValue"}], + "SamplePayloadUrl": "s3://test-bucket/model", + "Task": "IMAGE_CLASSIFICATION", }, }, ] @@ -433,6 +447,12 @@ def test_register_model_tf(estimator_tf, model_metrics, drift_check_baselines): drift_check_baselines=drift_check_baselines, approval_status="Approved", description="description", + sample_payload_url="s3://test-bucket/model", + task="IMAGE_CLASSIFICATION", + framework="TENSORFLOW", + framework_version="2.9", + nearest_model_name="resnet50", + data_input_configuration='{"input_1":[1,224,224,3]}', ) assert ordered(register_model.request_dicts()) == ordered( [ @@ -446,6 +466,12 @@ def test_register_model_tf(estimator_tf, model_metrics, drift_check_baselines): { "Image": "763104351884.dkr.ecr.us-west-2.amazonaws.com/tensorflow-inference:1.15.2-cpu", "ModelDataUrl": f"s3://{BUCKET}/model.tar.gz", + "Framework": None, + "FrameworkVersion": None, + "NearestModelName": None, + "ModelInput": { + "DataInputConfig": None, + }, } ], "SupportedContentTypes": ["content_type"], @@ -474,6 +500,8 @@ def test_register_model_tf(estimator_tf, model_metrics, drift_check_baselines): }, "ModelPackageDescription": "description", "ModelPackageGroupName": "mpg", + "SamplePayloadUrl": "s3://test-bucket/model", + "Task": "IMAGE_CLASSIFICATION", }, }, ] @@ -502,6 +530,12 @@ def test_register_model_sip(estimator, model_metrics, drift_check_baselines): description="description", model=pipeline_model, depends_on=["TestStep"], + sample_payload_url="s3://test-bucket/model", + task="IMAGE_CLASSIFICATION", + framework="TENSORFLOW", + framework_version="2.9", + nearest_model_name="resnet50", + data_input_configuration='{"input_1":[1,224,224,3]}', ) assert ordered(register_model.request_dicts()) == ordered( [ @@ -517,11 +551,23 @@ def test_register_model_sip(estimator, model_metrics, drift_check_baselines): "Image": "fakeimage1", "ModelDataUrl": "Url1", "Environment": [{"k1": "v1"}, {"k2": "v2"}], + "Framework": "TENSORFLOW", + "FrameworkVersion": "2.9", + "NearestModelName": "resnet50", + "ModelInput": { + "DataInputConfig": '{"input_1":[1,224,224,3]}', + }, }, { "Image": "fakeimage2", "ModelDataUrl": "Url2", "Environment": [{"k3": "v3"}, {"k4": "v4"}], + "Framework": "TENSORFLOW", + "FrameworkVersion": "2.9", + "NearestModelName": "resnet50", + "ModelInput": { + "DataInputConfig": '{"input_1":[1,224,224,3]}', + }, }, ], "SupportedContentTypes": ["content_type"], @@ -550,6 +596,8 @@ def test_register_model_sip(estimator, model_metrics, drift_check_baselines): }, "ModelPackageDescription": "description", "ModelPackageGroupName": "mpg", + "SamplePayloadUrl": "s3://test-bucket/model", + "Task": "IMAGE_CLASSIFICATION", }, }, ] @@ -578,6 +626,12 @@ def test_register_model_with_model_repack_with_estimator( dependencies=[dummy_requirements], depends_on=["TestStep"], tags=[{"Key": "myKey", "Value": "myValue"}], + sample_payload_url="s3://test-bucket/model", + task="IMAGE_CLASSIFICATION", + framework="TENSORFLOW", + framework_version="2.9", + nearest_model_name="resnet50", + data_input_configuration='{"input_1":[1,224,224,3]}', ) request_dicts = register_model.request_dicts() @@ -649,6 +703,15 @@ def test_register_model_with_model_repack_with_estimator( assert isinstance( arguments["InferenceSpecification"]["Containers"][0]["ModelDataUrl"], Properties ) + assert arguments["InferenceSpecification"]["Containers"][0]["Framework"] == None + assert arguments["InferenceSpecification"]["Containers"][0]["FrameworkVersion"] == None + assert arguments["InferenceSpecification"]["Containers"][0]["NearestModelName"] == None + assert ( + arguments["InferenceSpecification"]["Containers"][0]["ModelInput"][ + "DataInputConfig" + ] + == None + ) del arguments["InferenceSpecification"]["Containers"] assert ordered(arguments) == ordered( { @@ -680,6 +743,8 @@ def test_register_model_with_model_repack_with_estimator( "ModelPackageDescription": "description", "ModelPackageGroupName": "mpg", "Tags": [{"Key": "myKey", "Value": "myValue"}], + "SamplePayloadUrl": "s3://test-bucket/model", + "Task": "IMAGE_CLASSIFICATION", } ) else: diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index 7e2e985845..6d37f0b443 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -3109,6 +3109,12 @@ def test_register_default_image(sagemaker_session): response_types = ["application/json"] inference_instances = ["ml.m4.xlarge"] transform_instances = ["ml.m4.xlarget"] + sample_payload_url = "s3://test-bucket/model" + task = "IMAGE_CLASSIFICATION" + framework = "TENSORFLOW" + framework_version = "2.9" + nearest_model_name = "resnet50" + data_input_config = '{"input_1":[1,224,224,3]}' estimator.register( content_types=content_types, @@ -3116,6 +3122,12 @@ def test_register_default_image(sagemaker_session): inference_instances=inference_instances, transform_instances=transform_instances, model_package_name=model_package_name, + sample_payload_url=sample_payload_url, + task=task, + framework=framework, + framework_version=framework_version, + nearest_model_name=nearest_model_name, + data_input_configuration=data_input_config, ) sagemaker_session.create_model.assert_not_called() @@ -3124,6 +3136,12 @@ def test_register_default_image(sagemaker_session): { "Image": estimator.image_uri, "ModelDataUrl": estimator.model_data, + "Framework": framework, + "FrameworkVersion": framework_version, + "NearestModelName": nearest_model_name, + "ModelInput": { + "DataInputConfig": data_input_config, + }, } ], "content_types": content_types, @@ -3132,6 +3150,8 @@ def test_register_default_image(sagemaker_session): "transform_instances": transform_instances, "model_package_name": model_package_name, "marketplace_cert": False, + "sample_payload_url": sample_payload_url, + "task": task, } sagemaker_session.create_model_package_from_containers.assert_called_with( **expected_create_model_package_request @@ -3153,11 +3173,23 @@ def test_register_default_image_without_instance_type_args(sagemaker_session): model_package_name = "test-estimator-register-model" content_types = ["application/json"] response_types = ["application/json"] + sample_payload_url = "s3://test-bucket/model" + task = "IMAGE_CLASSIFICATION" + framework = "TENSORFLOW" + framework_version = "2.9" + nearest_model_name = "resnet50" + data_input_config = '{"input_1":[1,224,224,3]}' estimator.register( content_types=content_types, response_types=response_types, model_package_name=model_package_name, + sample_payload_url=sample_payload_url, + task=task, + framework=framework, + framework_version=framework_version, + nearest_model_name=nearest_model_name, + data_input_configuration=data_input_config, ) sagemaker_session.create_model.assert_not_called() @@ -3166,6 +3198,12 @@ def test_register_default_image_without_instance_type_args(sagemaker_session): { "Image": estimator.image_uri, "ModelDataUrl": estimator.model_data, + "Framework": framework, + "FrameworkVersion": framework_version, + "NearestModelName": nearest_model_name, + "ModelInput": { + "DataInputConfig": data_input_config, + }, } ], "content_types": content_types, @@ -3174,6 +3212,8 @@ def test_register_default_image_without_instance_type_args(sagemaker_session): "transform_instances": None, "model_package_name": model_package_name, "marketplace_cert": False, + "sample_payload_url": sample_payload_url, + "task": task, } sagemaker_session.create_model_package_from_containers.assert_called_with( **expected_create_model_package_request @@ -3198,6 +3238,12 @@ def test_register_inference_image(sagemaker_session): inference_instances = ["ml.m4.xlarge"] transform_instances = ["ml.m4.xlarget"] inference_image = "fake-inference-image" + sample_payload_url = "s3://test-bucket/model" + task = "IMAGE_CLASSIFICATION" + framework = "TENSORFLOW" + framework_version = "2.9" + nearest_model_name = "resnet50" + data_input_config = '{"input_1":[1,224,224,3]}' estimator.register( content_types=content_types, @@ -3205,7 +3251,13 @@ def test_register_inference_image(sagemaker_session): inference_instances=inference_instances, transform_instances=transform_instances, model_package_name=model_package_name, + sample_payload_url=sample_payload_url, + task=task, image_uri=inference_image, + framework=framework, + framework_version=framework_version, + nearest_model_name=nearest_model_name, + data_input_configuration=data_input_config, ) sagemaker_session.create_model.assert_not_called() @@ -3214,6 +3266,12 @@ def test_register_inference_image(sagemaker_session): { "Image": inference_image, "ModelDataUrl": estimator.model_data, + "Framework": framework, + "FrameworkVersion": framework_version, + "NearestModelName": nearest_model_name, + "ModelInput": { + "DataInputConfig": data_input_config, + }, } ], "content_types": content_types, @@ -3222,6 +3280,8 @@ def test_register_inference_image(sagemaker_session): "transform_instances": transform_instances, "model_package_name": model_package_name, "marketplace_cert": False, + "sample_payload_url": sample_payload_url, + "task": task, } sagemaker_session.create_model_package_from_containers.assert_called_with( **expected_create_model_package_request From e6fe39c98a0c699a3b1c832f911172ce1cea6408 Mon Sep 17 00:00:00 2001 From: Rahul Venkatesh Date: Thu, 16 Jun 2022 23:11:33 +0530 Subject: [PATCH 3/9] fix: fix flake8 error --- tests/unit/sagemaker/workflow/test_step_collections.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/unit/sagemaker/workflow/test_step_collections.py b/tests/unit/sagemaker/workflow/test_step_collections.py index 747038b081..0d7133aa02 100644 --- a/tests/unit/sagemaker/workflow/test_step_collections.py +++ b/tests/unit/sagemaker/workflow/test_step_collections.py @@ -703,14 +703,14 @@ def test_register_model_with_model_repack_with_estimator( assert isinstance( arguments["InferenceSpecification"]["Containers"][0]["ModelDataUrl"], Properties ) - assert arguments["InferenceSpecification"]["Containers"][0]["Framework"] == None - assert arguments["InferenceSpecification"]["Containers"][0]["FrameworkVersion"] == None - assert arguments["InferenceSpecification"]["Containers"][0]["NearestModelName"] == None + assert arguments["InferenceSpecification"]["Containers"][0]["Framework"] is None + assert arguments["InferenceSpecification"]["Containers"][0]["FrameworkVersion"] is None + assert arguments["InferenceSpecification"]["Containers"][0]["NearestModelName"] is None assert ( arguments["InferenceSpecification"]["Containers"][0]["ModelInput"][ "DataInputConfig" ] - == None + is None ) del arguments["InferenceSpecification"]["Containers"] assert ordered(arguments) == ordered( From c6d44da02696a1a5dd97efac55890d3a916cfdec Mon Sep 17 00:00:00 2001 From: Rahul Venkatesh Date: Fri, 17 Jun 2022 18:20:44 +0530 Subject: [PATCH 4/9] fix: add conditionals to include container variables --- src/sagemaker/model.py | 26 ++++++++------ src/sagemaker/pipeline.py | 34 +++++++++---------- src/sagemaker/session.py | 33 +++++++++--------- src/sagemaker/workflow/step_collections.py | 28 +++++++++------ .../workflow/test_step_collections.py | 21 ------------ 5 files changed, 67 insertions(+), 75 deletions(-) diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index 4dec833bdf..872e2e85c7 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -374,16 +374,22 @@ def register( "Image": self.image_uri, "ModelDataUrl": self.model_data, } - container_def.update( - { - "Framework": framework, - "FrameworkVersion": framework_version, - "NearestModelName": nearest_model_name, - "ModelInput": { - "DataInputConfig": data_input_configuration, - }, - } - ) + if ( + framework is not None + and framework_version is not None + and nearest_model_name is not None + and data_input_configuration is not None + ): + container_def.update( + { + "Framework": framework, + "FrameworkVersion": framework_version, + "NearestModelName": nearest_model_name, + "ModelInput": { + "DataInputConfig": data_input_configuration, + }, + } + ) model_pkg_args = sagemaker.get_model_package_args( content_types, response_types, diff --git a/src/sagemaker/pipeline.py b/src/sagemaker/pipeline.py index 6f5e144ca6..6d28763944 100644 --- a/src/sagemaker/pipeline.py +++ b/src/sagemaker/pipeline.py @@ -337,31 +337,31 @@ def register( container_def = self.pipeline_container_def( inference_instances[0] if inference_instances else None ) - container_def[0].update( - { - "Framework": framework, - "FrameworkVersion": framework_version, - "NearestModelName": nearest_model_name, - "ModelInput": { - "DataInputConfig": data_input_configuration, - }, - } - ) else: container_def = [ { "Image": image_uri or model.image_uri, "ModelDataUrl": model.model_data, - "Framework": framework or model.framework, - "FrameworkVersion": framework_version or model.framework_version, - "NearestModelName": nearest_model_name or model.nearest_model_name, - "ModelInput": { - "DataInputConfig": data_input_configuration - or model.data_input_configuration - }, } for model in self.models ] + if ( + framework is not None + and framework_version is not None + and nearest_model_name is not None + and data_input_configuration is not None + ): + for container_obj in container_def: + container_obj.update( + { + "Framework": framework, + "FrameworkVersion": framework_version, + "NearestModelName": nearest_model_name, + "ModelInput": { + "DataInputConfig": data_input_configuration, + }, + } + ) model_pkg_args = sagemaker.get_model_package_args( content_types, diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index a5c40c4e97..a26555d21b 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -4299,27 +4299,28 @@ def get_model_package_args( dict: A dictionary of method argument names and values. """ if container_def_list is not None: - container_def_list[0].update( - { - "Framework": container_def_list[0]["Framework"], - "FrameworkVersion": container_def_list[0]["FrameworkVersion"], - "NearestModelName": container_def_list[0]["NearestModelName"], - "ModelInput": { - "DataInputConfig": container_def_list[0]["ModelInput"]["DataInputConfig"], - }, - } - ) + container_fields = container_def_list[0] + if ( + container_fields.get("Framework") is not None + and container_fields.get("FrameworkVersion") is not None + and container_fields.get("NearestModelName") is not None + and container_fields.get("ModelInput").get("DataInputConfig") is not None + ): + container_def_list[0].update( + { + "Framework": container_fields["Framework"], + "FrameworkVersion": container_fields["FrameworkVersion"], + "NearestModelName": container_fields["NearestModelName"], + "ModelInput": { + "DataInputConfig": container_fields["ModelInput"]["DataInputConfig"], + }, + } + ) containers = container_def_list else: container = { "Image": image_uri, "ModelDataUrl": model_data, - "Framework": None, - "FrameworkVersion": None, - "NearestModelName": None, - "ModelInput": { - "DataInputConfig": None, - }, } containers = [container] diff --git a/src/sagemaker/workflow/step_collections.py b/src/sagemaker/workflow/step_collections.py index b4825fd7a9..f11fede6c4 100644 --- a/src/sagemaker/workflow/step_collections.py +++ b/src/sagemaker/workflow/step_collections.py @@ -246,17 +246,23 @@ def __init__( inference_instances[0] if inference_instances else None ) ] - for container_obj in self.container_def_list: - container_obj.update( - { - "Framework": framework, - "FrameworkVersion": framework_version, - "NearestModelName": nearest_model_name, - "ModelInput": { - "DataInputConfig": data_input_configuration, - }, - } - ) + if ( + framework is not None + and framework_version is not None + and nearest_model_name is not None + and data_input_configuration is not None + ): + for container_obj in self.container_def_list: + container_obj.update( + { + "Framework": framework, + "FrameworkVersion": framework_version, + "NearestModelName": nearest_model_name, + "ModelInput": { + "DataInputConfig": data_input_configuration, + }, + } + ) register_model_step = _RegisterModelStep( name=name, estimator=estimator, diff --git a/tests/unit/sagemaker/workflow/test_step_collections.py b/tests/unit/sagemaker/workflow/test_step_collections.py index 0d7133aa02..4aa55fd068 100644 --- a/tests/unit/sagemaker/workflow/test_step_collections.py +++ b/tests/unit/sagemaker/workflow/test_step_collections.py @@ -389,12 +389,6 @@ def test_register_model(estimator, model_metrics, drift_check_baselines): { "Image": "012345678901.dkr.ecr.us-west-2.amazonaws.com/my-custom-image-uri", "ModelDataUrl": f"s3://{BUCKET}/model.tar.gz", - "Framework": None, - "FrameworkVersion": None, - "NearestModelName": None, - "ModelInput": { - "DataInputConfig": None, - }, } ], "SupportedContentTypes": ["content_type"], @@ -466,12 +460,6 @@ def test_register_model_tf(estimator_tf, model_metrics, drift_check_baselines): { "Image": "763104351884.dkr.ecr.us-west-2.amazonaws.com/tensorflow-inference:1.15.2-cpu", "ModelDataUrl": f"s3://{BUCKET}/model.tar.gz", - "Framework": None, - "FrameworkVersion": None, - "NearestModelName": None, - "ModelInput": { - "DataInputConfig": None, - }, } ], "SupportedContentTypes": ["content_type"], @@ -703,15 +691,6 @@ def test_register_model_with_model_repack_with_estimator( assert isinstance( arguments["InferenceSpecification"]["Containers"][0]["ModelDataUrl"], Properties ) - assert arguments["InferenceSpecification"]["Containers"][0]["Framework"] is None - assert arguments["InferenceSpecification"]["Containers"][0]["FrameworkVersion"] is None - assert arguments["InferenceSpecification"]["Containers"][0]["NearestModelName"] is None - assert ( - arguments["InferenceSpecification"]["Containers"][0]["ModelInput"][ - "DataInputConfig" - ] - is None - ) del arguments["InferenceSpecification"]["Containers"] assert ordered(arguments) == ordered( { From f194f7a5003b39fe72f98f5984ac9b64bedc1cf8 Mon Sep 17 00:00:00 2001 From: Rahul Venkatesh Date: Sat, 18 Jun 2022 22:42:22 +0530 Subject: [PATCH 5/9] fix: fixed failing Integration Test --- .../test_model_create_and_registration.py | 71 ++++++++++--------- 1 file changed, 39 insertions(+), 32 deletions(-) diff --git a/tests/integ/sagemaker/workflow/test_model_create_and_registration.py b/tests/integ/sagemaker/workflow/test_model_create_and_registration.py index 265ba724ad..d0f617a266 100644 --- a/tests/integ/sagemaker/workflow/test_model_create_and_registration.py +++ b/tests/integ/sagemaker/workflow/test_model_create_and_registration.py @@ -94,6 +94,13 @@ def test_conditional_pytorch_training_model_registration( good_enough_input = ParameterInteger(name="GoodEnoughInput", default_value=1) in_condition_input = ParameterString(name="Foo", default_value="Foo") + task = "IMAGE_CLASSIFICATION" + sample_payload_url = "s3://test-bucket/model" + framework = "TENSORFLOW" + framework_version = "2.9" + nearest_model_name = "resnet50" + data_input_configuration = '{"input_1":[1,224,224,3]}' + # If image_uri is not provided, the instance_type should not be a pipeline variable # since instance_type is used to retrieve image_uri in compile time (PySDK) pytorch_estimator = PyTorch( @@ -120,6 +127,12 @@ def test_conditional_pytorch_training_model_registration( inference_instances=["*"], transform_instances=["*"], description="test-description", + sample_payload_url=sample_payload_url, + task=task, + framework=framework, + framework_version=framework_version, + nearest_model_name=nearest_model_name, + data_input_configuration=data_input_configuration, ) model = Model( @@ -201,6 +214,13 @@ def test_mxnet_model_registration( instance_count = ParameterInteger(name="InstanceCount", default_value=1) instance_type = ParameterString(name="InstanceType", default_value="ml.m5.xlarge") + task = "IMAGE_CLASSIFICATION" + sample_payload_url = "s3://test-bucket/model" + framework = "TENSORFLOW" + framework_version = "2.9" + nearest_model_name = "resnet50" + data_input_configuration = '{"input_1":[1,224,224,3]}' + model = MXNetModel( entry_point=entry_point, source_dir=source_dir, @@ -219,6 +239,12 @@ def test_mxnet_model_registration( inference_instances=["ml.m5.xlarge"], transform_instances=["*"], description="test-description", + sample_payload_url=sample_payload_url, + task=task, + framework=framework, + framework_version=framework_version, + nearest_model_name=nearest_model_name, + data_input_configuration=data_input_configuration, ) pipeline = Pipeline( @@ -262,6 +288,13 @@ def test_sklearn_xgboost_sip_model_registration( instance_count = ParameterInteger(name="InstanceCount", default_value=1) instance_type = "ml.m5.xlarge" + task = "IMAGE_CLASSIFICATION" + sample_payload_url = "s3://test-bucket/model" + framework = "TENSORFLOW" + framework_version = "2.9" + nearest_model_name = "resnet50" + data_input_configuration = '{"input_1":[1,224,224,3]}' + # The instance_type should not be a pipeline variable # since it is used to retrieve image_uri in compile time (PySDK) sklearn_processor = SKLearnProcessor( @@ -412,6 +445,12 @@ def test_sklearn_xgboost_sip_model_registration( inference_instances=["ml.t2.medium", "ml.m5.xlarge"], transform_instances=["ml.m5.xlarge"], model_package_group_name="windturbine", + sample_payload_url=sample_payload_url, + task=task, + framework=framework, + framework_version=framework_version, + nearest_model_name=nearest_model_name, + data_input_configuration=data_input_configuration, ) pipeline = Pipeline( @@ -575,27 +614,8 @@ def test_model_registration_with_drift_check_baselines( role=role, ) - base_dir = os.path.join(DATA_DIR, "mxnet_mnist") - source_dir = os.path.join(base_dir, "code") - entry_point = os.path.join(source_dir, "inference.py") - mx_mnist_model_data = os.path.join(base_dir, "model.tar.gz") - - instance_count = ParameterInteger(name="InstanceCount", default_value=1) - instance_type = ParameterString(name="InstanceType", default_value="ml.m5.xlarge") - - model = MXNetModel( - entry_point=entry_point, - source_dir=source_dir, - role=role, - model_data=mx_mnist_model_data, - framework_version="1.4.0", - py_version="py3", - sagemaker_session=sagemaker_session, - ) - step_register = RegisterModel( name="MyRegisterModelStep", - model=model, estimator=estimator, model_data=model_uri_param, content_types=["application/json"], @@ -686,19 +706,6 @@ def test_model_registration_with_drift_check_baselines( assert response["Domain"] == domain assert response["Task"] == task assert response["SamplePayloadUrl"] == sample_payload_url - assert response["InferenceSpecification"]["Containers"][0]["Framework"] == framework - assert ( - response["InferenceSpecification"]["Containers"][0]["FrameworkVersion"] - == framework_version - ) - assert ( - response["InferenceSpecification"]["Containers"][0]["NearestModelName"] - == nearest_model_name - ) - assert ( - response["InferenceSpecification"]["Containers"][0]["ModelInput"]["DataInputConfig"] - == data_input_configuration - ) break finally: try: From bb76610794ebc224def005364c06a4253ad35fe9 Mon Sep 17 00:00:00 2001 From: Rahul Venkatesh Date: Tue, 21 Jun 2022 19:55:36 +0530 Subject: [PATCH 6/9] fix: added util function for inference recommender fields --- src/sagemaker/model.py | 32 ++++++------- src/sagemaker/pipeline.py | 35 +++++++------- src/sagemaker/session.py | 25 +--------- src/sagemaker/utils.py | 56 ++++++++++++++++++++++ src/sagemaker/workflow/step_collections.py | 23 ++++----- tests/unit/test_estimator.py | 18 ------- 6 files changed, 99 insertions(+), 90 deletions(-) diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index 872e2e85c7..c16e6ee8ec 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -35,7 +35,11 @@ from sagemaker.serverless import ServerlessInferenceConfig from sagemaker.transformer import Transformer from sagemaker.jumpstart.utils import add_jumpstart_tags, get_jumpstart_base_name_if_jumpstart_model -from sagemaker.utils import unique_name_from_base +from sagemaker.utils import ( + unique_name_from_base, + inference_recommender_params_exist, + update_container_object, +) from sagemaker.async_inference import AsyncInferenceConfig from sagemaker.predictor_async import AsyncPredictor from sagemaker.workflow import is_pipeline_variable @@ -367,29 +371,23 @@ def register( raise ValueError("SageMaker Model Package cannot be created without model data.") if image_uri is not None: self.image_uri = image_uri + if model_package_group_name is not None: container_def = self.prepare_container_def() + if inference_recommender_params_exist( + framework, framework_version, nearest_model_name, data_input_configuration + ): + container_def.update( + update_container_object( + framework, framework_version, nearest_model_name, data_input_configuration + ) + ) else: container_def = { "Image": self.image_uri, "ModelDataUrl": self.model_data, } - if ( - framework is not None - and framework_version is not None - and nearest_model_name is not None - and data_input_configuration is not None - ): - container_def.update( - { - "Framework": framework, - "FrameworkVersion": framework_version, - "NearestModelName": nearest_model_name, - "ModelInput": { - "DataInputConfig": data_input_configuration, - }, - } - ) + model_pkg_args = sagemaker.get_model_package_args( content_types, response_types, diff --git a/src/sagemaker/pipeline.py b/src/sagemaker/pipeline.py index 6d28763944..803a76f7a8 100644 --- a/src/sagemaker/pipeline.py +++ b/src/sagemaker/pipeline.py @@ -20,7 +20,11 @@ from sagemaker.drift_check_baselines import DriftCheckBaselines from sagemaker.metadata_properties import MetadataProperties from sagemaker.session import Session -from sagemaker.utils import name_from_image +from sagemaker.utils import ( + name_from_image, + inference_recommender_params_exist, + update_container_object, +) from sagemaker.transformer import Transformer from sagemaker.workflow.pipeline_context import runnable_by_pipeline @@ -337,6 +341,18 @@ def register( container_def = self.pipeline_container_def( inference_instances[0] if inference_instances else None ) + if inference_recommender_params_exist( + framework, framework_version, nearest_model_name, data_input_configuration + ): + for container_obj in container_def: + container_obj.update( + update_container_object( + framework, + framework_version, + nearest_model_name, + data_input_configuration, + ) + ) else: container_def = [ { @@ -345,23 +361,6 @@ def register( } for model in self.models ] - if ( - framework is not None - and framework_version is not None - and nearest_model_name is not None - and data_input_configuration is not None - ): - for container_obj in container_def: - container_obj.update( - { - "Framework": framework, - "FrameworkVersion": framework_version, - "NearestModelName": nearest_model_name, - "ModelInput": { - "DataInputConfig": data_input_configuration, - }, - } - ) model_pkg_args = sagemaker.get_model_package_args( content_types, diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index a26555d21b..eb158eab3d 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -4288,34 +4288,11 @@ def get_model_package_args( task (str): Task values which are supported by Inference Recommender are "FILL_MASK", "IMAGE_CLASSIFICATION", "OBJECT_DETECTION", "TEXT_GENERATION", "IMAGE_SEGMENTATION", "CLASSIFICATION", "REGRESSION", "OTHER" (default: None). - framework (str): Machine learning framework of the model package container image - (default: None). - framework_version (str): Framework version of the Model Package Container Image - (default: None). - nearest_model_name (str): Name of a pre-trained machine learning benchmarked by - Amazon SageMaker Inference Recommender (default: None). - data_input_configuration (str): Input object for the model (default: None). + Returns: dict: A dictionary of method argument names and values. """ if container_def_list is not None: - container_fields = container_def_list[0] - if ( - container_fields.get("Framework") is not None - and container_fields.get("FrameworkVersion") is not None - and container_fields.get("NearestModelName") is not None - and container_fields.get("ModelInput").get("DataInputConfig") is not None - ): - container_def_list[0].update( - { - "Framework": container_fields["Framework"], - "FrameworkVersion": container_fields["FrameworkVersion"], - "NearestModelName": container_fields["NearestModelName"], - "ModelInput": { - "DataInputConfig": container_fields["ModelInput"]["DataInputConfig"], - }, - } - ) containers = container_def_list else: container = { diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index 1d2e9fe5cb..6fc294b2a6 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -722,3 +722,59 @@ def get_data_bucket(self, region_requested=None): get_ecr_image_uri_prefix = deprecations.removed_function("get_ecr_image_uri_prefix") + + +def inference_recommender_params_exist( + framework=None, framework_version=None, nearest_model_name=None, data_input_configuration=None +): + """ + Function to check if inference recommender parameters exist. + + Args: + framework (str): Machine learning framework of the model package container image + (default: None). + framework_version (str): Framework version of the Model Package Container Image + (default: None). + nearest_model_name (str): Name of a pre-trained machine learning benchmarked by + Amazon SageMaker Inference Recommender (default: None). + data_input_configuration (str): Input object for the model (default: None). + + Returns: + bool: all required fields exist or not + """ + if ( + framework is not None + and framework_version is not None + and nearest_model_name is not None + and data_input_configuration is not None + ): + return True + return False + + +def update_container_object( + framework=None, framework_version=None, nearest_model_name=None, data_input_configuration=None +): + """ + Update the container_def object with inference recommedender parameters. + + Args: + framework (str): Machine learning framework of the model package container image + (default: None). + framework_version (str): Framework version of the Model Package Container Image + (default: None). + nearest_model_name (str): Name of a pre-trained machine learning benchmarked by + Amazon SageMaker Inference Recommender (default: None). + data_input_configuration (str): Input object for the model (default: None). + + Returns: + dict: inference recommender key, value pairs which updates the object. + """ + return { + "Framework": framework, + "FrameworkVersion": framework_version, + "NearestModelName": nearest_model_name, + "ModelInput": { + "DataInputConfig": data_input_configuration, + }, + } diff --git a/src/sagemaker/workflow/step_collections.py b/src/sagemaker/workflow/step_collections.py index 53840d06ce..bf7fae82c3 100644 --- a/src/sagemaker/workflow/step_collections.py +++ b/src/sagemaker/workflow/step_collections.py @@ -27,6 +27,7 @@ from sagemaker.workflow.steps import Step, CreateModelStep, TransformStep from sagemaker.workflow._utils import _RegisterModelStep, _RepackModelStep from sagemaker.workflow.retry import RetryPolicy +from sagemaker.utils import inference_recommender_params_exist, update_container_object @attr.s @@ -245,23 +246,19 @@ def __init__( inference_instances[0] if inference_instances else None ) ] - if ( - framework is not None - and framework_version is not None - and nearest_model_name is not None - and data_input_configuration is not None + if inference_recommender_params_exist( + framework, framework_version, nearest_model_name, data_input_configuration ): for container_obj in self.container_def_list: container_obj.update( - { - "Framework": framework, - "FrameworkVersion": framework_version, - "NearestModelName": nearest_model_name, - "ModelInput": { - "DataInputConfig": data_input_configuration, - }, - } + update_container_object( + framework, + framework_version, + nearest_model_name, + data_input_configuration, + ) ) + register_model_step = _RegisterModelStep( name=name, estimator=estimator, diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index 6d37f0b443..78298025ea 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -3136,12 +3136,6 @@ def test_register_default_image(sagemaker_session): { "Image": estimator.image_uri, "ModelDataUrl": estimator.model_data, - "Framework": framework, - "FrameworkVersion": framework_version, - "NearestModelName": nearest_model_name, - "ModelInput": { - "DataInputConfig": data_input_config, - }, } ], "content_types": content_types, @@ -3198,12 +3192,6 @@ def test_register_default_image_without_instance_type_args(sagemaker_session): { "Image": estimator.image_uri, "ModelDataUrl": estimator.model_data, - "Framework": framework, - "FrameworkVersion": framework_version, - "NearestModelName": nearest_model_name, - "ModelInput": { - "DataInputConfig": data_input_config, - }, } ], "content_types": content_types, @@ -3266,12 +3254,6 @@ def test_register_inference_image(sagemaker_session): { "Image": inference_image, "ModelDataUrl": estimator.model_data, - "Framework": framework, - "FrameworkVersion": framework_version, - "NearestModelName": nearest_model_name, - "ModelInput": { - "DataInputConfig": data_input_config, - }, } ], "content_types": content_types, From 750d213cf239fcac579eb29b11b116d89e7b6fe5 Mon Sep 17 00:00:00 2001 From: Rahul Venkatesh Date: Tue, 21 Jun 2022 20:05:16 +0530 Subject: [PATCH 7/9] fix: remove extra space --- src/sagemaker/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index c16e6ee8ec..a52090e0c2 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -371,7 +371,7 @@ def register( raise ValueError("SageMaker Model Package cannot be created without model data.") if image_uri is not None: self.image_uri = image_uri - + if model_package_group_name is not None: container_def = self.prepare_container_def() if inference_recommender_params_exist( From a538beff93b6985d5e3b6d7538018895471d6f47 Mon Sep 17 00:00:00 2001 From: Rahul Venkatesh Date: Tue, 21 Jun 2022 20:38:02 +0530 Subject: [PATCH 8/9] fix: fix docstyle error --- src/sagemaker/utils.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index 6fc294b2a6..a4cff57b7a 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -727,8 +727,7 @@ def get_data_bucket(self, region_requested=None): def inference_recommender_params_exist( framework=None, framework_version=None, nearest_model_name=None, data_input_configuration=None ): - """ - Function to check if inference recommender parameters exist. + """Function to check if inference recommender parameters exist. Args: framework (str): Machine learning framework of the model package container image @@ -755,8 +754,7 @@ def inference_recommender_params_exist( def update_container_object( framework=None, framework_version=None, nearest_model_name=None, data_input_configuration=None ): - """ - Update the container_def object with inference recommedender parameters. + """Update the container_def object with inference recommedender parameters. Args: framework (str): Machine learning framework of the model package container image From c857aca4af2424d27d087ec986cdba976a127651 Mon Sep 17 00:00:00 2001 From: Rahul Venkatesh Date: Wed, 22 Jun 2022 19:26:26 +0530 Subject: [PATCH 9/9] fix: refactored util function --- src/sagemaker/model.py | 18 +++--- src/sagemaker/pipeline.py | 22 +++---- src/sagemaker/utils.py | 68 +++++++++++----------- src/sagemaker/workflow/step_collections.py | 22 +++---- 4 files changed, 60 insertions(+), 70 deletions(-) diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index a52090e0c2..60c766379b 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -37,8 +37,7 @@ from sagemaker.jumpstart.utils import add_jumpstart_tags, get_jumpstart_base_name_if_jumpstart_model from sagemaker.utils import ( unique_name_from_base, - inference_recommender_params_exist, - update_container_object, + update_container_with_inference_params, ) from sagemaker.async_inference import AsyncInferenceConfig from sagemaker.predictor_async import AsyncPredictor @@ -374,14 +373,13 @@ def register( if model_package_group_name is not None: container_def = self.prepare_container_def() - if inference_recommender_params_exist( - framework, framework_version, nearest_model_name, data_input_configuration - ): - container_def.update( - update_container_object( - framework, framework_version, nearest_model_name, data_input_configuration - ) - ) + update_container_with_inference_params( + framework=framework, + framework_version=framework_version, + nearest_model_name=nearest_model_name, + data_input_configuration=data_input_configuration, + container_obj=container_def, + ) else: container_def = { "Image": self.image_uri, diff --git a/src/sagemaker/pipeline.py b/src/sagemaker/pipeline.py index 803a76f7a8..8cdb82ffe7 100644 --- a/src/sagemaker/pipeline.py +++ b/src/sagemaker/pipeline.py @@ -22,8 +22,7 @@ from sagemaker.session import Session from sagemaker.utils import ( name_from_image, - inference_recommender_params_exist, - update_container_object, + update_container_with_inference_params, ) from sagemaker.transformer import Transformer from sagemaker.workflow.pipeline_context import runnable_by_pipeline @@ -341,18 +340,13 @@ def register( container_def = self.pipeline_container_def( inference_instances[0] if inference_instances else None ) - if inference_recommender_params_exist( - framework, framework_version, nearest_model_name, data_input_configuration - ): - for container_obj in container_def: - container_obj.update( - update_container_object( - framework, - framework_version, - nearest_model_name, - data_input_configuration, - ) - ) + update_container_with_inference_params( + framework=framework, + framework_version=framework_version, + nearest_model_name=nearest_model_name, + data_input_configuration=data_input_configuration, + container_list=container_def, + ) else: container_def = [ { diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index a4cff57b7a..ed5b3c5e75 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -724,10 +724,15 @@ def get_data_bucket(self, region_requested=None): get_ecr_image_uri_prefix = deprecations.removed_function("get_ecr_image_uri_prefix") -def inference_recommender_params_exist( - framework=None, framework_version=None, nearest_model_name=None, data_input_configuration=None +def update_container_with_inference_params( + framework=None, + framework_version=None, + nearest_model_name=None, + data_input_configuration=None, + container_obj=None, + container_list=None, ): - """Function to check if inference recommender parameters exist. + """Function to check if inference recommender parameters exist and update container. Args: framework (str): Machine learning framework of the model package container image @@ -737,42 +742,39 @@ def inference_recommender_params_exist( nearest_model_name (str): Name of a pre-trained machine learning benchmarked by Amazon SageMaker Inference Recommender (default: None). data_input_configuration (str): Input object for the model (default: None). + container_obj (dict): object to be updated. + container_list (list): list to be updated. Returns: - bool: all required fields exist or not + dict: dict with inference recommender params """ + if ( framework is not None and framework_version is not None and nearest_model_name is not None and data_input_configuration is not None ): - return True - return False - - -def update_container_object( - framework=None, framework_version=None, nearest_model_name=None, data_input_configuration=None -): - """Update the container_def object with inference recommedender parameters. - - Args: - framework (str): Machine learning framework of the model package container image - (default: None). - framework_version (str): Framework version of the Model Package Container Image - (default: None). - nearest_model_name (str): Name of a pre-trained machine learning benchmarked by - Amazon SageMaker Inference Recommender (default: None). - data_input_configuration (str): Input object for the model (default: None). - - Returns: - dict: inference recommender key, value pairs which updates the object. - """ - return { - "Framework": framework, - "FrameworkVersion": framework_version, - "NearestModelName": nearest_model_name, - "ModelInput": { - "DataInputConfig": data_input_configuration, - }, - } + if container_list is not None: + for obj in container_list: + obj.update( + { + "Framework": framework, + "FrameworkVersion": framework_version, + "NearestModelName": nearest_model_name, + "ModelInput": { + "DataInputConfig": data_input_configuration, + }, + } + ) + if container_obj is not None: + container_obj.update( + { + "Framework": framework, + "FrameworkVersion": framework_version, + "NearestModelName": nearest_model_name, + "ModelInput": { + "DataInputConfig": data_input_configuration, + }, + } + ) diff --git a/src/sagemaker/workflow/step_collections.py b/src/sagemaker/workflow/step_collections.py index bf7fae82c3..dd9529916e 100644 --- a/src/sagemaker/workflow/step_collections.py +++ b/src/sagemaker/workflow/step_collections.py @@ -27,7 +27,7 @@ from sagemaker.workflow.steps import Step, CreateModelStep, TransformStep from sagemaker.workflow._utils import _RegisterModelStep, _RepackModelStep from sagemaker.workflow.retry import RetryPolicy -from sagemaker.utils import inference_recommender_params_exist, update_container_object +from sagemaker.utils import update_container_with_inference_params @attr.s @@ -246,18 +246,14 @@ def __init__( inference_instances[0] if inference_instances else None ) ] - if inference_recommender_params_exist( - framework, framework_version, nearest_model_name, data_input_configuration - ): - for container_obj in self.container_def_list: - container_obj.update( - update_container_object( - framework, - framework_version, - nearest_model_name, - data_input_configuration, - ) - ) + + update_container_with_inference_params( + framework=framework, + framework_version=framework_version, + nearest_model_name=nearest_model_name, + data_input_configuration=data_input_configuration, + container_list=self.container_def_list, + ) register_model_step = _RegisterModelStep( name=name,