@@ -710,7 +710,7 @@ def wait_for_model_package(self, model_package_name, poll=5):
710
710
return desc
711
711
712
712
def create_endpoint_config (self , name , model_name , initial_instance_count , instance_type ,
713
- accelerator_type = None , tags = None ):
713
+ accelerator_type = None , tags = None , kms_key = None ):
714
714
"""Create an Amazon SageMaker endpoint configuration.
715
715
716
716
The endpoint configuration identifies the Amazon SageMaker model (created using the
@@ -738,12 +738,21 @@ def create_endpoint_config(self, name, model_name, initial_instance_count, insta
738
738
739
739
tags = tags or []
740
740
741
- self .sagemaker_client .create_endpoint_config (
742
- EndpointConfigName = name ,
743
- ProductionVariants = [production_variant (model_name , instance_type , initial_instance_count ,
744
- accelerator_type = accelerator_type )],
745
- Tags = tags
746
- )
741
+ request = {
742
+ 'EndpointConfigName' : name ,
743
+ 'ProductionVariants' : [
744
+ production_variant (model_name , instance_type , initial_instance_count ,
745
+ accelerator_type = accelerator_type )
746
+ ],
747
+ }
748
+
749
+ if tags is not None :
750
+ request ['Tags' ] = tags
751
+
752
+ if kms_key is not None :
753
+ request ['KmsKeyId' ] = kms_key
754
+
755
+ self .sagemaker_client .create_endpoint_config (** request )
747
756
return name
748
757
749
758
def create_endpoint (self , endpoint_name , config_name , tags = None , wait = True ):
@@ -1032,13 +1041,15 @@ def endpoint_from_model_data(self, model_s3_location, deployment_image, initial_
1032
1041
self .create_endpoint (endpoint_name = name , config_name = name , wait = wait )
1033
1042
return name
1034
1043
1035
- def endpoint_from_production_variants (self , name , production_variants , tags = None , wait = True ):
1044
+ def endpoint_from_production_variants (self , name , production_variants , tags = None , kms_key = None , wait = True ):
1036
1045
"""Create an SageMaker ``Endpoint`` from a list of production variants.
1037
1046
1038
1047
Args:
1039
1048
name (str): The name of the ``Endpoint`` to create.
1040
1049
production_variants (list[dict[str, str]]): The list of production variants to deploy.
1041
1050
tags (list[dict[str, str]]): A list of key-value pairs for tagging the endpoint (default: None).
1051
+ kms_key (str): The KMS key that is used to encrypt the data on the storage volume attached
1052
+ to the instance hosting the endpoint.
1042
1053
wait (bool): Whether to wait for the endpoint deployment to complete before returning (default: True).
1043
1054
1044
1055
Returns:
@@ -1050,6 +1061,8 @@ def endpoint_from_production_variants(self, name, production_variants, tags=None
1050
1061
config_options = {'EndpointConfigName' : name , 'ProductionVariants' : production_variants }
1051
1062
if tags :
1052
1063
config_options ['Tags' ] = tags
1064
+ if kms_key :
1065
+ config_options ['KmsKeyId' ] = kms_key
1053
1066
1054
1067
self .sagemaker_client .create_endpoint_config (** config_options )
1055
1068
return self .create_endpoint (endpoint_name = name , config_name = name , tags = tags , wait = wait )
0 commit comments