43
43
ENDPOINT_CONFIG_ASYNC_KMS_KEY_ID_PATH ,
44
44
load_sagemaker_config ,
45
45
)
46
+ from sagemaker .model_card .schema_constraints import ModelApprovalStatusEnum
46
47
from sagemaker .session import Session
47
48
from sagemaker .model_metrics import ModelMetrics
48
49
from sagemaker .deprecations import removed_kwargs
@@ -374,12 +375,14 @@ def __init__(
374
375
self .dependencies = updates ["dependencies" ]
375
376
self .uploaded_code = None
376
377
self .repacked_model_data = None
378
+ self .content_types = None
379
+ self .response_types = None
377
380
378
381
@runnable_by_pipeline
379
382
def register (
380
383
self ,
381
- content_types : List [Union [str , PipelineVariable ]],
382
- response_types : List [Union [str , PipelineVariable ]],
384
+ content_types : List [Union [str , PipelineVariable ]] = None ,
385
+ response_types : List [Union [str , PipelineVariable ]] = None ,
383
386
inference_instances : Optional [List [Union [str , PipelineVariable ]]] = None ,
384
387
transform_instances : Optional [List [Union [str , PipelineVariable ]]] = None ,
385
388
model_package_name : Optional [Union [str , PipelineVariable ]] = None ,
@@ -456,16 +459,33 @@ def register(
456
459
in case the Model instance is built with
457
460
:class:`~sagemaker.workflow.pipeline_context.PipelineSession`
458
461
"""
459
- if self .model_data is None :
460
- raise ValueError ("SageMaker Model Package cannot be created without model data." )
461
462
if isinstance (self .model_data , dict ):
462
463
raise ValueError (
463
464
"SageMaker Model Package currently cannot be created with ModelDataSource."
464
465
)
465
466
467
+ if content_types is not None :
468
+ self .content_types = content_types
469
+
470
+ if response_types is not None :
471
+ self .response_types = response_types
472
+
473
+ if self .content_types is None :
474
+ raise ValueError ("The supported MIME types for the input data is not set" )
475
+
476
+ if self .response_types is None :
477
+ raise ValueError ("The supported MIME types for the output data is not set" )
478
+
466
479
if image_uri is not None :
467
480
self .image_uri = image_uri
468
481
482
+ if model_package_group_name is None and model_package_name is None :
483
+ # If model package group and model package name is not set
484
+ # then register to auto-generated model package group
485
+ model_package_group_name = utils .base_name_from_image (
486
+ self .image_uri , default_base_name = ModelPackage .__name__
487
+ )
488
+
469
489
if model_package_group_name is not None :
470
490
container_def = self .prepare_container_def ()
471
491
container_def = update_container_with_inference_params (
@@ -478,12 +498,14 @@ def register(
478
498
else :
479
499
container_def = {
480
500
"Image" : self .image_uri ,
481
- "ModelDataUrl" : self .model_data ,
482
501
}
483
502
503
+ if self .model_data is not None :
504
+ container_def ["ModelDataUrl" ] = self .model_data
505
+
484
506
model_pkg_args = sagemaker .get_model_package_args (
485
- content_types ,
486
- response_types ,
507
+ self . content_types ,
508
+ self . response_types ,
487
509
inference_instances = inference_instances ,
488
510
transform_instances = transform_instances ,
489
511
model_package_name = model_package_name ,
@@ -1751,6 +1773,7 @@ def __init__(
1751
1773
1752
1774
# works for MODEL_PACKAGE_ARN with or without version info.
1753
1775
MODEL_PACKAGE_ARN_PATTERN = r"arn:aws:sagemaker:(.*?):(.*?):model-package/(.*?)(?:/(\d+))?$"
1776
+ MODEL_PACKAGE_VERSIONED_ARN_PATTERN = r"arn:aws:sagemaker:(.*?):(.*?):model-package/(.*?)/(\d+)$"
1754
1777
1755
1778
1756
1779
class ModelPackage (Model ):
@@ -1885,6 +1908,18 @@ def _create_sagemaker_model(self, *args, **kwargs): # pylint: disable=unused-ar
1885
1908
self ._ensure_base_name_if_needed (model_package_name )
1886
1909
self ._set_model_name_if_needed ()
1887
1910
1911
+ # Quering the approval status for the model package
1912
+ # Approving the versioned model package in case it is not approved
1913
+ model_package_desc = self .sagemaker_session .sagemaker_client .describe_model_package (
1914
+ ModelPackageName = self .model_package_arn or model_package_name
1915
+ )
1916
+ if self .model_package_arn is None :
1917
+ self .model_package_arn = model_package_desc ["ModelPackageArn" ]
1918
+ if re .match (MODEL_PACKAGE_VERSIONED_ARN_PATTERN , self .model_package_arn ):
1919
+ approval_status = model_package_desc .get ("ModelApprovalStatus" , "" )
1920
+ if approval_status != ModelApprovalStatusEnum .APPROVED :
1921
+ self .update_approval_status (approval_status = ModelApprovalStatusEnum .APPROVED )
1922
+
1888
1923
self .sagemaker_session .create_model (
1889
1924
self .name ,
1890
1925
self .role ,
@@ -1898,3 +1933,25 @@ def _ensure_base_name_if_needed(self, base_name):
1898
1933
"""Set the base name if there is no model name provided."""
1899
1934
if self .name is None :
1900
1935
self ._base_name = base_name
1936
+
1937
+ def update_approval_status (self , approval_status , approval_description = None ):
1938
+ """Update the approval status for the model package
1939
+
1940
+ Args:
1941
+ approval_status (str or PipelineVariable): Model Approval Status, values can be
1942
+ "Approved", "Rejected", or "PendingManualApproval".
1943
+ approval_description (str): Optional. Description for the approval status of the model
1944
+ (default: None).
1945
+ """
1946
+ if self .model_package_arn is None :
1947
+ raise ValueError ("model_package_arn is required to update the status." )
1948
+
1949
+ update_approval_args = {
1950
+ "ModelPackageArn" : self .model_package_arn ,
1951
+ "ModelApprovalStatus" : approval_status ,
1952
+ }
1953
+
1954
+ if approval_description is not None :
1955
+ update_approval_args ["ApprovalDescription" ] = approval_description
1956
+
1957
+ self .sagemaker_session .sagemaker_client .update_model_package (** update_approval_args )
0 commit comments