43
43
ENDPOINT_CONFIG_ASYNC_KMS_KEY_ID_PATH ,
44
44
load_sagemaker_config ,
45
45
)
46
+ from sagemaker .model_card .schema_constraints import ModelApprovalStatusEnum
46
47
from sagemaker .session import Session
47
48
from sagemaker .model_metrics import ModelMetrics
48
49
from sagemaker .deprecations import removed_kwargs
@@ -374,12 +375,14 @@ def __init__(
374
375
self .dependencies = updates ["dependencies" ]
375
376
self .uploaded_code = None
376
377
self .repacked_model_data = None
378
+ self .content_types = None
379
+ self .response_types = None
377
380
378
381
@runnable_by_pipeline
379
382
def register (
380
383
self ,
381
- content_types : List [Union [str , PipelineVariable ]],
382
- response_types : List [Union [str , PipelineVariable ]],
384
+ content_types : List [Union [str , PipelineVariable ]] = None ,
385
+ response_types : List [Union [str , PipelineVariable ]] = None ,
383
386
inference_instances : Optional [List [Union [str , PipelineVariable ]]] = None ,
384
387
transform_instances : Optional [List [Union [str , PipelineVariable ]]] = None ,
385
388
model_package_name : Optional [Union [str , PipelineVariable ]] = None ,
@@ -456,16 +459,33 @@ def register(
456
459
in case the Model instance is built with
457
460
:class:`~sagemaker.workflow.pipeline_context.PipelineSession`
458
461
"""
459
- if self .model_data is None :
460
- raise ValueError ("SageMaker Model Package cannot be created without model data." )
461
462
if isinstance (self .model_data , dict ):
462
463
raise ValueError (
463
464
"SageMaker Model Package currently cannot be created with ModelDataSource."
464
465
)
465
466
467
+ if content_types is not None :
468
+ self .content_types = content_types
469
+
470
+ if response_types is not None :
471
+ self .response_types = response_types
472
+
473
+ if self .content_types is None :
474
+ raise ValueError ("The supported MIME types for the input data is not set" )
475
+
476
+ if self .response_types is None :
477
+ raise ValueError ("The supported MIME types for the output data is not set" )
478
+
466
479
if image_uri is not None :
467
480
self .image_uri = image_uri
468
481
482
+ if model_package_group_name is None and model_package_name is None :
483
+ # If model package group and model package name is not set
484
+ # then register to auto-generated model package group
485
+ model_package_group_name = utils .base_name_from_image (
486
+ self .image_uri , default_base_name = ModelPackage .__name__
487
+ )
488
+
469
489
if model_package_group_name is not None :
470
490
container_def = self .prepare_container_def ()
471
491
container_def = update_container_with_inference_params (
@@ -478,12 +498,14 @@ def register(
478
498
else :
479
499
container_def = {
480
500
"Image" : self .image_uri ,
481
- "ModelDataUrl" : self .model_data ,
482
501
}
483
502
503
+ if self .model_data is not None :
504
+ container_def ["ModelDataUrl" ] = self .model_data
505
+
484
506
model_pkg_args = sagemaker .get_model_package_args (
485
- content_types ,
486
- response_types ,
507
+ self . content_types ,
508
+ self . response_types ,
487
509
inference_instances = inference_instances ,
488
510
transform_instances = transform_instances ,
489
511
model_package_name = model_package_name ,
@@ -511,6 +533,7 @@ def register(
511
533
role = self .role ,
512
534
model_data = self .model_data ,
513
535
model_package_arn = model_package .get ("ModelPackageArn" ),
536
+ sagemaker_session = self .sagemaker_session ,
514
537
)
515
538
516
539
@runnable_by_pipeline
@@ -1751,6 +1774,7 @@ def __init__(
1751
1774
1752
1775
# works for MODEL_PACKAGE_ARN with or without version info.
1753
1776
MODEL_PACKAGE_ARN_PATTERN = r"arn:aws:sagemaker:(.*?):(.*?):model-package/(.*?)(?:/(\d+))?$"
1777
+ MODEL_PACKAGE_VERSIONED_ARN_PATTERN = r"arn:aws:sagemaker:(.*?):(.*?):model-package/(.*?)/(\d+)$"
1754
1778
1755
1779
1756
1780
class ModelPackage (Model ):
@@ -1885,6 +1909,18 @@ def _create_sagemaker_model(self, *args, **kwargs): # pylint: disable=unused-ar
1885
1909
self ._ensure_base_name_if_needed (model_package_name )
1886
1910
self ._set_model_name_if_needed ()
1887
1911
1912
+ # Quering the approval status for the model package
1913
+ # Approving the versioned model package in case it is not approved
1914
+ model_package_desc = self .sagemaker_session .sagemaker_client .describe_model_package (
1915
+ ModelPackageName = self .model_package_arn or model_package_name
1916
+ )
1917
+ if self .model_package_arn is None :
1918
+ self .model_package_arn = model_package_desc ["ModelPackageArn" ]
1919
+ if re .match (MODEL_PACKAGE_VERSIONED_ARN_PATTERN , self .model_package_arn ):
1920
+ approval_status = model_package_desc .get ("ModelApprovalStatus" , "" )
1921
+ if approval_status != ModelApprovalStatusEnum .APPROVED :
1922
+ self .update_approval_status (approval_status = ModelApprovalStatusEnum .APPROVED )
1923
+
1888
1924
self .sagemaker_session .create_model (
1889
1925
self .name ,
1890
1926
self .role ,
@@ -1898,3 +1934,29 @@ def _ensure_base_name_if_needed(self, base_name):
1898
1934
"""Set the base name if there is no model name provided."""
1899
1935
if self .name is None :
1900
1936
self ._base_name = base_name
1937
+
1938
+ def update_approval_status (self , approval_status , approval_description = None ):
1939
+ """Update the approval status for the model package
1940
+
1941
+ Args:
1942
+ approval_status (str or PipelineVariable): Model Approval Status, values can be
1943
+ "Approved", "Rejected", or "PendingManualApproval".
1944
+ approval_description (str): Optional. Description for the approval status of the model
1945
+ (default: None).
1946
+ """
1947
+
1948
+ # Models can lazy-init sagemaker_session until deploy() is called to support
1949
+ # LocalMode so we must make sure we have an actual session
1950
+ sagemaker_session = self .sagemaker_session or sagemaker .Session ()
1951
+ if self .model_package_arn is None :
1952
+ raise ValueError ("model_package_arn is required to update the status." )
1953
+
1954
+ update_approval_args = {
1955
+ "ModelPackageArn" : self .model_package_arn ,
1956
+ "ModelApprovalStatus" : approval_status ,
1957
+ }
1958
+
1959
+ if approval_description is not None :
1960
+ update_approval_args ["ApprovalDescription" ] = approval_description
1961
+
1962
+ sagemaker_session .sagemaker_client .update_model_package (** update_approval_args )
0 commit comments