Skip to content

Commit 95f1292

Browse files
mrudulmnbenieric
authored and
root
committed
feature: Add ModelDataSource and SourceUri support for model package and while registering (aws#4492)
Co-authored-by: Erick Benitez-Ramos <[email protected]>
1 parent 3e92cb8 commit 95f1292

File tree

22 files changed

+864
-35
lines changed

22 files changed

+864
-35
lines changed

src/sagemaker/chainer/model.py

+4
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ def register(
174174
nearest_model_name: Optional[Union[str, PipelineVariable]] = None,
175175
data_input_configuration: Optional[Union[str, PipelineVariable]] = None,
176176
skip_model_validation: Optional[Union[str, PipelineVariable]] = None,
177+
source_uri: Optional[Union[str, PipelineVariable]] = None,
177178
):
178179
"""Creates a model package for creating SageMaker models or listing on Marketplace.
179180
@@ -223,6 +224,8 @@ def register(
223224
(default: None).
224225
skip_model_validation (str or PipelineVariable): Indicates if you want to skip model
225226
validation. Values can be "All" or "None" (default: None).
227+
source_uri (str or PipelineVariable): The URI of the source for the model package
228+
(default: None).
226229
227230
Returns:
228231
str: A string of SageMaker Model Package ARN.
@@ -262,6 +265,7 @@ def register(
262265
nearest_model_name=nearest_model_name,
263266
data_input_configuration=data_input_configuration,
264267
skip_model_validation=skip_model_validation,
268+
source_uri=source_uri,
265269
)
266270

267271
def prepare_container_def(

src/sagemaker/estimator.py

+3
Original file line numberDiff line numberDiff line change
@@ -1718,6 +1718,7 @@ def register(
17181718
nearest_model_name=None,
17191719
data_input_configuration=None,
17201720
skip_model_validation=None,
1721+
source_uri=None,
17211722
**kwargs,
17221723
):
17231724
"""Creates a model package for creating SageMaker models or listing on Marketplace.
@@ -1765,6 +1766,7 @@ def register(
17651766
data_input_configuration (str): Input object for the model (default: None).
17661767
skip_model_validation (str): Indicates if you want to skip model validation.
17671768
Values can be "All" or "None" (default: None).
1769+
source_uri (str): The URI of the source for the model package (default: None).
17681770
**kwargs: Passed to invocation of ``create_model()``. Implementations may customize
17691771
``create_model()`` to accept ``**kwargs`` to customize model creation during
17701772
deploy. For more, see the implementation docs.
@@ -1809,6 +1811,7 @@ def register(
18091811
nearest_model_name=nearest_model_name,
18101812
data_input_configuration=data_input_configuration,
18111813
skip_model_validation=skip_model_validation,
1814+
source_uri=source_uri,
18121815
)
18131816

18141817
@property

src/sagemaker/huggingface/model.py

+4
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,7 @@ def register(
360360
nearest_model_name: Optional[Union[str, PipelineVariable]] = None,
361361
data_input_configuration: Optional[Union[str, PipelineVariable]] = None,
362362
skip_model_validation: Optional[Union[str, PipelineVariable]] = None,
363+
source_uri: Optional[Union[str, PipelineVariable]] = None,
363364
):
364365
"""Creates a model package for creating SageMaker models or listing on Marketplace.
365366
@@ -410,6 +411,8 @@ def register(
410411
(default: None).
411412
skip_model_validation (str or PipelineVariable): Indicates if you want to skip model
412413
validation. Values can be "All" or "None" (default: None).
414+
source_uri (str or PipelineVariable): The URI of the source for the model package
415+
(default: None).
413416
414417
Returns:
415418
A `sagemaker.model.ModelPackage` instance.
@@ -457,6 +460,7 @@ def register(
457460
nearest_model_name=nearest_model_name,
458461
data_input_configuration=data_input_configuration,
459462
skip_model_validation=skip_model_validation,
463+
source_uri=source_uri,
460464
)
461465

462466
def prepare_container_def(

src/sagemaker/jumpstart/factory/model.py

+2
Original file line numberDiff line numberDiff line change
@@ -598,6 +598,7 @@ def get_register_kwargs(
598598
nearest_model_name: Optional[str] = None,
599599
data_input_configuration: Optional[str] = None,
600600
skip_model_validation: Optional[str] = None,
601+
source_uri: Optional[str] = None,
601602
) -> JumpStartModelRegisterKwargs:
602603
"""Returns kwargs required to call `register` on `sagemaker.estimator.Model` object."""
603604

@@ -629,6 +630,7 @@ def get_register_kwargs(
629630
nearest_model_name=nearest_model_name,
630631
data_input_configuration=data_input_configuration,
631632
skip_model_validation=skip_model_validation,
633+
source_uri=source_uri,
632634
)
633635

634636
model_specs = verify_model_region_and_return_specs(

src/sagemaker/jumpstart/model.py

+4
Original file line numberDiff line numberDiff line change
@@ -631,6 +631,7 @@ def register(
631631
nearest_model_name: Optional[Union[str, PipelineVariable]] = None,
632632
data_input_configuration: Optional[Union[str, PipelineVariable]] = None,
633633
skip_model_validation: Optional[Union[str, PipelineVariable]] = None,
634+
source_uri: Optional[Union[str, PipelineVariable]] = None,
634635
):
635636
"""Creates a model package for creating SageMaker models or listing on Marketplace.
636637
@@ -676,6 +677,8 @@ def register(
676677
(default: None).
677678
skip_model_validation (str or PipelineVariable): Indicates if you want to skip model
678679
validation. Values can be "All" or "None" (default: None).
680+
source_uri (str or PipelineVariable): The URI of the source for the model package
681+
(default: None).
679682
680683
Returns:
681684
A `sagemaker.model.ModelPackage` instance.
@@ -709,6 +712,7 @@ def register(
709712
nearest_model_name=nearest_model_name,
710713
data_input_configuration=data_input_configuration,
711714
skip_model_validation=skip_model_validation,
715+
source_uri=source_uri,
712716
)
713717

714718
model_package = super(JumpStartModel, self).register(**register_kwargs.to_kwargs_dict())

src/sagemaker/jumpstart/types.py

+3
Original file line numberDiff line numberDiff line change
@@ -1659,6 +1659,7 @@ class JumpStartModelRegisterKwargs(JumpStartKwargs):
16591659
"nearest_model_name",
16601660
"data_input_configuration",
16611661
"skip_model_validation",
1662+
"source_uri",
16621663
]
16631664

16641665
SERIALIZATION_EXCLUSION_SET = {
@@ -1699,6 +1700,7 @@ def __init__(
16991700
nearest_model_name: Optional[str] = None,
17001701
data_input_configuration: Optional[str] = None,
17011702
skip_model_validation: Optional[str] = None,
1703+
source_uri: Optional[str] = None,
17021704
) -> None:
17031705
"""Instantiates JumpStartModelRegisterKwargs object."""
17041706

@@ -1730,3 +1732,4 @@ def __init__(
17301732
self.nearest_model_name = nearest_model_name
17311733
self.data_input_configuration = data_input_configuration
17321734
self.skip_model_validation = skip_model_validation
1735+
self.source_uri = source_uri

src/sagemaker/model.py

+85-13
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,10 @@
7777
)
7878
from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements
7979
from sagemaker.enums import EndpointType
80-
from sagemaker.session import get_add_model_package_inference_args
80+
from sagemaker.session import (
81+
get_add_model_package_inference_args,
82+
get_update_model_package_inference_args,
83+
)
8184

8285
# Setting LOGGER for backward compatibility, in case users import it...
8386
logger = LOGGER = logging.getLogger("sagemaker")
@@ -423,6 +426,7 @@ def register(
423426
nearest_model_name: Optional[Union[str, PipelineVariable]] = None,
424427
data_input_configuration: Optional[Union[str, PipelineVariable]] = None,
425428
skip_model_validation: Optional[Union[str, PipelineVariable]] = None,
429+
source_uri: Optional[Union[str, PipelineVariable]] = None,
426430
):
427431
"""Creates a model package for creating SageMaker models or listing on Marketplace.
428432
@@ -472,17 +476,14 @@ def register(
472476
(default: None).
473477
skip_model_validation (str or PipelineVariable): Indicates if you want to skip model
474478
validation. Values can be "All" or "None" (default: None).
479+
source_uri (str or PipelineVariable): The URI of the source for the model package
480+
(default: None).
475481
476482
Returns:
477483
A `sagemaker.model.ModelPackage` instance or pipeline step arguments
478484
in case the Model instance is built with
479485
:class:`~sagemaker.workflow.pipeline_context.PipelineSession`
480486
"""
481-
if isinstance(self.model_data, dict):
482-
raise ValueError(
483-
"SageMaker Model Package currently cannot be created with ModelDataSource."
484-
)
485-
486487
if content_types is not None:
487488
self.content_types = content_types
488489

@@ -513,6 +514,12 @@ def register(
513514
"Image": self.image_uri,
514515
}
515516

517+
if isinstance(self.model_data, dict):
518+
raise ValueError(
519+
"Un-versioned SageMaker Model Package currently cannot be "
520+
"created with ModelDataSource."
521+
)
522+
516523
if self.model_data is not None:
517524
container_def["ModelDataUrl"] = self.model_data
518525

@@ -536,6 +543,7 @@ def register(
536543
sample_payload_url=sample_payload_url,
537544
task=task,
538545
skip_model_validation=skip_model_validation,
546+
source_uri=source_uri,
539547
)
540548
model_package = self.sagemaker_session.create_model_package_from_containers(
541549
**model_pkg_args
@@ -2040,8 +2048,9 @@ def __init__(
20402048
endpoints use this role to access training data and model
20412049
artifacts. After the endpoint is created, the inference code
20422050
might use the IAM role, if it needs to access an AWS resource.
2043-
model_data (str): The S3 location of a SageMaker model data
2044-
``.tar.gz`` file. Must be provided if algorithm_arn is provided.
2051+
model_data (str or dict[str, Any]): The S3 location of a SageMaker model data
2052+
``.tar.gz`` file or a dictionary representing a ``ModelDataSource``
2053+
object. Must be provided if algorithm_arn is provided.
20452054
algorithm_arn (str): algorithm arn used to train the model, can be
20462055
just the name if your account owns the algorithm. Must also
20472056
provide ``model_data``.
@@ -2050,11 +2059,6 @@ def __init__(
20502059
``model_data`` is not required.
20512060
**kwargs: Additional kwargs passed to the Model constructor.
20522061
"""
2053-
if isinstance(model_data, dict):
2054-
raise ValueError(
2055-
"Creating ModelPackage with ModelDataSource is currently not supported"
2056-
)
2057-
20582062
super(ModelPackage, self).__init__(
20592063
role=role, model_data=model_data, image_uri=None, **kwargs
20602064
)
@@ -2222,6 +2226,74 @@ def update_customer_metadata(self, customer_metadata_properties: Dict[str, str])
22222226
sagemaker_session = self.sagemaker_session or sagemaker.Session()
22232227
sagemaker_session.sagemaker_client.update_model_package(**update_metadata_args)
22242228

2229+
def update_inference_specification(
2230+
self,
2231+
containers: Dict = None,
2232+
image_uris: List[str] = None,
2233+
content_types: List[str] = None,
2234+
response_types: List[str] = None,
2235+
inference_instances: List[str] = None,
2236+
transform_instances: List[str] = None,
2237+
):
2238+
"""Inference specification to be set for the model package
2239+
2240+
Args:
2241+
containers (dict): The Amazon ECR registry path of the Docker image
2242+
that contains the inference code.
2243+
image_uris (List[str]): The ECR path where inference code is stored.
2244+
content_types (list[str]): The supported MIME types
2245+
for the input data.
2246+
response_types (list[str]): The supported MIME types
2247+
for the output data.
2248+
inference_instances (list[str]): A list of the instance
2249+
types that are used to generate inferences in real-time (default: None).
2250+
transform_instances (list[str]): A list of the instance
2251+
types on which a transformation job can be run or on which an endpoint can be
2252+
deployed (default: None).
2253+
2254+
"""
2255+
sagemaker_session = self.sagemaker_session or sagemaker.Session()
2256+
if (containers is not None) ^ (image_uris is None):
2257+
raise ValueError("Should have either containers or image_uris for inference.")
2258+
container_def = []
2259+
if image_uris:
2260+
for uri in image_uris:
2261+
container_def.append(
2262+
{
2263+
"Image": uri,
2264+
}
2265+
)
2266+
else:
2267+
container_def = containers
2268+
2269+
model_package_update_args = get_update_model_package_inference_args(
2270+
model_package_arn=self.model_package_arn,
2271+
containers=container_def,
2272+
content_types=content_types,
2273+
response_types=response_types,
2274+
inference_instances=inference_instances,
2275+
transform_instances=transform_instances,
2276+
)
2277+
2278+
sagemaker_session.sagemaker_client.update_model_package(**model_package_update_args)
2279+
2280+
def update_source_uri(
2281+
self,
2282+
source_uri: str,
2283+
):
2284+
"""Source uri to be set for the model package
2285+
2286+
Args:
2287+
source_uri (str): The URI of the source for the model package.
2288+
2289+
"""
2290+
update_source_uri_args = {
2291+
"ModelPackageArn": self.model_package_arn,
2292+
"SourceUri": source_uri,
2293+
}
2294+
sagemaker_session = self.sagemaker_session or sagemaker.Session()
2295+
sagemaker_session.sagemaker_client.update_model_package(**update_source_uri_args)
2296+
22252297
def remove_customer_metadata_properties(
22262298
self, customer_metadata_properties_to_remove: List[str]
22272299
):

src/sagemaker/mxnet/model.py

+4
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,7 @@ def register(
176176
nearest_model_name: Optional[Union[str, PipelineVariable]] = None,
177177
data_input_configuration: Optional[Union[str, PipelineVariable]] = None,
178178
skip_model_validation: Optional[Union[str, PipelineVariable]] = None,
179+
source_uri: Optional[Union[str, PipelineVariable]] = None,
179180
):
180181
"""Creates a model package for creating SageMaker models or listing on Marketplace.
181182
@@ -225,6 +226,8 @@ def register(
225226
(default: None).
226227
skip_model_validation (str or PipelineVariable): Indicates if you want to skip model
227228
validation. Values can be "All" or "None" (default: None).
229+
source_uri (str or PipelineVariable): The URI of the source for the model package
230+
(default: None).
228231
229232
Returns:
230233
A `sagemaker.model.ModelPackage` instance.
@@ -264,6 +267,7 @@ def register(
264267
nearest_model_name=nearest_model_name,
265268
data_input_configuration=data_input_configuration,
266269
skip_model_validation=skip_model_validation,
270+
source_uri=source_uri,
267271
)
268272

269273
def prepare_container_def(

src/sagemaker/pipeline.py

+4
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,7 @@ def register(
360360
nearest_model_name: Optional[Union[str, PipelineVariable]] = None,
361361
data_input_configuration: Optional[Union[str, PipelineVariable]] = None,
362362
skip_model_validation: Optional[Union[str, PipelineVariable]] = None,
363+
source_uri: Optional[Union[str, PipelineVariable]] = None,
363364
):
364365
"""Creates a model package for creating SageMaker models or listing on Marketplace.
365366
@@ -409,6 +410,8 @@ def register(
409410
(default: None).
410411
skip_model_validation (str or PipelineVariable): Indicates if you want to skip model
411412
validation. Values can be "All" or "None" (default: None).
413+
source_uri (str or PipelineVariable): The URI of the source for the model package
414+
(default: None).
412415
413416
Returns:
414417
If ``sagemaker_session`` is a ``PipelineSession`` instance, returns pipeline step
@@ -456,6 +459,7 @@ def register(
456459
sample_payload_url=sample_payload_url,
457460
task=task,
458461
skip_model_validation=skip_model_validation,
462+
source_uri=source_uri,
459463
)
460464

461465
self.sagemaker_session.create_model_package_from_containers(**model_pkg_args)

src/sagemaker/pytorch/model.py

+4
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,7 @@ def register(
178178
nearest_model_name: Optional[Union[str, PipelineVariable]] = None,
179179
data_input_configuration: Optional[Union[str, PipelineVariable]] = None,
180180
skip_model_validation: Optional[Union[str, PipelineVariable]] = None,
181+
source_uri: Optional[Union[str, PipelineVariable]] = None,
181182
):
182183
"""Creates a model package for creating SageMaker models or listing on Marketplace.
183184
@@ -227,6 +228,8 @@ def register(
227228
(default: None).
228229
skip_model_validation (str or PipelineVariable): Indicates if you want to skip model
229230
validation. Values can be "All" or "None" (default: None).
231+
source_uri (str or PipelineVariable): The URI of the source for the model package
232+
(default: None).
230233
231234
Returns:
232235
A `sagemaker.model.ModelPackage` instance.
@@ -266,6 +269,7 @@ def register(
266269
nearest_model_name=nearest_model_name,
267270
data_input_configuration=data_input_configuration,
268271
skip_model_validation=skip_model_validation,
272+
source_uri=source_uri,
269273
)
270274

271275
def prepare_container_def(

0 commit comments

Comments
 (0)