@@ -2786,6 +2786,121 @@ def test_create_endpoint_config_from_existing_with_sagemaker_config_injection(
2786
2786
)
2787
2787
2788
2788
2789
+ def test_create_endpoint_config_from_existing_with_sagemaker_config_injection_partial_kms_support (
2790
+ sagemaker_session ,
2791
+ ):
2792
+ sagemaker_session .sagemaker_config = SAGEMAKER_CONFIG_ENDPOINT_CONFIG
2793
+
2794
+ pvs = [
2795
+ sagemaker .production_variant ("A" , "ml.g5.2xlarge" ),
2796
+ sagemaker .production_variant ("B" , "ml.p2.xlarge" ),
2797
+ sagemaker .production_variant ("C" , "ml.p2.xlarge" ),
2798
+ ]
2799
+ # Add DestinationS3Uri to only one production variant
2800
+ pvs [0 ]["CoreDumpConfig" ] = {"DestinationS3Uri" : "s3://test" }
2801
+ existing_endpoint_arn = "arn:aws:sagemaker:us-west-2:123412341234:endpoint-config/foo"
2802
+ existing_endpoint_name = "foo"
2803
+ new_endpoint_name = "new-foo"
2804
+ sagemaker_session .sagemaker_client .describe_endpoint_config .return_value = {
2805
+ "ProductionVariants" : [sagemaker .production_variant ("A" , "ml.m4.xlarge" )],
2806
+ "EndpointConfigArn" : existing_endpoint_arn ,
2807
+ "AsyncInferenceConfig" : {},
2808
+ }
2809
+ sagemaker_session .sagemaker_client .list_tags .return_value = {"Tags" : []}
2810
+
2811
+ sagemaker_session .create_endpoint_config_from_existing (
2812
+ existing_endpoint_name , new_endpoint_name , new_production_variants = pvs
2813
+ )
2814
+
2815
+ expected_production_variant_0_kms_key_id = SAGEMAKER_CONFIG_ENDPOINT_CONFIG ["SageMaker" ][
2816
+ "EndpointConfig"
2817
+ ]["ProductionVariants" ][0 ]["CoreDumpConfig" ]["KmsKeyId" ]
2818
+ expected_inference_kms_key_id = SAGEMAKER_CONFIG_ENDPOINT_CONFIG ["SageMaker" ]["EndpointConfig" ][
2819
+ "AsyncInferenceConfig"
2820
+ ]["OutputConfig" ]["KmsKeyId" ]
2821
+ expected_kms_key_id = SAGEMAKER_CONFIG_ENDPOINT_CONFIG ["SageMaker" ]["EndpointConfig" ][
2822
+ "KmsKeyId"
2823
+ ]
2824
+ expected_tags = SAGEMAKER_CONFIG_ENDPOINT_CONFIG ["SageMaker" ]["EndpointConfig" ]["Tags" ]
2825
+
2826
+ sagemaker_session .sagemaker_client .create_endpoint_config .assert_called_with (
2827
+ EndpointConfigName = new_endpoint_name ,
2828
+ ProductionVariants = [
2829
+ {
2830
+ "CoreDumpConfig" : {
2831
+ "KmsKeyId" : expected_production_variant_0_kms_key_id ,
2832
+ "DestinationS3Uri" : pvs [0 ]["CoreDumpConfig" ]["DestinationS3Uri" ],
2833
+ },
2834
+ ** sagemaker .production_variant ("A" , "ml.g5.2xlarge" ),
2835
+ },
2836
+ {
2837
+ # Merge shouldn't happen because input for this index doesn't have DestinationS3Uri
2838
+ ** sagemaker .production_variant ("B" , "ml.p2.xlarge" ),
2839
+ },
2840
+ sagemaker .production_variant ("C" , "ml.p2.xlarge" ),
2841
+ ],
2842
+ KmsKeyId = expected_kms_key_id , # from config
2843
+ Tags = expected_tags , # from config
2844
+ AsyncInferenceConfig = {"OutputConfig" : {"KmsKeyId" : expected_inference_kms_key_id }},
2845
+ )
2846
+
2847
+
2848
+ def test_create_endpoint_config_from_existing_with_sagemaker_config_injection_no_kms_support (
2849
+ sagemaker_session ,
2850
+ ):
2851
+ sagemaker_session .sagemaker_config = SAGEMAKER_CONFIG_ENDPOINT_CONFIG
2852
+
2853
+ pvs = [
2854
+ sagemaker .production_variant ("A" , "ml.g5.2xlarge" ),
2855
+ sagemaker .production_variant ("B" , "ml.g5.xlarge" ),
2856
+ sagemaker .production_variant ("C" , "ml.g5.xlarge" ),
2857
+ ]
2858
+ # Add DestinationS3Uri to only one production variant
2859
+ pvs [0 ]["CoreDumpConfig" ] = {"DestinationS3Uri" : "s3://test" }
2860
+ existing_endpoint_arn = "arn:aws:sagemaker:us-west-2:123412341234:endpoint-config/foo"
2861
+ existing_endpoint_name = "foo"
2862
+ new_endpoint_name = "new-foo"
2863
+ sagemaker_session .sagemaker_client .describe_endpoint_config .return_value = {
2864
+ "ProductionVariants" : [sagemaker .production_variant ("A" , "ml.m4.xlarge" )],
2865
+ "EndpointConfigArn" : existing_endpoint_arn ,
2866
+ "AsyncInferenceConfig" : {},
2867
+ }
2868
+ sagemaker_session .sagemaker_client .list_tags .return_value = {"Tags" : []}
2869
+
2870
+ sagemaker_session .create_endpoint_config_from_existing (
2871
+ existing_endpoint_name , new_endpoint_name , new_production_variants = pvs
2872
+ )
2873
+
2874
+ expected_production_variant_0_kms_key_id = SAGEMAKER_CONFIG_ENDPOINT_CONFIG ["SageMaker" ][
2875
+ "EndpointConfig"
2876
+ ]["ProductionVariants" ][0 ]["CoreDumpConfig" ]["KmsKeyId" ]
2877
+ expected_inference_kms_key_id = SAGEMAKER_CONFIG_ENDPOINT_CONFIG ["SageMaker" ]["EndpointConfig" ][
2878
+ "AsyncInferenceConfig"
2879
+ ]["OutputConfig" ]["KmsKeyId" ]
2880
+
2881
+ expected_tags = SAGEMAKER_CONFIG_ENDPOINT_CONFIG ["SageMaker" ]["EndpointConfig" ]["Tags" ]
2882
+
2883
+ sagemaker_session .sagemaker_client .create_endpoint_config .assert_called_with (
2884
+ EndpointConfigName = new_endpoint_name ,
2885
+ ProductionVariants = [
2886
+ {
2887
+ "CoreDumpConfig" : {
2888
+ "KmsKeyId" : expected_production_variant_0_kms_key_id ,
2889
+ "DestinationS3Uri" : pvs [0 ]["CoreDumpConfig" ]["DestinationS3Uri" ],
2890
+ },
2891
+ ** sagemaker .production_variant ("A" , "ml.g5.2xlarge" ),
2892
+ },
2893
+ {
2894
+ # Merge shouldn't happen because input for this index doesn't have DestinationS3Uri
2895
+ ** sagemaker .production_variant ("B" , "ml.g5.xlarge" ),
2896
+ },
2897
+ sagemaker .production_variant ("C" , "ml.g5.xlarge" ),
2898
+ ],
2899
+ Tags = expected_tags , # from config
2900
+ AsyncInferenceConfig = {"OutputConfig" : {"KmsKeyId" : expected_inference_kms_key_id }},
2901
+ )
2902
+
2903
+
2789
2904
def test_endpoint_from_production_variants_with_sagemaker_config_injection (
2790
2905
sagemaker_session ,
2791
2906
):
@@ -2848,6 +2963,127 @@ def test_endpoint_from_production_variants_with_sagemaker_config_injection(
2848
2963
)
2849
2964
2850
2965
2966
+ def test_endpoint_from_production_variants_with_sagemaker_config_injection_partial_kms_support (
2967
+ sagemaker_session ,
2968
+ ):
2969
+ sagemaker_session .sagemaker_config = SAGEMAKER_CONFIG_ENDPOINT_CONFIG
2970
+
2971
+ sagemaker_session .sagemaker_client .describe_endpoint = Mock (
2972
+ return_value = {"EndpointStatus" : "InService" }
2973
+ )
2974
+ pvs = [
2975
+ sagemaker .production_variant ("A" , "ml.g5.xlarge" ),
2976
+ sagemaker .production_variant ("B" , "ml.p2.xlarge" ),
2977
+ sagemaker .production_variant ("C" , "ml.p2.xlarge" ),
2978
+ ]
2979
+ # Add DestinationS3Uri to only one production variant
2980
+ pvs [0 ]["CoreDumpConfig" ] = {"DestinationS3Uri" : "s3://test" }
2981
+ sagemaker_session .endpoint_from_production_variants (
2982
+ "some-endpoint" ,
2983
+ pvs ,
2984
+ data_capture_config_dict = {},
2985
+ async_inference_config_dict = AsyncInferenceConfig ()._to_request_dict (),
2986
+ )
2987
+ expected_data_capture_kms_key_id = SAGEMAKER_CONFIG_ENDPOINT_CONFIG ["SageMaker" ][
2988
+ "EndpointConfig"
2989
+ ]["DataCaptureConfig" ]["KmsKeyId" ]
2990
+ expected_inference_kms_key_id = SAGEMAKER_CONFIG_ENDPOINT_CONFIG ["SageMaker" ]["EndpointConfig" ][
2991
+ "AsyncInferenceConfig"
2992
+ ]["OutputConfig" ]["KmsKeyId" ]
2993
+ expected_kms_key_id = SAGEMAKER_CONFIG_ENDPOINT_CONFIG ["SageMaker" ]["EndpointConfig" ][
2994
+ "KmsKeyId"
2995
+ ]
2996
+ expected_tags = SAGEMAKER_CONFIG_ENDPOINT_CONFIG ["SageMaker" ]["EndpointConfig" ]["Tags" ]
2997
+
2998
+ expected_async_inference_config_dict = AsyncInferenceConfig ()._to_request_dict ()
2999
+ expected_async_inference_config_dict ["OutputConfig" ]["KmsKeyId" ] = expected_inference_kms_key_id
3000
+ expected_pvs = [
3001
+ sagemaker .production_variant ("A" , "ml.g5.xlarge" ),
3002
+ sagemaker .production_variant ("B" , "ml.p2.xlarge" ),
3003
+ sagemaker .production_variant ("C" , "ml.p2.xlarge" ),
3004
+ ]
3005
+ # Add DestinationS3Uri, KmsKeyId to only one production variant
3006
+ expected_production_variant_0_kms_key_id = SAGEMAKER_CONFIG_ENDPOINT_CONFIG ["SageMaker" ][
3007
+ "EndpointConfig"
3008
+ ]["ProductionVariants" ][0 ]["CoreDumpConfig" ]["KmsKeyId" ]
3009
+ expected_pvs [0 ]["CoreDumpConfig" ] = {
3010
+ "DestinationS3Uri" : "s3://test" ,
3011
+ "KmsKeyId" : expected_production_variant_0_kms_key_id ,
3012
+ }
3013
+ sagemaker_session .sagemaker_client .create_endpoint_config .assert_called_with (
3014
+ EndpointConfigName = "some-endpoint" ,
3015
+ ProductionVariants = expected_pvs ,
3016
+ Tags = expected_tags , # from config
3017
+ KmsKeyId = expected_kms_key_id , # from config
3018
+ AsyncInferenceConfig = expected_async_inference_config_dict ,
3019
+ DataCaptureConfig = {"KmsKeyId" : expected_data_capture_kms_key_id },
3020
+ )
3021
+ sagemaker_session .sagemaker_client .create_endpoint .assert_called_with (
3022
+ EndpointConfigName = "some-endpoint" ,
3023
+ EndpointName = "some-endpoint" ,
3024
+ Tags = expected_tags , # from config
3025
+ )
3026
+
3027
+
3028
+ def test_endpoint_from_production_variants_with_sagemaker_config_injection_no_kms_support (
3029
+ sagemaker_session ,
3030
+ ):
3031
+ sagemaker_session .sagemaker_config = SAGEMAKER_CONFIG_ENDPOINT_CONFIG
3032
+
3033
+ sagemaker_session .sagemaker_client .describe_endpoint = Mock (
3034
+ return_value = {"EndpointStatus" : "InService" }
3035
+ )
3036
+ pvs = [
3037
+ sagemaker .production_variant ("A" , "ml.g5.xlarge" ),
3038
+ sagemaker .production_variant ("B" , "ml.g5.xlarge" ),
3039
+ sagemaker .production_variant ("C" , "ml.g5.xlarge" ),
3040
+ ]
3041
+ # Add DestinationS3Uri to only one production variant
3042
+ pvs [0 ]["CoreDumpConfig" ] = {"DestinationS3Uri" : "s3://test" }
3043
+ sagemaker_session .endpoint_from_production_variants (
3044
+ "some-endpoint" ,
3045
+ pvs ,
3046
+ data_capture_config_dict = {},
3047
+ async_inference_config_dict = AsyncInferenceConfig ()._to_request_dict (),
3048
+ )
3049
+ expected_data_capture_kms_key_id = SAGEMAKER_CONFIG_ENDPOINT_CONFIG ["SageMaker" ][
3050
+ "EndpointConfig"
3051
+ ]["DataCaptureConfig" ]["KmsKeyId" ]
3052
+ expected_inference_kms_key_id = SAGEMAKER_CONFIG_ENDPOINT_CONFIG ["SageMaker" ]["EndpointConfig" ][
3053
+ "AsyncInferenceConfig"
3054
+ ]["OutputConfig" ]["KmsKeyId" ]
3055
+
3056
+ expected_tags = SAGEMAKER_CONFIG_ENDPOINT_CONFIG ["SageMaker" ]["EndpointConfig" ]["Tags" ]
3057
+
3058
+ expected_async_inference_config_dict = AsyncInferenceConfig ()._to_request_dict ()
3059
+ expected_async_inference_config_dict ["OutputConfig" ]["KmsKeyId" ] = expected_inference_kms_key_id
3060
+ expected_pvs = [
3061
+ sagemaker .production_variant ("A" , "ml.g5.xlarge" ),
3062
+ sagemaker .production_variant ("B" , "ml.g5.xlarge" ),
3063
+ sagemaker .production_variant ("C" , "ml.g5.xlarge" ),
3064
+ ]
3065
+ # Add DestinationS3Uri, KmsKeyId to only one production variant
3066
+ expected_production_variant_0_kms_key_id = SAGEMAKER_CONFIG_ENDPOINT_CONFIG ["SageMaker" ][
3067
+ "EndpointConfig"
3068
+ ]["ProductionVariants" ][0 ]["CoreDumpConfig" ]["KmsKeyId" ]
3069
+ expected_pvs [0 ]["CoreDumpConfig" ] = {
3070
+ "DestinationS3Uri" : "s3://test" ,
3071
+ "KmsKeyId" : expected_production_variant_0_kms_key_id ,
3072
+ }
3073
+ sagemaker_session .sagemaker_client .create_endpoint_config .assert_called_with (
3074
+ EndpointConfigName = "some-endpoint" ,
3075
+ ProductionVariants = expected_pvs ,
3076
+ Tags = expected_tags , # from config
3077
+ AsyncInferenceConfig = expected_async_inference_config_dict ,
3078
+ DataCaptureConfig = {"KmsKeyId" : expected_data_capture_kms_key_id },
3079
+ )
3080
+ sagemaker_session .sagemaker_client .create_endpoint .assert_called_with (
3081
+ EndpointConfigName = "some-endpoint" ,
3082
+ EndpointName = "some-endpoint" ,
3083
+ Tags = expected_tags , # from config
3084
+ )
3085
+
3086
+
2851
3087
def test_create_endpoint_config_with_tags (sagemaker_session ):
2852
3088
tags = [{"Key" : "TagtestKey" , "Value" : "TagtestValue" }]
2853
3089
0 commit comments