77
77
)
78
78
from sagemaker .compute_resource_requirements .resource_requirements import ResourceRequirements
79
79
from sagemaker .enums import EndpointType
80
- from sagemaker .session import get_add_model_package_inference_args
80
+ from sagemaker .session import (
81
+ get_add_model_package_inference_args ,
82
+ get_update_model_package_inference_args ,
83
+ )
81
84
82
85
# Setting LOGGER for backward compatibility, in case users import it...
83
86
logger = LOGGER = logging .getLogger ("sagemaker" )
@@ -423,6 +426,7 @@ def register(
423
426
nearest_model_name : Optional [Union [str , PipelineVariable ]] = None ,
424
427
data_input_configuration : Optional [Union [str , PipelineVariable ]] = None ,
425
428
skip_model_validation : Optional [Union [str , PipelineVariable ]] = None ,
429
+ source_uri : Optional [Union [str , PipelineVariable ]] = None ,
426
430
):
427
431
"""Creates a model package for creating SageMaker models or listing on Marketplace.
428
432
@@ -472,17 +476,14 @@ def register(
472
476
(default: None).
473
477
skip_model_validation (str or PipelineVariable): Indicates if you want to skip model
474
478
validation. Values can be "All" or "None" (default: None).
479
+ source_uri (str or PipelineVariable): The URI of the source for the model package
480
+ (default: None).
475
481
476
482
Returns:
477
483
A `sagemaker.model.ModelPackage` instance or pipeline step arguments
478
484
in case the Model instance is built with
479
485
:class:`~sagemaker.workflow.pipeline_context.PipelineSession`
480
486
"""
481
- if isinstance (self .model_data , dict ):
482
- raise ValueError (
483
- "SageMaker Model Package currently cannot be created with ModelDataSource."
484
- )
485
-
486
487
if content_types is not None :
487
488
self .content_types = content_types
488
489
@@ -513,6 +514,12 @@ def register(
513
514
"Image" : self .image_uri ,
514
515
}
515
516
517
+ if isinstance (self .model_data , dict ):
518
+ raise ValueError (
519
+ "Un-versioned SageMaker Model Package currently cannot be "
520
+ "created with ModelDataSource."
521
+ )
522
+
516
523
if self .model_data is not None :
517
524
container_def ["ModelDataUrl" ] = self .model_data
518
525
@@ -536,6 +543,7 @@ def register(
536
543
sample_payload_url = sample_payload_url ,
537
544
task = task ,
538
545
skip_model_validation = skip_model_validation ,
546
+ source_uri = source_uri ,
539
547
)
540
548
model_package = self .sagemaker_session .create_model_package_from_containers (
541
549
** model_pkg_args
@@ -2040,8 +2048,9 @@ def __init__(
2040
2048
endpoints use this role to access training data and model
2041
2049
artifacts. After the endpoint is created, the inference code
2042
2050
might use the IAM role, if it needs to access an AWS resource.
2043
- model_data (str): The S3 location of a SageMaker model data
2044
- ``.tar.gz`` file. Must be provided if algorithm_arn is provided.
2051
+ model_data (str or dict[str, Any]): The S3 location of a SageMaker model data
2052
+ ``.tar.gz`` file or a dictionary representing a ``ModelDataSource``
2053
+ object. Must be provided if algorithm_arn is provided.
2045
2054
algorithm_arn (str): algorithm arn used to train the model, can be
2046
2055
just the name if your account owns the algorithm. Must also
2047
2056
provide ``model_data``.
@@ -2050,11 +2059,6 @@ def __init__(
2050
2059
``model_data`` is not required.
2051
2060
**kwargs: Additional kwargs passed to the Model constructor.
2052
2061
"""
2053
- if isinstance (model_data , dict ):
2054
- raise ValueError (
2055
- "Creating ModelPackage with ModelDataSource is currently not supported"
2056
- )
2057
-
2058
2062
super (ModelPackage , self ).__init__ (
2059
2063
role = role , model_data = model_data , image_uri = None , ** kwargs
2060
2064
)
@@ -2222,6 +2226,74 @@ def update_customer_metadata(self, customer_metadata_properties: Dict[str, str])
2222
2226
sagemaker_session = self .sagemaker_session or sagemaker .Session ()
2223
2227
sagemaker_session .sagemaker_client .update_model_package (** update_metadata_args )
2224
2228
2229
+ def update_inference_specification (
2230
+ self ,
2231
+ containers : Dict = None ,
2232
+ image_uris : List [str ] = None ,
2233
+ content_types : List [str ] = None ,
2234
+ response_types : List [str ] = None ,
2235
+ inference_instances : List [str ] = None ,
2236
+ transform_instances : List [str ] = None ,
2237
+ ):
2238
+ """Inference specification to be set for the model package
2239
+
2240
+ Args:
2241
+ containers (dict): The Amazon ECR registry path of the Docker image
2242
+ that contains the inference code.
2243
+ image_uris (List[str]): The ECR path where inference code is stored.
2244
+ content_types (list[str]): The supported MIME types
2245
+ for the input data.
2246
+ response_types (list[str]): The supported MIME types
2247
+ for the output data.
2248
+ inference_instances (list[str]): A list of the instance
2249
+ types that are used to generate inferences in real-time (default: None).
2250
+ transform_instances (list[str]): A list of the instance
2251
+ types on which a transformation job can be run or on which an endpoint can be
2252
+ deployed (default: None).
2253
+
2254
+ """
2255
+ sagemaker_session = self .sagemaker_session or sagemaker .Session ()
2256
+ if (containers is not None ) ^ (image_uris is None ):
2257
+ raise ValueError ("Should have either containers or image_uris for inference." )
2258
+ container_def = []
2259
+ if image_uris :
2260
+ for uri in image_uris :
2261
+ container_def .append (
2262
+ {
2263
+ "Image" : uri ,
2264
+ }
2265
+ )
2266
+ else :
2267
+ container_def = containers
2268
+
2269
+ model_package_update_args = get_update_model_package_inference_args (
2270
+ model_package_arn = self .model_package_arn ,
2271
+ containers = container_def ,
2272
+ content_types = content_types ,
2273
+ response_types = response_types ,
2274
+ inference_instances = inference_instances ,
2275
+ transform_instances = transform_instances ,
2276
+ )
2277
+
2278
+ sagemaker_session .sagemaker_client .update_model_package (** model_package_update_args )
2279
+
2280
+ def update_source_uri (
2281
+ self ,
2282
+ source_uri : str ,
2283
+ ):
2284
+ """Source uri to be set for the model package
2285
+
2286
+ Args:
2287
+ source_uri (str): The URI of the source for the model package.
2288
+
2289
+ """
2290
+ update_source_uri_args = {
2291
+ "ModelPackageArn" : self .model_package_arn ,
2292
+ "SourceUri" : source_uri ,
2293
+ }
2294
+ sagemaker_session = self .sagemaker_session or sagemaker .Session ()
2295
+ sagemaker_session .sagemaker_client .update_model_package (** update_source_uri_args )
2296
+
2225
2297
def remove_customer_metadata_properties (
2226
2298
self , customer_metadata_properties_to_remove : List [str ]
2227
2299
):
0 commit comments