diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index 0dad6bd1cc..208dc208b4 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -1286,8 +1286,8 @@ def register( self, content_types, response_types, - inference_instances, - transform_instances, + inference_instances=None, + transform_instances=None, image_uri=None, model_package_name=None, model_package_group_name=None, @@ -1309,9 +1309,9 @@ def register( content_types (list): The supported MIME types for the input data. response_types (list): The supported MIME types for the output data. inference_instances (list): A list of the instance types that are used to - generate inferences in real-time. + generate inferences in real-time (default: None). transform_instances (list): A list of the instance types on which a transformation - job can be run or on which an endpoint can be deployed. + job can be run or on which an endpoint can be deployed (default: None). image_uri (str): The container image uri for Model Package, if not specified, Estimator's training container image will be used (default: None). model_package_name (str): Model Package name, exclusive to `model_package_group_name`, diff --git a/src/sagemaker/huggingface/model.py b/src/sagemaker/huggingface/model.py index 48e3681563..b72f1b1af2 100644 --- a/src/sagemaker/huggingface/model.py +++ b/src/sagemaker/huggingface/model.py @@ -293,8 +293,8 @@ def register( self, content_types, response_types, - inference_instances, - transform_instances, + inference_instances=None, + transform_instances=None, model_package_name=None, model_package_group_name=None, image_uri=None, @@ -313,9 +313,9 @@ def register( content_types (list): The supported MIME types for the input data. response_types (list): The supported MIME types for the output data. inference_instances (list): A list of the instance types that are used to - generate inferences in real-time. + generate inferences in real-time (default: None). transform_instances (list): A list of the instance types on which a transformation - job can be run or on which an endpoint can be deployed. + job can be run or on which an endpoint can be deployed (default: None). model_package_name (str): Model Package name, exclusive to `model_package_group_name`, using `model_package_name` makes the Model Package un-versioned. Defaults to ``None``. @@ -341,7 +341,7 @@ def register( Returns: A `sagemaker.model.ModelPackage` instance. """ - instance_type = inference_instances[0] + instance_type = inference_instances[0] if inference_instances else None self._init_sagemaker_session_if_does_not_exist(instance_type) if image_uri: diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index b456b452f8..9c66c57c7f 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -296,8 +296,8 @@ def register( self, content_types, response_types, - inference_instances, - transform_instances, + inference_instances=None, + transform_instances=None, model_package_name=None, model_package_group_name=None, image_uri=None, @@ -317,9 +317,9 @@ def register( content_types (list): The supported MIME types for the input data. response_types (list): The supported MIME types for the output data. inference_instances (list): A list of the instance types that are used to - generate inferences in real-time. + generate inferences in real-time (default: None). transform_instances (list): A list of the instance types on which a transformation - job can be run or on which an endpoint can be deployed. + job can be run or on which an endpoint can be deployed (default: None). model_package_name (str): Model Package name, exclusive to `model_package_group_name`, using `model_package_name` makes the Model Package un-versioned (default: None). model_package_group_name (str): Model Package Group name, exclusive to @@ -351,12 +351,11 @@ def register( container_def = self.prepare_container_def() else: container_def = {"Image": self.image_uri, "ModelDataUrl": self.model_data} - model_pkg_args = sagemaker.get_model_package_args( content_types, response_types, - inference_instances, - transform_instances, + inference_instances=inference_instances, + transform_instances=transform_instances, model_package_name=model_package_name, model_package_group_name=model_package_group_name, model_metrics=model_metrics, diff --git a/src/sagemaker/mxnet/model.py b/src/sagemaker/mxnet/model.py index c7eba903cb..fa2773bebb 100644 --- a/src/sagemaker/mxnet/model.py +++ b/src/sagemaker/mxnet/model.py @@ -146,8 +146,8 @@ def register( self, content_types, response_types, - inference_instances, - transform_instances, + inference_instances=None, + transform_instances=None, model_package_name=None, model_package_group_name=None, image_uri=None, @@ -166,9 +166,9 @@ def register( content_types (list): The supported MIME types for the input data. response_types (list): The supported MIME types for the output data. inference_instances (list): A list of the instance types that are used to - generate inferences in real-time. + generate inferences in real-time (default: None). transform_instances (list): A list of the instance types on which a transformation - job can be run or on which an endpoint can be deployed. + job can be run or on which an endpoint can be deployed (default: None). model_package_name (str): Model Package name, exclusive to `model_package_group_name`, using `model_package_name` makes the Model Package un-versioned (default: None). model_package_group_name (str): Model Package Group name, exclusive to @@ -192,7 +192,7 @@ def register( Returns: A `sagemaker.model.ModelPackage` instance. """ - instance_type = inference_instances[0] + instance_type = inference_instances[0] if inference_instances else None self._init_sagemaker_session_if_does_not_exist(instance_type) if image_uri: diff --git a/src/sagemaker/pipeline.py b/src/sagemaker/pipeline.py index 4dab7ccc1c..75fae3bfc4 100644 --- a/src/sagemaker/pipeline.py +++ b/src/sagemaker/pipeline.py @@ -84,7 +84,7 @@ def __init__( self.enable_network_isolation = enable_network_isolation self.endpoint_name = None - def pipeline_container_def(self, instance_type): + def pipeline_container_def(self, instance_type=None): """The pipeline definition for deploying this model. This is the dict created by ``sagemaker.pipeline_container_def()``. @@ -266,8 +266,8 @@ def register( self, content_types: list, response_types: list, - inference_instances: list, - transform_instances: list, + inference_instances: Optional[list] = None, + transform_instances: Optional[list] = None, model_package_name: Optional[str] = None, model_package_group_name: Optional[str] = None, image_uri: Optional[str] = None, @@ -286,9 +286,9 @@ def register( content_types (list): The supported MIME types for the input data. response_types (list): The supported MIME types for the output data. inference_instances (list): A list of the instance types that are used to - generate inferences in real-time. + generate inferences in real-time (default: None). transform_instances (list): A list of the instance types on which a transformation - job can be run or on which an endpoint can be deployed. + job can be run or on which an endpoint can be deployed (default: None). model_package_name (str): Model Package name, exclusive to `model_package_group_name`, using `model_package_name` makes the Model Package un-versioned (default: None). model_package_group_name (str): Model Package Group name, exclusive to @@ -316,18 +316,23 @@ def register( if model.model_data is None: raise ValueError("SageMaker Model Package cannot be created without model data.") if model_package_group_name is not None: - container_def = self.pipeline_container_def(inference_instances[0]) + container_def = self.pipeline_container_def( + inference_instances[0] if inference_instances else None + ) else: container_def = [ - {"Image": image_uri or model.image_uri, "ModelDataUrl": model.model_data} + { + "Image": image_uri or model.image_uri, + "ModelDataUrl": model.model_data, + } for model in self.models ] model_pkg_args = sagemaker.get_model_package_args( content_types, response_types, - inference_instances, - transform_instances, + inference_instances=inference_instances, + transform_instances=transform_instances, model_package_name=model_package_name, model_package_group_name=model_package_group_name, model_metrics=model_metrics, diff --git a/src/sagemaker/pytorch/model.py b/src/sagemaker/pytorch/model.py index 879fd9da15..6e5d63c14d 100644 --- a/src/sagemaker/pytorch/model.py +++ b/src/sagemaker/pytorch/model.py @@ -147,8 +147,8 @@ def register( self, content_types, response_types, - inference_instances, - transform_instances, + inference_instances=None, + transform_instances=None, model_package_name=None, model_package_group_name=None, image_uri=None, @@ -167,9 +167,9 @@ def register( content_types (list): The supported MIME types for the input data. response_types (list): The supported MIME types for the output data. inference_instances (list): A list of the instance types that are used to - generate inferences in real-time. + generate inferences in real-time (default: None). transform_instances (list): A list of the instance types on which a transformation - job can be run or on which an endpoint can be deployed. + job can be run or on which an endpoint can be deployed (default: None). model_package_name (str): Model Package name, exclusive to `model_package_group_name`, using `model_package_name` makes the Model Package un-versioned (default: None). model_package_group_name (str): Model Package Group name, exclusive to @@ -193,7 +193,7 @@ def register( Returns: A `sagemaker.model.ModelPackage` instance. """ - instance_type = inference_instances[0] + instance_type = inference_instances[0] if inference_instances else None self._init_sagemaker_session_if_does_not_exist(instance_type) if image_uri: diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index b9e6221218..44458438c4 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -4206,8 +4206,8 @@ def _intercept_create_request( def get_model_package_args( content_types, response_types, - inference_instances, - transform_instances, + inference_instances=None, + transform_instances=None, model_package_name=None, model_package_group_name=None, model_data=None, @@ -4230,9 +4230,9 @@ def get_model_package_args( content_types (list): The supported MIME types for the input data. response_types (list): The supported MIME types for the output data. inference_instances (list): A list of the instance types that are used to - generate inferences in real-time. + generate inferences in real-time (default: None). transform_instances (list): A list of the instance types on which a transformation - job can be run or on which an endpoint can be deployed. + job can be run or on which an endpoint can be deployed (default: None). model_package_name (str): Model Package name, exclusive to `model_package_group_name`, using `model_package_name` makes the Model Package un-versioned (default: None). model_package_group_name (str): Model Package Group name, exclusive to @@ -4377,10 +4377,9 @@ def get_create_model_package_request( if domain is not None: request_dict["Domain"] = domain if containers is not None: - if not all([content_types, response_types, inference_instances, transform_instances]): + if not all([content_types, response_types]): raise ValueError( - "content_types, response_types, inference_inferences and transform_instances " - "must be provided if containers is present." + "content_types and response_types " "must be provided if containers is present." ) inference_specification = { "Containers": containers, diff --git a/src/sagemaker/sklearn/model.py b/src/sagemaker/sklearn/model.py index e0ae5e2c3d..71fe048bf1 100644 --- a/src/sagemaker/sklearn/model.py +++ b/src/sagemaker/sklearn/model.py @@ -141,8 +141,8 @@ def register( self, content_types, response_types, - inference_instances, - transform_instances, + inference_instances=None, + transform_instances=None, model_package_name=None, model_package_group_name=None, image_uri=None, @@ -161,9 +161,9 @@ def register( content_types (list): The supported MIME types for the input data. response_types (list): The supported MIME types for the output data. inference_instances (list): A list of the instance types that are used to - generate inferences in real-time. + generate inferences in real-time (default: None). transform_instances (list): A list of the instance types on which a transformation - job can be run or on which an endpoint can be deployed. + job can be run or on which an endpoint can be deployed (default: None). model_package_name (str): Model Package name, exclusive to `model_package_group_name`, using `model_package_name` makes the Model Package un-versioned (default: None). model_package_group_name (str): Model Package Group name, exclusive to @@ -187,7 +187,7 @@ def register( Returns: A `sagemaker.model.ModelPackage` instance. """ - instance_type = inference_instances[0] + instance_type = inference_instances[0] if inference_instances else None self._init_sagemaker_session_if_does_not_exist(instance_type) if image_uri: diff --git a/src/sagemaker/tensorflow/model.py b/src/sagemaker/tensorflow/model.py index a66912fc00..3c4bf3343a 100644 --- a/src/sagemaker/tensorflow/model.py +++ b/src/sagemaker/tensorflow/model.py @@ -193,8 +193,8 @@ def register( self, content_types, response_types, - inference_instances, - transform_instances, + inference_instances=None, + transform_instances=None, model_package_name=None, model_package_group_name=None, image_uri=None, @@ -213,9 +213,9 @@ def register( content_types (list): The supported MIME types for the input data. response_types (list): The supported MIME types for the output data. inference_instances (list): A list of the instance types that are used to - generate inferences in real-time. + generate inferences in real-time (default: None). transform_instances (list): A list of the instance types on which a transformation - job can be run or on which an endpoint can be deployed. + job can be run or on which an endpoint can be deployed (default: None). model_package_name (str): Model Package name, exclusive to `model_package_group_name`, using `model_package_name` makes the Model Package un-versioned (default: None). model_package_group_name (str): Model Package Group name, exclusive to @@ -239,7 +239,7 @@ def register( Returns: A `sagemaker.model.ModelPackage` instance. """ - instance_type = inference_instances[0] + instance_type = inference_instances[0] if inference_instances else None self._init_sagemaker_session_if_does_not_exist(instance_type) if image_uri: diff --git a/src/sagemaker/workflow/step_collections.py b/src/sagemaker/workflow/step_collections.py index 8a88b7d39e..e4491f5492 100644 --- a/src/sagemaker/workflow/step_collections.py +++ b/src/sagemaker/workflow/step_collections.py @@ -63,8 +63,8 @@ def __init__( name: str, content_types, response_types, - inference_instances, - transform_instances, + inference_instances=None, + transform_instances=None, estimator: EstimatorBase = None, model_data=None, depends_on: Optional[List[Union[str, Step, StepCollection]]] = None, @@ -220,9 +220,15 @@ def __init__( kwargs.pop("output_kms_key", None) if isinstance(model, PipelineModel): - self.container_def_list = model.pipeline_container_def(inference_instances[0]) + self.container_def_list = model.pipeline_container_def( + inference_instances[0] if inference_instances else None + ) elif isinstance(model, Model): - self.container_def_list = [model.prepare_container_def(inference_instances[0])] + self.container_def_list = [ + model.prepare_container_def( + inference_instances[0] if inference_instances else None + ) + ] register_model_step = _RegisterModelStep( name=name, diff --git a/tests/unit/sagemaker/workflow/test_pipeline_session.py b/tests/unit/sagemaker/workflow/test_pipeline_session.py index cdfa489af2..d2954ede7b 100644 --- a/tests/unit/sagemaker/workflow/test_pipeline_session.py +++ b/tests/unit/sagemaker/workflow/test_pipeline_session.py @@ -122,3 +122,52 @@ def test_pipeline_session_context_for_model_step(pipeline_session_mock): assert not register_step_args.create_model_request assert register_step_args.create_model_package_request assert len(register_step_args.need_runtime_repack) == 0 + + +def test_pipeline_session_context_for_model_step_without_instance_types( + pipeline_session_mock, +): + model = Model( + name="MyModel", + image_uri="fakeimage", + model_data=ParameterString(name="ModelData", default_value="s3://my-bucket/file"), + sagemaker_session=pipeline_session_mock, + entry_point=f"{DATA_DIR}/dummy_script.py", + source_dir=f"{DATA_DIR}", + role=_ROLE, + ) + + register_step_args = model.register( + content_types=["text/csv"], + response_types=["text/csv"], + model_package_group_name="MyModelPackageGroup", + ) + + expected_output = { + "ModelPackageGroupName": "MyModelPackageGroup", + "InferenceSpecification": { + "Containers": [ + { + "Image": "fakeimage", + "Environment": { + "SAGEMAKER_PROGRAM": "dummy_script.py", + "SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code", + "SAGEMAKER_CONTAINER_LOG_LEVEL": "20", + "SAGEMAKER_REGION": "us-west-2", + }, + "ModelDataUrl": ParameterString( + name="ModelData", + default_value="s3://my-bucket/file", + ), + } + ], + "SupportedContentTypes": ["text/csv"], + "SupportedResponseMIMETypes": ["text/csv"], + "SupportedRealtimeInferenceInstanceTypes": None, + "SupportedTransformInstanceTypes": None, + }, + "CertifyForMarketplace": False, + "ModelApprovalStatus": "PendingManualApproval", + } + + assert register_step_args.create_model_package_request == expected_output diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index 9ac4fce39a..7e2e985845 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -3138,6 +3138,48 @@ def test_register_default_image(sagemaker_session): ) +def test_register_default_image_without_instance_type_args(sagemaker_session): + estimator = Estimator( + IMAGE_URI, + ROLE, + INSTANCE_COUNT, + INSTANCE_TYPE, + output_path=OUTPUT_PATH, + sagemaker_session=sagemaker_session, + ) + estimator.set_hyperparameters(**HYPERPARAMS) + estimator.fit({"train": "s3://bucket/training-prefix"}) + + model_package_name = "test-estimator-register-model" + content_types = ["application/json"] + response_types = ["application/json"] + + estimator.register( + content_types=content_types, + response_types=response_types, + model_package_name=model_package_name, + ) + sagemaker_session.create_model.assert_not_called() + + expected_create_model_package_request = { + "containers": [ + { + "Image": estimator.image_uri, + "ModelDataUrl": estimator.model_data, + } + ], + "content_types": content_types, + "response_types": response_types, + "inference_instances": None, + "transform_instances": None, + "model_package_name": model_package_name, + "marketplace_cert": False, + } + sagemaker_session.create_model_package_from_containers.assert_called_with( + **expected_create_model_package_request + ) + + def test_register_inference_image(sagemaker_session): estimator = Estimator( IMAGE_URI, diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 4360802308..2040ed5d80 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -2424,6 +2424,69 @@ def test_create_model_package_from_containers_all_args(sagemaker_session): sagemaker_session.sagemaker_client.create_model_package.assert_called_with(**expected_args) +def test_create_model_package_from_containers_without_instance_types(sagemaker_session): + model_package_name = "sagemaker-model-package" + containers = ["dummy-container"] + content_types = ["application/json"] + response_types = ["application/json"] + model_metrics = { + "Bias": { + "ContentType": "content-type", + "S3Uri": "s3://...", + } + } + drift_check_baselines = { + "Bias": { + "ConfigFile": { + "ContentType": "content-type", + "S3Uri": "s3://...", + } + } + } + + metadata_properties = { + "CommitId": "test-commit-id", + "Repository": "test-repository", + "GeneratedBy": "sagemaker-python-sdk", + "ProjectId": "unit-test", + } + marketplace_cert = (True,) + approval_status = ("Approved",) + description = "description" + customer_metadata_properties = {"key1": "value1"} + sagemaker_session.create_model_package_from_containers( + containers=containers, + content_types=content_types, + response_types=response_types, + model_package_name=model_package_name, + model_metrics=model_metrics, + metadata_properties=metadata_properties, + marketplace_cert=marketplace_cert, + approval_status=approval_status, + description=description, + drift_check_baselines=drift_check_baselines, + customer_metadata_properties=customer_metadata_properties, + ) + expected_args = { + "ModelPackageName": model_package_name, + "InferenceSpecification": { + "Containers": containers, + "SupportedContentTypes": content_types, + "SupportedResponseMIMETypes": response_types, + "SupportedRealtimeInferenceInstanceTypes": None, + "SupportedTransformInstanceTypes": None, + }, + "ModelPackageDescription": description, + "ModelMetrics": model_metrics, + "MetadataProperties": metadata_properties, + "CertifyForMarketplace": marketplace_cert, + "ModelApprovalStatus": approval_status, + "DriftCheckBaselines": drift_check_baselines, + "CustomerMetadataProperties": customer_metadata_properties, + } + sagemaker_session.sagemaker_client.create_model_package.assert_called_with(**expected_args) + + @pytest.fixture def feature_group_dummy_definitions(): return [{"FeatureName": "feature1", "FeatureType": "String"}]