Skip to content

feature: Add ModelDataSource and SourceUri support for model package. #4492

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/sagemaker/chainer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def register(
nearest_model_name: Optional[Union[str, PipelineVariable]] = None,
data_input_configuration: Optional[Union[str, PipelineVariable]] = None,
skip_model_validation: Optional[Union[str, PipelineVariable]] = None,
source_uri: Optional[Union[str, PipelineVariable]] = None,
):
"""Creates a model package for creating SageMaker models or listing on Marketplace.

Expand Down Expand Up @@ -223,6 +224,8 @@ def register(
(default: None).
skip_model_validation (str or PipelineVariable): Indicates if you want to skip model
validation. Values can be "All" or "None" (default: None).
source_uri (str or PipelineVariable): The URI of the source for the model package
(default: None).

Returns:
str: A string of SageMaker Model Package ARN.
Expand Down Expand Up @@ -262,6 +265,7 @@ def register(
nearest_model_name=nearest_model_name,
data_input_configuration=data_input_configuration,
skip_model_validation=skip_model_validation,
source_uri=source_uri,
)

def prepare_container_def(
Expand Down
3 changes: 3 additions & 0 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1718,6 +1718,7 @@ def register(
nearest_model_name=None,
data_input_configuration=None,
skip_model_validation=None,
source_uri=None,
**kwargs,
):
"""Creates a model package for creating SageMaker models or listing on Marketplace.
Expand Down Expand Up @@ -1765,6 +1766,7 @@ def register(
data_input_configuration (str): Input object for the model (default: None).
skip_model_validation (str): Indicates if you want to skip model validation.
Values can be "All" or "None" (default: None).
source_uri (str): The URI of the source for the model package (default: None).
**kwargs: Passed to invocation of ``create_model()``. Implementations may customize
``create_model()`` to accept ``**kwargs`` to customize model creation during
deploy. For more, see the implementation docs.
Expand Down Expand Up @@ -1809,6 +1811,7 @@ def register(
nearest_model_name=nearest_model_name,
data_input_configuration=data_input_configuration,
skip_model_validation=skip_model_validation,
source_uri=source_uri,
)

@property
Expand Down
4 changes: 4 additions & 0 deletions src/sagemaker/huggingface/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,7 @@ def register(
nearest_model_name: Optional[Union[str, PipelineVariable]] = None,
data_input_configuration: Optional[Union[str, PipelineVariable]] = None,
skip_model_validation: Optional[Union[str, PipelineVariable]] = None,
source_uri: Optional[Union[str, PipelineVariable]] = None,
):
"""Creates a model package for creating SageMaker models or listing on Marketplace.

Expand Down Expand Up @@ -410,6 +411,8 @@ def register(
(default: None).
skip_model_validation (str or PipelineVariable): Indicates if you want to skip model
validation. Values can be "All" or "None" (default: None).
source_uri (str or PipelineVariable): The URI of the source for the model package
(default: None).

Returns:
A `sagemaker.model.ModelPackage` instance.
Expand Down Expand Up @@ -457,6 +460,7 @@ def register(
nearest_model_name=nearest_model_name,
data_input_configuration=data_input_configuration,
skip_model_validation=skip_model_validation,
source_uri=source_uri,
)

def prepare_container_def(
Expand Down
2 changes: 2 additions & 0 deletions src/sagemaker/jumpstart/factory/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,6 +598,7 @@ def get_register_kwargs(
nearest_model_name: Optional[str] = None,
data_input_configuration: Optional[str] = None,
skip_model_validation: Optional[str] = None,
source_uri: Optional[str] = None,
) -> JumpStartModelRegisterKwargs:
"""Returns kwargs required to call `register` on `sagemaker.estimator.Model` object."""

Expand Down Expand Up @@ -629,6 +630,7 @@ def get_register_kwargs(
nearest_model_name=nearest_model_name,
data_input_configuration=data_input_configuration,
skip_model_validation=skip_model_validation,
source_uri=source_uri,
)

model_specs = verify_model_region_and_return_specs(
Expand Down
4 changes: 4 additions & 0 deletions src/sagemaker/jumpstart/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,6 +631,7 @@ def register(
nearest_model_name: Optional[Union[str, PipelineVariable]] = None,
data_input_configuration: Optional[Union[str, PipelineVariable]] = None,
skip_model_validation: Optional[Union[str, PipelineVariable]] = None,
source_uri: Optional[Union[str, PipelineVariable]] = None,
):
"""Creates a model package for creating SageMaker models or listing on Marketplace.

Expand Down Expand Up @@ -676,6 +677,8 @@ def register(
(default: None).
skip_model_validation (str or PipelineVariable): Indicates if you want to skip model
validation. Values can be "All" or "None" (default: None).
source_uri (str or PipelineVariable): The URI of the source for the model package
(default: None).

Returns:
A `sagemaker.model.ModelPackage` instance.
Expand Down Expand Up @@ -709,6 +712,7 @@ def register(
nearest_model_name=nearest_model_name,
data_input_configuration=data_input_configuration,
skip_model_validation=skip_model_validation,
source_uri=source_uri,
)

model_package = super(JumpStartModel, self).register(**register_kwargs.to_kwargs_dict())
Expand Down
3 changes: 3 additions & 0 deletions src/sagemaker/jumpstart/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1659,6 +1659,7 @@ class JumpStartModelRegisterKwargs(JumpStartKwargs):
"nearest_model_name",
"data_input_configuration",
"skip_model_validation",
"source_uri",
]

SERIALIZATION_EXCLUSION_SET = {
Expand Down Expand Up @@ -1699,6 +1700,7 @@ def __init__(
nearest_model_name: Optional[str] = None,
data_input_configuration: Optional[str] = None,
skip_model_validation: Optional[str] = None,
source_uri: Optional[str] = None,
) -> None:
"""Instantiates JumpStartModelRegisterKwargs object."""

Expand Down Expand Up @@ -1730,3 +1732,4 @@ def __init__(
self.nearest_model_name = nearest_model_name
self.data_input_configuration = data_input_configuration
self.skip_model_validation = skip_model_validation
self.source_uri = source_uri
98 changes: 85 additions & 13 deletions src/sagemaker/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,10 @@
)
from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements
from sagemaker.enums import EndpointType
from sagemaker.session import get_add_model_package_inference_args
from sagemaker.session import (
get_add_model_package_inference_args,
get_update_model_package_inference_args,
)

# Setting LOGGER for backward compatibility, in case users import it...
logger = LOGGER = logging.getLogger("sagemaker")
Expand Down Expand Up @@ -423,6 +426,7 @@ def register(
nearest_model_name: Optional[Union[str, PipelineVariable]] = None,
data_input_configuration: Optional[Union[str, PipelineVariable]] = None,
skip_model_validation: Optional[Union[str, PipelineVariable]] = None,
source_uri: Optional[Union[str, PipelineVariable]] = None,
):
"""Creates a model package for creating SageMaker models or listing on Marketplace.

Expand Down Expand Up @@ -472,17 +476,14 @@ def register(
(default: None).
skip_model_validation (str or PipelineVariable): Indicates if you want to skip model
validation. Values can be "All" or "None" (default: None).
source_uri (str or PipelineVariable): The URI of the source for the model package
(default: None).

Returns:
A `sagemaker.model.ModelPackage` instance or pipeline step arguments
in case the Model instance is built with
:class:`~sagemaker.workflow.pipeline_context.PipelineSession`
"""
if isinstance(self.model_data, dict):
raise ValueError(
"SageMaker Model Package currently cannot be created with ModelDataSource."
)

if content_types is not None:
self.content_types = content_types

Expand Down Expand Up @@ -513,6 +514,12 @@ def register(
"Image": self.image_uri,
}

if isinstance(self.model_data, dict):
raise ValueError(
"Un-versioned SageMaker Model Package currently cannot be "
"created with ModelDataSource."
)

if self.model_data is not None:
container_def["ModelDataUrl"] = self.model_data

Expand All @@ -536,6 +543,7 @@ def register(
sample_payload_url=sample_payload_url,
task=task,
skip_model_validation=skip_model_validation,
source_uri=source_uri,
)
model_package = self.sagemaker_session.create_model_package_from_containers(
**model_pkg_args
Expand Down Expand Up @@ -2040,8 +2048,9 @@ def __init__(
endpoints use this role to access training data and model
artifacts. After the endpoint is created, the inference code
might use the IAM role, if it needs to access an AWS resource.
model_data (str): The S3 location of a SageMaker model data
``.tar.gz`` file. Must be provided if algorithm_arn is provided.
model_data (str or dict[str, Any]): The S3 location of a SageMaker model data
``.tar.gz`` file or a dictionary representing a ``ModelDataSource``
object. Must be provided if algorithm_arn is provided.
algorithm_arn (str): algorithm arn used to train the model, can be
just the name if your account owns the algorithm. Must also
provide ``model_data``.
Expand All @@ -2050,11 +2059,6 @@ def __init__(
``model_data`` is not required.
**kwargs: Additional kwargs passed to the Model constructor.
"""
if isinstance(model_data, dict):
raise ValueError(
"Creating ModelPackage with ModelDataSource is currently not supported"
)

super(ModelPackage, self).__init__(
role=role, model_data=model_data, image_uri=None, **kwargs
)
Expand Down Expand Up @@ -2222,6 +2226,74 @@ def update_customer_metadata(self, customer_metadata_properties: Dict[str, str])
sagemaker_session = self.sagemaker_session or sagemaker.Session()
sagemaker_session.sagemaker_client.update_model_package(**update_metadata_args)

def update_inference_specification(
self,
containers: Dict = None,
image_uris: List[str] = None,
content_types: List[str] = None,
response_types: List[str] = None,
inference_instances: List[str] = None,
transform_instances: List[str] = None,
):
"""Inference specification to be set for the model package

Args:
containers (dict): The Amazon ECR registry path of the Docker image
that contains the inference code.
image_uris (List[str]): The ECR path where inference code is stored.
content_types (list[str]): The supported MIME types
for the input data.
response_types (list[str]): The supported MIME types
for the output data.
inference_instances (list[str]): A list of the instance
types that are used to generate inferences in real-time (default: None).
transform_instances (list[str]): A list of the instance
types on which a transformation job can be run or on which an endpoint can be
deployed (default: None).

"""
sagemaker_session = self.sagemaker_session or sagemaker.Session()
if (containers is not None) ^ (image_uris is None):
raise ValueError("Should have either containers or image_uris for inference.")
container_def = []
if image_uris:
for uri in image_uris:
container_def.append(
{
"Image": uri,
}
)
else:
container_def = containers

model_package_update_args = get_update_model_package_inference_args(
model_package_arn=self.model_package_arn,
containers=container_def,
content_types=content_types,
response_types=response_types,
inference_instances=inference_instances,
transform_instances=transform_instances,
)

sagemaker_session.sagemaker_client.update_model_package(**model_package_update_args)

def update_source_uri(
self,
source_uri: str,
):
"""Source uri to be set for the model package

Args:
source_uri (str): The URI of the source for the model package.

"""
update_source_uri_args = {
"ModelPackageArn": self.model_package_arn,
"SourceUri": source_uri,
}
sagemaker_session = self.sagemaker_session or sagemaker.Session()
sagemaker_session.sagemaker_client.update_model_package(**update_source_uri_args)

def remove_customer_metadata_properties(
self, customer_metadata_properties_to_remove: List[str]
):
Expand Down
4 changes: 4 additions & 0 deletions src/sagemaker/mxnet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ def register(
nearest_model_name: Optional[Union[str, PipelineVariable]] = None,
data_input_configuration: Optional[Union[str, PipelineVariable]] = None,
skip_model_validation: Optional[Union[str, PipelineVariable]] = None,
source_uri: Optional[Union[str, PipelineVariable]] = None,
):
"""Creates a model package for creating SageMaker models or listing on Marketplace.

Expand Down Expand Up @@ -225,6 +226,8 @@ def register(
(default: None).
skip_model_validation (str or PipelineVariable): Indicates if you want to skip model
validation. Values can be "All" or "None" (default: None).
source_uri (str or PipelineVariable): The URI of the source for the model package
(default: None).

Returns:
A `sagemaker.model.ModelPackage` instance.
Expand Down Expand Up @@ -264,6 +267,7 @@ def register(
nearest_model_name=nearest_model_name,
data_input_configuration=data_input_configuration,
skip_model_validation=skip_model_validation,
source_uri=source_uri,
)

def prepare_container_def(
Expand Down
4 changes: 4 additions & 0 deletions src/sagemaker/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,7 @@ def register(
nearest_model_name: Optional[Union[str, PipelineVariable]] = None,
data_input_configuration: Optional[Union[str, PipelineVariable]] = None,
skip_model_validation: Optional[Union[str, PipelineVariable]] = None,
source_uri: Optional[Union[str, PipelineVariable]] = None,
):
"""Creates a model package for creating SageMaker models or listing on Marketplace.

Expand Down Expand Up @@ -409,6 +410,8 @@ def register(
(default: None).
skip_model_validation (str or PipelineVariable): Indicates if you want to skip model
validation. Values can be "All" or "None" (default: None).
source_uri (str or PipelineVariable): The URI of the source for the model package
(default: None).

Returns:
If ``sagemaker_session`` is a ``PipelineSession`` instance, returns pipeline step
Expand Down Expand Up @@ -456,6 +459,7 @@ def register(
sample_payload_url=sample_payload_url,
task=task,
skip_model_validation=skip_model_validation,
source_uri=source_uri,
)

self.sagemaker_session.create_model_package_from_containers(**model_pkg_args)
Expand Down
4 changes: 4 additions & 0 deletions src/sagemaker/pytorch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ def register(
nearest_model_name: Optional[Union[str, PipelineVariable]] = None,
data_input_configuration: Optional[Union[str, PipelineVariable]] = None,
skip_model_validation: Optional[Union[str, PipelineVariable]] = None,
source_uri: Optional[Union[str, PipelineVariable]] = None,
):
"""Creates a model package for creating SageMaker models or listing on Marketplace.

Expand Down Expand Up @@ -227,6 +228,8 @@ def register(
(default: None).
skip_model_validation (str or PipelineVariable): Indicates if you want to skip model
validation. Values can be "All" or "None" (default: None).
source_uri (str or PipelineVariable): The URI of the source for the model package
(default: None).

Returns:
A `sagemaker.model.ModelPackage` instance.
Expand Down Expand Up @@ -266,6 +269,7 @@ def register(
nearest_model_name=nearest_model_name,
data_input_configuration=data_input_configuration,
skip_model_validation=skip_model_validation,
source_uri=source_uri,
)

def prepare_container_def(
Expand Down
Loading