diff --git a/src/sagemaker/chainer/model.py b/src/sagemaker/chainer/model.py index bafcfde3a8..9fce051454 100644 --- a/src/sagemaker/chainer/model.py +++ b/src/sagemaker/chainer/model.py @@ -174,6 +174,7 @@ def register( nearest_model_name: Optional[Union[str, PipelineVariable]] = None, data_input_configuration: Optional[Union[str, PipelineVariable]] = None, skip_model_validation: Optional[Union[str, PipelineVariable]] = None, + source_uri: Optional[Union[str, PipelineVariable]] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -223,6 +224,8 @@ def register( (default: None). skip_model_validation (str or PipelineVariable): Indicates if you want to skip model validation. Values can be "All" or "None" (default: None). + source_uri (str or PipelineVariable): The URI of the source for the model package + (default: None). Returns: str: A string of SageMaker Model Package ARN. @@ -262,6 +265,7 @@ def register( nearest_model_name=nearest_model_name, data_input_configuration=data_input_configuration, skip_model_validation=skip_model_validation, + source_uri=source_uri, ) def prepare_container_def( diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index 7ef367e485..501c826f82 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -1718,6 +1718,7 @@ def register( nearest_model_name=None, data_input_configuration=None, skip_model_validation=None, + source_uri=None, **kwargs, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -1765,6 +1766,7 @@ def register( data_input_configuration (str): Input object for the model (default: None). skip_model_validation (str): Indicates if you want to skip model validation. Values can be "All" or "None" (default: None). + source_uri (str): The URI of the source for the model package (default: None). **kwargs: Passed to invocation of ``create_model()``. Implementations may customize ``create_model()`` to accept ``**kwargs`` to customize model creation during deploy. For more, see the implementation docs. @@ -1809,6 +1811,7 @@ def register( nearest_model_name=nearest_model_name, data_input_configuration=data_input_configuration, skip_model_validation=skip_model_validation, + source_uri=source_uri, ) @property diff --git a/src/sagemaker/huggingface/model.py b/src/sagemaker/huggingface/model.py index efe6a85288..f71dca0ac8 100644 --- a/src/sagemaker/huggingface/model.py +++ b/src/sagemaker/huggingface/model.py @@ -360,6 +360,7 @@ def register( nearest_model_name: Optional[Union[str, PipelineVariable]] = None, data_input_configuration: Optional[Union[str, PipelineVariable]] = None, skip_model_validation: Optional[Union[str, PipelineVariable]] = None, + source_uri: Optional[Union[str, PipelineVariable]] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -410,6 +411,8 @@ def register( (default: None). skip_model_validation (str or PipelineVariable): Indicates if you want to skip model validation. Values can be "All" or "None" (default: None). + source_uri (str or PipelineVariable): The URI of the source for the model package + (default: None). Returns: A `sagemaker.model.ModelPackage` instance. @@ -457,6 +460,7 @@ def register( nearest_model_name=nearest_model_name, data_input_configuration=data_input_configuration, skip_model_validation=skip_model_validation, + source_uri=source_uri, ) def prepare_container_def( diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index 1b41cad714..40759d0f0b 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -598,6 +598,7 @@ def get_register_kwargs( nearest_model_name: Optional[str] = None, data_input_configuration: Optional[str] = None, skip_model_validation: Optional[str] = None, + source_uri: Optional[str] = None, ) -> JumpStartModelRegisterKwargs: """Returns kwargs required to call `register` on `sagemaker.estimator.Model` object.""" @@ -629,6 +630,7 @@ def get_register_kwargs( nearest_model_name=nearest_model_name, data_input_configuration=data_input_configuration, skip_model_validation=skip_model_validation, + source_uri=source_uri, ) model_specs = verify_model_region_and_return_specs( diff --git a/src/sagemaker/jumpstart/model.py b/src/sagemaker/jumpstart/model.py index 1742f860e4..8d007aed24 100644 --- a/src/sagemaker/jumpstart/model.py +++ b/src/sagemaker/jumpstart/model.py @@ -631,6 +631,7 @@ def register( nearest_model_name: Optional[Union[str, PipelineVariable]] = None, data_input_configuration: Optional[Union[str, PipelineVariable]] = None, skip_model_validation: Optional[Union[str, PipelineVariable]] = None, + source_uri: Optional[Union[str, PipelineVariable]] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -676,6 +677,8 @@ def register( (default: None). skip_model_validation (str or PipelineVariable): Indicates if you want to skip model validation. Values can be "All" or "None" (default: None). + source_uri (str or PipelineVariable): The URI of the source for the model package + (default: None). Returns: A `sagemaker.model.ModelPackage` instance. @@ -709,6 +712,7 @@ def register( nearest_model_name=nearest_model_name, data_input_configuration=data_input_configuration, skip_model_validation=skip_model_validation, + source_uri=source_uri, ) model_package = super(JumpStartModel, self).register(**register_kwargs.to_kwargs_dict()) diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 810d1c4cd3..43a25f3c12 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -1659,6 +1659,7 @@ class JumpStartModelRegisterKwargs(JumpStartKwargs): "nearest_model_name", "data_input_configuration", "skip_model_validation", + "source_uri", ] SERIALIZATION_EXCLUSION_SET = { @@ -1699,6 +1700,7 @@ def __init__( nearest_model_name: Optional[str] = None, data_input_configuration: Optional[str] = None, skip_model_validation: Optional[str] = None, + source_uri: Optional[str] = None, ) -> None: """Instantiates JumpStartModelRegisterKwargs object.""" @@ -1730,3 +1732,4 @@ def __init__( self.nearest_model_name = nearest_model_name self.data_input_configuration = data_input_configuration self.skip_model_validation = skip_model_validation + self.source_uri = source_uri diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index 5a2b27c54d..af08d1203f 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -77,7 +77,10 @@ ) from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements from sagemaker.enums import EndpointType -from sagemaker.session import get_add_model_package_inference_args +from sagemaker.session import ( + get_add_model_package_inference_args, + get_update_model_package_inference_args, +) # Setting LOGGER for backward compatibility, in case users import it... logger = LOGGER = logging.getLogger("sagemaker") @@ -423,6 +426,7 @@ def register( nearest_model_name: Optional[Union[str, PipelineVariable]] = None, data_input_configuration: Optional[Union[str, PipelineVariable]] = None, skip_model_validation: Optional[Union[str, PipelineVariable]] = None, + source_uri: Optional[Union[str, PipelineVariable]] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -472,17 +476,14 @@ def register( (default: None). skip_model_validation (str or PipelineVariable): Indicates if you want to skip model validation. Values can be "All" or "None" (default: None). + source_uri (str or PipelineVariable): The URI of the source for the model package + (default: None). Returns: A `sagemaker.model.ModelPackage` instance or pipeline step arguments in case the Model instance is built with :class:`~sagemaker.workflow.pipeline_context.PipelineSession` """ - if isinstance(self.model_data, dict): - raise ValueError( - "SageMaker Model Package currently cannot be created with ModelDataSource." - ) - if content_types is not None: self.content_types = content_types @@ -513,6 +514,12 @@ def register( "Image": self.image_uri, } + if isinstance(self.model_data, dict): + raise ValueError( + "Un-versioned SageMaker Model Package currently cannot be " + "created with ModelDataSource." + ) + if self.model_data is not None: container_def["ModelDataUrl"] = self.model_data @@ -536,6 +543,7 @@ def register( sample_payload_url=sample_payload_url, task=task, skip_model_validation=skip_model_validation, + source_uri=source_uri, ) model_package = self.sagemaker_session.create_model_package_from_containers( **model_pkg_args @@ -2040,8 +2048,9 @@ def __init__( endpoints use this role to access training data and model artifacts. After the endpoint is created, the inference code might use the IAM role, if it needs to access an AWS resource. - model_data (str): The S3 location of a SageMaker model data - ``.tar.gz`` file. Must be provided if algorithm_arn is provided. + model_data (str or dict[str, Any]): The S3 location of a SageMaker model data + ``.tar.gz`` file or a dictionary representing a ``ModelDataSource`` + object. Must be provided if algorithm_arn is provided. algorithm_arn (str): algorithm arn used to train the model, can be just the name if your account owns the algorithm. Must also provide ``model_data``. @@ -2050,11 +2059,6 @@ def __init__( ``model_data`` is not required. **kwargs: Additional kwargs passed to the Model constructor. """ - if isinstance(model_data, dict): - raise ValueError( - "Creating ModelPackage with ModelDataSource is currently not supported" - ) - super(ModelPackage, self).__init__( role=role, model_data=model_data, image_uri=None, **kwargs ) @@ -2222,6 +2226,74 @@ def update_customer_metadata(self, customer_metadata_properties: Dict[str, str]) sagemaker_session = self.sagemaker_session or sagemaker.Session() sagemaker_session.sagemaker_client.update_model_package(**update_metadata_args) + def update_inference_specification( + self, + containers: Dict = None, + image_uris: List[str] = None, + content_types: List[str] = None, + response_types: List[str] = None, + inference_instances: List[str] = None, + transform_instances: List[str] = None, + ): + """Inference specification to be set for the model package + + Args: + containers (dict): The Amazon ECR registry path of the Docker image + that contains the inference code. + image_uris (List[str]): The ECR path where inference code is stored. + content_types (list[str]): The supported MIME types + for the input data. + response_types (list[str]): The supported MIME types + for the output data. + inference_instances (list[str]): A list of the instance + types that are used to generate inferences in real-time (default: None). + transform_instances (list[str]): A list of the instance + types on which a transformation job can be run or on which an endpoint can be + deployed (default: None). + + """ + sagemaker_session = self.sagemaker_session or sagemaker.Session() + if (containers is not None) ^ (image_uris is None): + raise ValueError("Should have either containers or image_uris for inference.") + container_def = [] + if image_uris: + for uri in image_uris: + container_def.append( + { + "Image": uri, + } + ) + else: + container_def = containers + + model_package_update_args = get_update_model_package_inference_args( + model_package_arn=self.model_package_arn, + containers=container_def, + content_types=content_types, + response_types=response_types, + inference_instances=inference_instances, + transform_instances=transform_instances, + ) + + sagemaker_session.sagemaker_client.update_model_package(**model_package_update_args) + + def update_source_uri( + self, + source_uri: str, + ): + """Source uri to be set for the model package + + Args: + source_uri (str): The URI of the source for the model package. + + """ + update_source_uri_args = { + "ModelPackageArn": self.model_package_arn, + "SourceUri": source_uri, + } + sagemaker_session = self.sagemaker_session or sagemaker.Session() + sagemaker_session.sagemaker_client.update_model_package(**update_source_uri_args) + def remove_customer_metadata_properties( self, customer_metadata_properties_to_remove: List[str] ): diff --git a/src/sagemaker/mxnet/model.py b/src/sagemaker/mxnet/model.py index 8cd0ac6b65..714b0db945 100644 --- a/src/sagemaker/mxnet/model.py +++ b/src/sagemaker/mxnet/model.py @@ -176,6 +176,7 @@ def register( nearest_model_name: Optional[Union[str, PipelineVariable]] = None, data_input_configuration: Optional[Union[str, PipelineVariable]] = None, skip_model_validation: Optional[Union[str, PipelineVariable]] = None, + source_uri: Optional[Union[str, PipelineVariable]] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -225,6 +226,8 @@ def register( (default: None). skip_model_validation (str or PipelineVariable): Indicates if you want to skip model validation. Values can be "All" or "None" (default: None). + source_uri (str or PipelineVariable): The URI of the source for the model package + (default: None). Returns: A `sagemaker.model.ModelPackage` instance. @@ -264,6 +267,7 @@ def register( nearest_model_name=nearest_model_name, data_input_configuration=data_input_configuration, skip_model_validation=skip_model_validation, + source_uri=source_uri, ) def prepare_container_def( diff --git a/src/sagemaker/pipeline.py b/src/sagemaker/pipeline.py index a4b7feac69..3bfdb1a594 100644 --- a/src/sagemaker/pipeline.py +++ b/src/sagemaker/pipeline.py @@ -360,6 +360,7 @@ def register( nearest_model_name: Optional[Union[str, PipelineVariable]] = None, data_input_configuration: Optional[Union[str, PipelineVariable]] = None, skip_model_validation: Optional[Union[str, PipelineVariable]] = None, + source_uri: Optional[Union[str, PipelineVariable]] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -409,6 +410,8 @@ def register( (default: None). skip_model_validation (str or PipelineVariable): Indicates if you want to skip model validation. Values can be "All" or "None" (default: None). + source_uri (str or PipelineVariable): The URI of the source for the model package + (default: None). Returns: If ``sagemaker_session`` is a ``PipelineSession`` instance, returns pipeline step @@ -456,6 +459,7 @@ def register( sample_payload_url=sample_payload_url, task=task, skip_model_validation=skip_model_validation, + source_uri=source_uri, ) self.sagemaker_session.create_model_package_from_containers(**model_pkg_args) diff --git a/src/sagemaker/pytorch/model.py b/src/sagemaker/pytorch/model.py index fb731cabf4..f490e49375 100644 --- a/src/sagemaker/pytorch/model.py +++ b/src/sagemaker/pytorch/model.py @@ -178,6 +178,7 @@ def register( nearest_model_name: Optional[Union[str, PipelineVariable]] = None, data_input_configuration: Optional[Union[str, PipelineVariable]] = None, skip_model_validation: Optional[Union[str, PipelineVariable]] = None, + source_uri: Optional[Union[str, PipelineVariable]] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -227,6 +228,8 @@ def register( (default: None). skip_model_validation (str or PipelineVariable): Indicates if you want to skip model validation. Values can be "All" or "None" (default: None). + source_uri (str or PipelineVariable): The URI of the source for the model package + (default: None). Returns: A `sagemaker.model.ModelPackage` instance. @@ -266,6 +269,7 @@ def register( nearest_model_name=nearest_model_name, data_input_configuration=data_input_configuration, skip_model_validation=skip_model_validation, + source_uri=source_uri, ) def prepare_container_def( diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 325d3c7697..8d72051cc0 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -140,6 +140,7 @@ ) from sagemaker import exceptions from sagemaker.session_settings import SessionSettings +from sagemaker.utils import can_model_package_source_uri_autopopulate # Setting LOGGER for backward compatibility, in case users import it... logger = LOGGER = logging.getLogger("sagemaker") @@ -3969,14 +3970,19 @@ def create_model_package_from_algorithm(self, name, description, algorithm_arn, name (str): ModelPackage name description (str): Model Package description algorithm_arn (str): arn or name of the algorithm used for training. - model_data (str): s3 URI to the model artifacts produced by training + model_data (str or dict[str, Any]): s3 URI or a dictionary representing a + ``ModelDataSource`` to the model artifacts produced by training """ + sourceAlgorithm = {"AlgorithmName": algorithm_arn} + if isinstance(model_data, dict): + sourceAlgorithm["ModelDataSource"] = model_data + else: + sourceAlgorithm["ModelDataUrl"] = model_data + request = { "ModelPackageName": name, "ModelPackageDescription": description, - "SourceAlgorithmSpecification": { - "SourceAlgorithms": [{"AlgorithmName": algorithm_arn, "ModelDataUrl": model_data}] - }, + "SourceAlgorithmSpecification": {"SourceAlgorithms": [sourceAlgorithm]}, } try: logger.info("Creating model package with name: %s", name) @@ -4011,6 +4017,7 @@ def create_model_package_from_containers( sample_payload_url=None, task=None, skip_model_validation="None", + source_uri=None, ): """Get request dictionary for CreateModelPackage API. @@ -4047,6 +4054,7 @@ def create_model_package_from_containers( "CLASSIFICATION", "REGRESSION", "OTHER" (default: None). skip_model_validation (str): Indicates if you want to skip model validation. Values can be "All" or "None" (default: None). + source_uri (str): The URI of the source for the model package (default: None). """ if containers: # Containers are provided. Now we can merge missing entries from config. @@ -4103,6 +4111,7 @@ def create_model_package_from_containers( sample_payload_url=sample_payload_url, task=task, skip_model_validation=skip_model_validation, + source_uri=source_uri, ) def submit(request): @@ -4114,6 +4123,26 @@ def submit(request): ModelPackageGroupName=request["ModelPackageGroupName"] ) ) + if "SourceUri" in request and request["SourceUri"] is not None: + # Remove inference spec from request if the + # given source uri can lead to auto-population of it + if can_model_package_source_uri_autopopulate(request["SourceUri"]): + if "InferenceSpecification" in request: + del request["InferenceSpecification"] + return self.sagemaker_client.create_model_package(**request) + # If source uri can't autopopulate, + # first create model package with just the inference spec + # and then update model package with the source uri. + # Done this way because passing source uri and inference spec together + # in create/update model package is not allowed in the base sdk. + request_source_uri = request["SourceUri"] + del request["SourceUri"] + model_package = self.sagemaker_client.create_model_package(**request) + update_source_uri_args = { + "ModelPackageArn": model_package.get("ModelPackageArn"), + "SourceUri": request_source_uri, + } + return self.sagemaker_client.update_model_package(**update_source_uri_args) return self.sagemaker_client.create_model_package(**request) return self._intercept_create_request( @@ -6669,6 +6698,7 @@ def get_model_package_args( sample_payload_url=None, task=None, skip_model_validation=None, + source_uri=None, ): """Get arguments for create_model_package method. @@ -6707,6 +6737,7 @@ def get_model_package_args( "CLASSIFICATION", "REGRESSION", "OTHER" (default: None). skip_model_validation (str): Indicates if you want to skip model validation. Values can be "All" or "None" (default: None). + source_uri (str): The URI of the source for the model package (default: None). Returns: dict: A dictionary of method argument names and values. @@ -6761,6 +6792,8 @@ def get_model_package_args( model_package_args["task"] = task if skip_model_validation is not None: model_package_args["skip_model_validation"] = skip_model_validation + if source_uri is not None: + model_package_args["source_uri"] = source_uri return model_package_args @@ -6785,6 +6818,7 @@ def get_create_model_package_request( sample_payload_url=None, task=None, skip_model_validation="None", + source_uri=None, ): """Get request dictionary for CreateModelPackage API. @@ -6821,12 +6855,32 @@ def get_create_model_package_request( "CLASSIFICATION", "REGRESSION", "OTHER" (default: None). skip_model_validation (str): Indicates if you want to skip model validation. Values can be "All" or "None" (default: None). + source_uri (str): The URI of the source for the model package (default: None). """ if all([model_package_name, model_package_group_name]): raise ValueError( "model_package_name and model_package_group_name cannot be present at the " "same time." ) + if all([model_package_name, source_uri]): + raise ValueError( + "Un-versioned SageMaker Model Package currently cannot be " "created with source_uri." + ) + if (containers is not None) and all( + [ + model_package_name, + any( + [ + (("ModelDataSource" in c) and (c["ModelDataSource"] is not None)) + for c in containers + ] + ), + ] + ): + raise ValueError( + "Un-versioned SageMaker Model Package currently cannot be " + "created with ModelDataSource." + ) request_dict = {} if model_package_name is not None: request_dict["ModelPackageName"] = model_package_name @@ -6852,6 +6906,8 @@ def get_create_model_package_request( request_dict["SamplePayloadUrl"] = sample_payload_url if task is not None: request_dict["Task"] = task + if source_uri is not None: + request_dict["SourceUri"] = source_uri if containers is not None: inference_specification = { "Containers": containers, @@ -6900,6 +6956,65 @@ def get_create_model_package_request( return request_dict +def get_update_model_package_inference_args( + model_package_arn, + containers=None, + content_types=None, + response_types=None, + inference_instances=None, + transform_instances=None, +): + """Get request dictionary for UpdateModelPackage API for inference specification. + + Args: + model_package_arn (str): Arn for the model package. + containers (dict): The Amazon ECR registry path of the Docker image + that contains the inference code. + content_types (list[str]): The supported MIME types + for the input data. + response_types (list[str]): The supported MIME types + for the output data. + inference_instances (list[str]): A list of the instance + types that are used to generate inferences in real-time (default: None). + transform_instances (list[str]): A list of the instance + types on which a transformation job can be run or on which an endpoint can be + deployed (default: None). + """ + + request_dict = {} + if containers is not None: + inference_specification = { + "Containers": containers, + } + if content_types is not None: + inference_specification.update( + { + "SupportedContentTypes": content_types, + } + ) + if response_types is not None: + inference_specification.update( + { + "SupportedResponseMIMETypes": response_types, + } + ) + if inference_instances is not None: + inference_specification.update( + { + "SupportedRealtimeInferenceInstanceTypes": inference_instances, + } + ) + if transform_instances is not None: + inference_specification.update( + { + "SupportedTransformInstanceTypes": transform_instances, + } + ) + request_dict["InferenceSpecification"] = inference_specification + request_dict.update({"ModelPackageArn": model_package_arn}) + return request_dict + + def get_add_model_package_inference_args( model_package_arn, name, diff --git a/src/sagemaker/sklearn/model.py b/src/sagemaker/sklearn/model.py index 195a6a3a57..27833c1d9c 100644 --- a/src/sagemaker/sklearn/model.py +++ b/src/sagemaker/sklearn/model.py @@ -171,6 +171,7 @@ def register( nearest_model_name: Optional[Union[str, PipelineVariable]] = None, data_input_configuration: Optional[Union[str, PipelineVariable]] = None, skip_model_validation: Optional[Union[str, PipelineVariable]] = None, + source_uri: Optional[Union[str, PipelineVariable]] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -220,6 +221,8 @@ def register( (default: None). skip_model_validation (str or PipelineVariable): Indicates if you want to skip model validation. Values can be "All" or "None" (default: None). + source_uri (str or PipelineVariable): The URI of the source for the model package + (default: None). Returns: A `sagemaker.model.ModelPackage` instance. @@ -259,6 +262,7 @@ def register( nearest_model_name=nearest_model_name, data_input_configuration=data_input_configuration, skip_model_validation=skip_model_validation, + source_uri=source_uri, ) def prepare_container_def( diff --git a/src/sagemaker/tensorflow/model.py b/src/sagemaker/tensorflow/model.py index 1b35afbe7c..77f162207c 100644 --- a/src/sagemaker/tensorflow/model.py +++ b/src/sagemaker/tensorflow/model.py @@ -233,6 +233,7 @@ def register( nearest_model_name: Optional[Union[str, PipelineVariable]] = None, data_input_configuration: Optional[Union[str, PipelineVariable]] = None, skip_model_validation: Optional[Union[str, PipelineVariable]] = None, + source_uri: Optional[Union[str, PipelineVariable]] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -282,6 +283,8 @@ def register( (default: None). skip_model_validation (str or PipelineVariable): Indicates if you want to skip model validation. Values can be "All" or "None" (default: None). + source_uri (str or PipelineVariable): The URI of the source for the model package + (default: None). Returns: A `sagemaker.model.ModelPackage` instance. @@ -321,6 +324,7 @@ def register( nearest_model_name=nearest_model_name, data_input_configuration=data_input_configuration, skip_model_validation=skip_model_validation, + source_uri=source_uri, ) def deploy( diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index 115b8b258d..7896aac150 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -48,6 +48,10 @@ from sagemaker.workflow.entities import PipelineVariable ECR_URI_PATTERN = r"^(\d+)(\.)dkr(\.)ecr(\.)(.+)(\.)(.*)(/)(.*:.*)$" +MODEL_PACKAGE_ARN_PATTERN = ( + r"arn:aws([a-z\-]*)?:sagemaker:([a-z0-9\-]*):([0-9]{12}):model-package/(.*)" +) +MODEL_ARN_PATTERN = r"arn:aws([a-z\-]*):sagemaker:([a-z0-9\-]*):([0-9]{12}):model/(.*)" MAX_BUCKET_PATHS_COUNT = 5 S3_PREFIX = "s3://" HTTP_PREFIX = "http://" @@ -1581,3 +1585,17 @@ def custom_extractall_tarfile(tar, extract_path): tar.extractall(path=extract_path, filter="data") else: tar.extractall(path=extract_path, members=_get_safe_members(tar)) + + +def can_model_package_source_uri_autopopulate(source_uri: str): + """Checks if the source_uri can lead to auto-population of information in the Model registry. + + Args: + source_uri (str): The source uri. + + Returns: + bool: True if the source_uri can lead to auto-population, False otherwise. + """ + return bool( + re.match(MODEL_PACKAGE_ARN_PATTERN, source_uri) or re.match(MODEL_ARN_PATTERN, source_uri) + ) diff --git a/src/sagemaker/workflow/_utils.py b/src/sagemaker/workflow/_utils.py index 1fafa646bf..841cd68083 100644 --- a/src/sagemaker/workflow/_utils.py +++ b/src/sagemaker/workflow/_utils.py @@ -328,6 +328,7 @@ def __init__( sample_payload_url=None, task=None, skip_model_validation=None, + source_uri=None, **kwargs, ): """Constructor of a register model step. @@ -379,6 +380,7 @@ def __init__( "CLASSIFICATION", "REGRESSION", "OTHER" (default: None). skip_model_validation (str): Indicates if you want to skip model validation. Values can be "All" or "None" (default: None). + source_uri (str): The URI of the source for the model package (default: None). **kwargs: additional arguments to `create_model`. """ super(_RegisterModelStep, self).__init__( @@ -415,6 +417,7 @@ def __init__( self.kwargs = kwargs self.container_def_list = container_def_list self.skip_model_validation = skip_model_validation + self.source_uri = source_uri self._properties = Properties( step_name=name, step=self, shape_name="DescribeModelPackageOutput" @@ -489,6 +492,7 @@ def arguments(self) -> RequestType: sample_payload_url=self.sample_payload_url, task=self.task, skip_model_validation=self.skip_model_validation, + source_uri=self.source_uri, ) request_dict = get_create_model_package_request(**model_package_args) diff --git a/src/sagemaker/workflow/step_collections.py b/src/sagemaker/workflow/step_collections.py index d48bf7c307..0eedf4aa96 100644 --- a/src/sagemaker/workflow/step_collections.py +++ b/src/sagemaker/workflow/step_collections.py @@ -96,6 +96,7 @@ def __init__( nearest_model_name=None, data_input_configuration=None, skip_model_validation=None, + source_uri=None, **kwargs, ): """Construct steps `_RepackModelStep` and `_RegisterModelStep` based on the estimator. @@ -153,6 +154,7 @@ def __init__( data_input_configuration (str): Input object for the model (default: None). skip_model_validation (str): Indicates if you want to skip model validation. Values can be "All" or "None" (default: None). + source_uri (str): The URI of the source for the model package (default: None). **kwargs: additional arguments to `create_model`. """ @@ -291,6 +293,7 @@ def __init__( sample_payload_url=sample_payload_url, task=task, skip_model_validation=skip_model_validation, + source_uri=source_uri, **kwargs, ) if not repack_model: diff --git a/src/sagemaker/xgboost/model.py b/src/sagemaker/xgboost/model.py index 74776f8f72..8101f32721 100644 --- a/src/sagemaker/xgboost/model.py +++ b/src/sagemaker/xgboost/model.py @@ -159,6 +159,7 @@ def register( nearest_model_name: Optional[Union[str, PipelineVariable]] = None, data_input_configuration: Optional[Union[str, PipelineVariable]] = None, skip_model_validation: Optional[Union[str, PipelineVariable]] = None, + source_uri: Optional[Union[str, PipelineVariable]] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -208,6 +209,8 @@ def register( (default: None). skip_model_validation (str or PipelineVariable): Indicates if you want to skip model validation. Values can be "All" or "None" (default: None). + source_uri (str or PipelineVariable): The URI of the source for the model package + (default: None). Returns: str: A string of SageMaker Model Package ARN. @@ -247,6 +250,7 @@ def register( nearest_model_name=nearest_model_name, data_input_configuration=data_input_configuration, skip_model_validation=skip_model_validation, + source_uri=source_uri, ) def prepare_container_def( diff --git a/tests/integ/test_model_package.py b/tests/integ/test_model_package.py index 1554825fc2..914c5db7ed 100644 --- a/tests/integ/test_model_package.py +++ b/tests/integ/test_model_package.py @@ -18,6 +18,8 @@ from tests.integ import DATA_DIR from sagemaker.xgboost import XGBoostModel from sagemaker import image_uris +from sagemaker.session import get_execution_role +from sagemaker.model import ModelPackage _XGBOOST_PATH = os.path.join(DATA_DIR, "xgboost_abalone") @@ -104,3 +106,207 @@ def test_inference_specification_addition(sagemaker_session): sagemaker_session.sagemaker_client.delete_model_package_group( ModelPackageGroupName=model_group_name ) + + +def test_update_inference_specification(sagemaker_session): + model_group_name = unique_name_from_base("test-model-group") + source_uri = "dummy source uri" + + sagemaker_session.sagemaker_client.create_model_package_group( + ModelPackageGroupName=model_group_name + ) + + model_package = sagemaker_session.sagemaker_client.create_model_package( + ModelPackageGroupName=model_group_name, SourceUri=source_uri + ) + + mp = ModelPackage( + role=get_execution_role(sagemaker_session), + model_package_arn=model_package["ModelPackageArn"], + sagemaker_session=sagemaker_session, + ) + + xgb_image = image_uris.retrieve( + "xgboost", sagemaker_session.boto_region_name, version="1", image_scope="inference" + ) + + mp.update_inference_specification(image_uris=[xgb_image]) + + desc_model_package = sagemaker_session.sagemaker_client.describe_model_package( + ModelPackageName=model_package["ModelPackageArn"] + ) + + sagemaker_session.sagemaker_client.delete_model_package( + ModelPackageName=model_package["ModelPackageArn"] + ) + sagemaker_session.sagemaker_client.delete_model_package_group( + ModelPackageGroupName=model_group_name + ) + + assert len(desc_model_package["InferenceSpecification"]["Containers"]) == 1 + assert desc_model_package["InferenceSpecification"]["Containers"][0]["Image"] == xgb_image + + +def test_update_source_uri(sagemaker_session): + model_group_name = unique_name_from_base("test-model-group") + source_uri = "dummy source uri" + + sagemaker_session.sagemaker_client.create_model_package_group( + ModelPackageGroupName=model_group_name + ) + + xgb_model_data_s3 = sagemaker_session.upload_data( + path=os.path.join(_XGBOOST_PATH, "xgb_model.tar.gz"), + key_prefix="integ-test-data/xgboost/model", + ) + model = XGBoostModel( + model_data=xgb_model_data_s3, framework_version="1.3-1", sagemaker_session=sagemaker_session + ) + + model_package = model.register( + content_types=["text/csv"], + response_types=["text/csv"], + inference_instances=["ml.m5.large"], + transform_instances=["ml.m5.large"], + model_package_group_name=model_group_name, + ) + + desc_model_package = sagemaker_session.sagemaker_client.describe_model_package( + ModelPackageName=model_package.model_package_arn + ) + + model_package.update_source_uri(source_uri=source_uri) + desc_model_package = sagemaker_session.sagemaker_client.describe_model_package( + ModelPackageName=model_package.model_package_arn + ) + + assert desc_model_package["SourceUri"] == source_uri + + +def test_clone_model_package_using_source_uri(sagemaker_session): + model_group_name = unique_name_from_base("test-model-group") + + sagemaker_session.sagemaker_client.create_model_package_group( + ModelPackageGroupName=model_group_name + ) + + xgb_model_data_s3 = sagemaker_session.upload_data( + path=os.path.join(_XGBOOST_PATH, "xgb_model.tar.gz"), + key_prefix="integ-test-data/xgboost/model", + ) + model = XGBoostModel( + model_data=xgb_model_data_s3, framework_version="1.3-1", sagemaker_session=sagemaker_session + ) + + model_package = model.register( + content_types=["text/csv"], + response_types=["text/csv"], + inference_instances=["ml.m5.large"], + transform_instances=["ml.m5.large"], + model_package_group_name=model_group_name, + source_uri="dummy-source-uri", + ) + + desc_model_package = sagemaker_session.sagemaker_client.describe_model_package( + ModelPackageName=model_package.model_package_arn + ) + + model2 = XGBoostModel( + model_data=xgb_model_data_s3, framework_version="1.3-1", sagemaker_session=sagemaker_session + ) + cloned_model_package = model2.register( + content_types=["text/csv"], + response_types=["text/csv"], + inference_instances=["ml.m5.large"], + transform_instances=["ml.m5.large"], + model_package_group_name=model_group_name, + source_uri=model_package.model_package_arn, + ) + + desc_cloned_model_package = sagemaker_session.sagemaker_client.describe_model_package( + ModelPackageName=cloned_model_package.model_package_arn + ) + + sagemaker_session.sagemaker_client.delete_model_package( + ModelPackageName=model_package.model_package_arn + ) + sagemaker_session.sagemaker_client.delete_model_package( + ModelPackageName=cloned_model_package.model_package_arn + ) + sagemaker_session.sagemaker_client.delete_model_package_group( + ModelPackageGroupName=model_group_name + ) + + assert len(desc_cloned_model_package["InferenceSpecification"]["Containers"]) == len( + desc_model_package["InferenceSpecification"]["Containers"] + ) + assert len( + desc_cloned_model_package["InferenceSpecification"]["SupportedTransformInstanceTypes"] + ) == len(desc_model_package["InferenceSpecification"]["SupportedTransformInstanceTypes"]) + assert len( + desc_cloned_model_package["InferenceSpecification"][ + "SupportedRealtimeInferenceInstanceTypes" + ] + ) == len( + desc_model_package["InferenceSpecification"]["SupportedRealtimeInferenceInstanceTypes"] + ) + assert len(desc_cloned_model_package["InferenceSpecification"]["SupportedContentTypes"]) == len( + desc_model_package["InferenceSpecification"]["SupportedContentTypes"] + ) + assert len( + desc_cloned_model_package["InferenceSpecification"]["SupportedResponseMIMETypes"] + ) == len(desc_model_package["InferenceSpecification"]["SupportedResponseMIMETypes"]) + assert desc_cloned_model_package["SourceUri"] == model_package.model_package_arn + + +def test_register_model_using_source_uri(sagemaker_session): + model_name = unique_name_from_base("test-model") + model_group_name = unique_name_from_base("test-model-group") + + sagemaker_session.sagemaker_client.create_model_package_group( + ModelPackageGroupName=model_group_name + ) + + xgb_model_data_s3 = sagemaker_session.upload_data( + path=os.path.join(_XGBOOST_PATH, "xgb_model.tar.gz"), + key_prefix="integ-test-data/xgboost/model", + ) + + model = XGBoostModel( + model_data=xgb_model_data_s3, + framework_version="1.3-1", + sagemaker_session=sagemaker_session, + role=get_execution_role(sagemaker_session), + ) + + model.name = model_name + model.create() + desc_model = sagemaker_session.sagemaker_client.describe_model(ModelName=model_name) + + model = XGBoostModel( + model_data=xgb_model_data_s3, + framework_version="1.3-1", + sagemaker_session=sagemaker_session, + role=get_execution_role(sagemaker_session), + ) + registered_model_package = model.register( + inference_instances=["ml.m5.xlarge"], + model_package_group_name=model_group_name, + source_uri=desc_model["ModelArn"], + ) + + desc_registered_model_package = sagemaker_session.sagemaker_client.describe_model_package( + ModelPackageName=registered_model_package.model_package_arn + ) + + sagemaker_session.sagemaker_client.delete_model(ModelName=model_name) + sagemaker_session.sagemaker_client.delete_model_package( + ModelPackageName=registered_model_package.model_package_arn + ) + sagemaker_session.sagemaker_client.delete_model_package_group( + ModelPackageGroupName=model_group_name + ) + + assert desc_registered_model_package["SourceUri"] == desc_model["ModelArn"] + assert "InferenceSpecification" in desc_registered_model_package + assert desc_registered_model_package["InferenceSpecification"] is not None diff --git a/tests/unit/sagemaker/model/test_model.py b/tests/unit/sagemaker/model/test_model.py index de86fcf99a..c0b18a3eb3 100644 --- a/tests/unit/sagemaker/model/test_model.py +++ b/tests/unit/sagemaker/model/test_model.py @@ -1118,7 +1118,46 @@ def test_register_calls_model_package_args(get_model_package_args, sagemaker_ses get_model_package_args""" -def test_register_calls_model_data_source_not_supported(sagemaker_session): +@patch("sagemaker.get_model_package_args") +def test_register_passes_source_uri_to_model_package_args( + get_model_package_args, sagemaker_session +): + source_dir = "s3://blah/blah/blah" + source_uri = "dummy_source_uri" + t = Model( + entry_point=ENTRY_POINT_INFERENCE, + role=ROLE, + sagemaker_session=sagemaker_session, + source_dir=source_dir, + image_uri=IMAGE_URI, + model_data=MODEL_DATA, + ) + + t.register( + SUPPORTED_CONTENT_TYPES, + SUPPORTED_RESPONSE_MIME_TYPES, + SUPPORTED_REALTIME_INFERENCE_INSTANCE_TYPES, + SUPPORTED_BATCH_TRANSFORM_INSTANCE_TYPES, + marketplace_cert=True, + description=MODEL_DESCRIPTION, + model_package_name=MODEL_NAME, + validation_specification=VALIDATION_SPECIFICATION, + source_uri=source_uri, + ) + + # check that the kwarg source_uri was passed to the internal method 'get_model_package_args' + assert ( + "source_uri" in get_model_package_args.call_args_list[0][1] + ), "source_uri kwarg was not passed to get_model_package_args" + + # check that the kwarg source_uri is identical to the one passed into the method 'register' + assert ( + source_uri == get_model_package_args.call_args_list[0][1]["source_uri"] + ), """source_uri from model.register method is not identical to source_uri from + get_model_package_args""" + + +def test_register_with_model_data_source_not_supported_for_unversioned_model(sagemaker_session): source_dir = "s3://blah/blah/blah" t = Model( entry_point=ENTRY_POINT_INFERENCE, @@ -1137,7 +1176,7 @@ def test_register_calls_model_data_source_not_supported(sagemaker_session): with pytest.raises( ValueError, - match="SageMaker Model Package currently cannot be created with ModelDataSource.", + match="Un-versioned SageMaker Model Package currently cannot be created with ModelDataSource.", ): t.register( SUPPORTED_CONTENT_TYPES, @@ -1151,6 +1190,51 @@ def test_register_calls_model_data_source_not_supported(sagemaker_session): ) +@patch("sagemaker.get_model_package_args") +def test_register_with_model_data_source_supported_for_versioned_model( + get_model_package_args, sagemaker_session +): + source_dir = "s3://blah/blah/blah" + model_data_source = { + "S3DataSource": { + "S3Uri": "s3://bucket/model/prefix/", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + } + t = Model( + entry_point=ENTRY_POINT_INFERENCE, + role=ROLE, + sagemaker_session=sagemaker_session, + source_dir=source_dir, + image_uri=IMAGE_URI, + model_data=model_data_source, + ) + + t.register( + SUPPORTED_CONTENT_TYPES, + SUPPORTED_RESPONSE_MIME_TYPES, + SUPPORTED_REALTIME_INFERENCE_INSTANCE_TYPES, + SUPPORTED_BATCH_TRANSFORM_INSTANCE_TYPES, + marketplace_cert=True, + description=MODEL_DESCRIPTION, + model_package_group_name="dummy_group", + validation_specification=VALIDATION_SPECIFICATION, + ) + + # check that the kwarg container_def_list was set for the internal method 'get_model_package_args' + assert ( + "container_def_list" in get_model_package_args.call_args_list[0][1] + ), "container_def_list kwarg was not set to get_model_package_args" + + # check that the kwarg container in container_def_list contains the model data source + assert ( + model_data_source + == get_model_package_args.call_args_list[0][1]["container_def_list"][0]["ModelDataSource"] + ), """model_data_source from model.register method is not identical to ModelDataSource from + get_model_package_args""" + + @patch("sagemaker.utils.repack_model") def test_model_local_download_dir(repack_model, sagemaker_session): diff --git a/tests/unit/sagemaker/model/test_model_package.py b/tests/unit/sagemaker/model/test_model_package.py index def7ddf5e3..9bfc830a75 100644 --- a/tests/unit/sagemaker/model/test_model_package.py +++ b/tests/unit/sagemaker/model/test_model_package.py @@ -223,22 +223,21 @@ def test_create_sagemaker_model_include_tags(sagemaker_session): ) -def test_model_package_model_data_source_not_supported(sagemaker_session): - with pytest.raises( - ValueError, match="Creating ModelPackage with ModelDataSource is currently not supported" - ): - ModelPackage( - role="role", - model_package_arn="my-model-package", - model_data={ - "S3DataSource": { - "S3Uri": "s3://bucket/model/prefix/", - "S3DataType": "S3Prefix", - "CompressionType": "None", - } - }, - sagemaker_session=sagemaker_session, - ) +def test_model_package_model_data_source_supported(sagemaker_session): + model_data_source = { + "S3DataSource": { + "S3Uri": "s3://bucket/model/prefix/", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + } + model_package = ModelPackage( + role="role", + model_package_arn="my-model-package", + model_data=model_data_source, + sagemaker_session=sagemaker_session, + ) + assert model_package.model_data == model_package.model_data @patch("sagemaker.utils.name_from_base") @@ -399,3 +398,47 @@ def test_add_inference_specification(sagemaker_session): } ], ) + + +def test_update_inference_specification(sagemaker_session): + model_package = ModelPackage( + role="role", + model_package_arn=MODEL_PACKAGE_VERSIONED_ARN, + sagemaker_session=sagemaker_session, + ) + + image_uris = ["image_uri"] + + containers = [{"Image": "image_uri"}] + + try: + model_package.update_inference_specification(image_uris=image_uris, containers=containers) + except ValueError as ve: + assert "Should have either containers or image_uris for inference." in str(ve) + + try: + model_package.update_inference_specification() + except ValueError as ve: + assert "Should have either containers or image_uris for inference." in str(ve) + + model_package.update_inference_specification(image_uris=image_uris) + + sagemaker_session.sagemaker_client.update_model_package.assert_called_with( + ModelPackageArn=MODEL_PACKAGE_VERSIONED_ARN, + InferenceSpecification={ + "Containers": [{"Image": "image_uri"}], + }, + ) + + +def test_update_source_uri(sagemaker_session): + source_uri = "dummy_source_uri" + model_package = ModelPackage( + role="role", + model_package_arn=MODEL_PACKAGE_VERSIONED_ARN, + sagemaker_session=sagemaker_session, + ) + model_package.update_source_uri(source_uri=source_uri) + sagemaker_session.sagemaker_client.update_model_package.assert_called_with( + ModelPackageArn=MODEL_PACKAGE_VERSIONED_ARN, SourceUri=source_uri + ) diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 358cabd0f8..ee11f5a1f3 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -5119,6 +5119,233 @@ def test_create_model_package_with_sagemaker_config_injection(sagemaker_session) sagemaker_session.sagemaker_client.create_model_package.assert_called_with(**expected_args) +def test_create_model_package_from_containers_with_source_uri_and_inference_spec(sagemaker_session): + model_package_group_name = "sagemaker-model-package-group" + containers = ["dummy-container"] + content_types = ["application/json"] + response_types = ["application/json"] + inference_instances = ["ml.m4.xlarge"] + transform_instances = ["ml.m4.xlarget"] + marketplace_cert = (False,) + approval_status = ("Approved",) + skip_model_validation = "All" + source_uri = "dummy-source-uri" + + created_versioned_mp_arn = ( + "arn:aws:sagemaker:us-west-2:123456789123:model-package/unit-test-package-version/1" + ) + sagemaker_session.sagemaker_client.create_model_package = Mock( + return_value={"ModelPackageArn": created_versioned_mp_arn} + ) + + sagemaker_session.create_model_package_from_containers( + model_package_group_name=model_package_group_name, + containers=containers, + content_types=content_types, + response_types=response_types, + inference_instances=inference_instances, + transform_instances=transform_instances, + marketplace_cert=marketplace_cert, + approval_status=approval_status, + skip_model_validation=skip_model_validation, + source_uri=source_uri, + ) + expected_create_mp_args = { + "ModelPackageGroupName": model_package_group_name, + "InferenceSpecification": { + "Containers": containers, + "SupportedContentTypes": content_types, + "SupportedResponseMIMETypes": response_types, + "SupportedRealtimeInferenceInstanceTypes": inference_instances, + "SupportedTransformInstanceTypes": transform_instances, + }, + "CertifyForMarketplace": marketplace_cert, + "ModelApprovalStatus": approval_status, + "SkipModelValidation": skip_model_validation, + } + + sagemaker_session.sagemaker_client.create_model_package.assert_called_once_with( + **expected_create_mp_args + ) + expected_update_mp_args = { + "ModelPackageArn": created_versioned_mp_arn, + "SourceUri": source_uri, + } + sagemaker_session.sagemaker_client.update_model_package.assert_called_once_with( + **expected_update_mp_args + ) + + +def test_create_model_package_from_containers_with_source_uri_for_unversioned_mp(sagemaker_session): + model_package_name = "sagemaker-model-package" + content_types = ["application/json"] + response_types = ["application/json"] + inference_instances = ["ml.m4.xlarge"] + transform_instances = ["ml.m4.xlarget"] + marketplace_cert = (False,) + approval_status = ("Approved",) + skip_model_validation = "All" + source_uri = "dummy-source-uri" + + with pytest.raises( + ValueError, + match="Un-versioned SageMaker Model Package currently cannot be created with source_uri.", + ): + sagemaker_session.create_model_package_from_containers( + model_package_name=model_package_name, + content_types=content_types, + response_types=response_types, + inference_instances=inference_instances, + transform_instances=transform_instances, + marketplace_cert=marketplace_cert, + approval_status=approval_status, + skip_model_validation=skip_model_validation, + source_uri=source_uri, + ) + + +def test_create_model_package_from_containers_with_source_uri_for_versioned_mp(sagemaker_session): + model_package_name = "sagemaker-model-package" + model_data_source = { + "S3DataSource": { + "S3Uri": "s3://bucket/model/prefix/", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + } + containers = [{"Image": "dummy-image", "ModelDataSource": model_data_source}] + content_types = ["application/json"] + response_types = ["application/json"] + inference_instances = ["ml.m4.xlarge"] + transform_instances = ["ml.m4.xlarget"] + marketplace_cert = (False,) + approval_status = ("Approved",) + skip_model_validation = "All" + + with pytest.raises( + ValueError, + match="Un-versioned SageMaker Model Package currently cannot be created with ModelDataSource.", + ): + sagemaker_session.create_model_package_from_containers( + model_package_name=model_package_name, + containers=containers, + content_types=content_types, + response_types=response_types, + inference_instances=inference_instances, + transform_instances=transform_instances, + marketplace_cert=marketplace_cert, + approval_status=approval_status, + skip_model_validation=skip_model_validation, + ) + + +def test_create_model_package_from_containers_with_source_uri_set_to_mp(sagemaker_session): + model_package_group_name = "sagemaker-model-package-group" + containers = ["dummy-container"] + content_types = ["application/json"] + response_types = ["application/json"] + inference_instances = ["ml.m4.xlarge"] + transform_instances = ["ml.m4.xlarget"] + marketplace_cert = (False,) + approval_status = ("Approved",) + skip_model_validation = "All" + source_uri = "arn:aws:sagemaker:us-west-2:123456789123:model-package/existing-mp" + + created_versioned_mp_arn = ( + "arn:aws:sagemaker:us-west-2:123456789123:model-package/unit-test-package-version/1" + ) + sagemaker_session.sagemaker_client.create_model_package = Mock( + return_value={"ModelPackageArn": created_versioned_mp_arn} + ) + + sagemaker_session.create_model_package_from_containers( + model_package_group_name=model_package_group_name, + containers=containers, + content_types=content_types, + response_types=response_types, + inference_instances=inference_instances, + transform_instances=transform_instances, + marketplace_cert=marketplace_cert, + approval_status=approval_status, + skip_model_validation=skip_model_validation, + source_uri=source_uri, + ) + expected_create_mp_args = { + "ModelPackageGroupName": model_package_group_name, + "CertifyForMarketplace": marketplace_cert, + "ModelApprovalStatus": approval_status, + "SkipModelValidation": skip_model_validation, + "SourceUri": source_uri, + } + + sagemaker_session.sagemaker_client.create_model_package.assert_called_once_with( + **expected_create_mp_args + ) + sagemaker_session.sagemaker_client.update_model_package.assert_not_called() + + +def test_create_model_package_from_algorithm_with_model_data_source(sagemaker_session): + model_package_name = "sagemaker-model-package" + description = "dummy description" + model_data_source = { + "S3DataSource": { + "S3Uri": "s3://bucket/model/prefix/", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + } + algorithm_arn = "arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees" + sagemaker_session.create_model_package_from_algorithm( + algorithm_arn=algorithm_arn, + model_data=model_data_source, + name=model_package_name, + description=description, + ) + expected_create_mp_args = { + "ModelPackageName": model_package_name, + "ModelPackageDescription": description, + "SourceAlgorithmSpecification": { + "SourceAlgorithms": [ + { + "AlgorithmName": algorithm_arn, + "ModelDataSource": model_data_source, + } + ] + }, + } + sagemaker_session.sagemaker_client.create_model_package.assert_called_once_with( + **expected_create_mp_args + ) + + +def test_create_model_package_from_algorithm_with_model_data_url(sagemaker_session): + model_package_name = "sagemaker-model-package" + description = "dummy description" + model_data_url = "s3://bucket/key" + algorithm_arn = "arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees" + sagemaker_session.create_model_package_from_algorithm( + algorithm_arn=algorithm_arn, + model_data=model_data_url, + name=model_package_name, + description=description, + ) + expected_create_mp_args = { + "ModelPackageName": model_package_name, + "ModelPackageDescription": description, + "SourceAlgorithmSpecification": { + "SourceAlgorithms": [ + { + "AlgorithmName": algorithm_arn, + "ModelDataUrl": model_data_url, + } + ] + }, + } + sagemaker_session.sagemaker_client.create_model_package.assert_called_once_with( + **expected_create_mp_args + ) + + def test_create_model_package_from_containers_all_args(sagemaker_session): model_package_name = "sagemaker-model-package" containers = ["dummy-container"] diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 852ef8b153..a83f1b995d 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -46,6 +46,7 @@ _is_bad_path, _is_bad_link, custom_extractall_tarfile, + can_model_package_source_uri_autopopulate, ) from tests.unit.sagemaker.workflow.helpers import CustomStep from sagemaker.workflow.parameters import ParameterString, ParameterInteger @@ -1796,3 +1797,15 @@ def test_is_bad_link(link_name, base, expected): def test_custom_extractall_tarfile(mock_custom_tarfile, data_filter, expected_extract_path): tar = mock_custom_tarfile(data_filter) custom_extractall_tarfile(tar, "/extract/path") + + +def test_can_model_package_source_uri_autopopulate(): + test_data = [ + ("arn:aws:sagemaker:us-west-2:012345678912:model-package/dummy-mpg/1", True), + ("arn:aws:sagemaker:us-west-2:012345678912:model-package/dummy-mp", True), + ("arn:aws:sagemaker:us-west-2:012345678912:model/dummy-model", True), + ("https://path/to/model", False), + ("/home/path/to/model", False), + ] + for source_uri, expected in test_data: + assert can_model_package_source_uri_autopopulate(source_uri) == expected