@@ -144,3 +144,53 @@ def test_inference_pipeline_model_deploy(sagemaker_session):
144
144
with pytest .raises (Exception ) as exception :
145
145
sagemaker_session .sagemaker_client .describe_model (ModelName = model .name )
146
146
assert "Could not find model" in str (exception .value )
147
+
148
+
149
+ def test_inference_pipeline_model_deploy_with_update_endpoint (sagemaker_session ):
150
+ sparkml_data_path = os .path .join (DATA_DIR , "sparkml_model" )
151
+ xgboost_data_path = os .path .join (DATA_DIR , "xgboost_model" )
152
+ endpoint_name = "test-inference-pipeline-deploy-{}" .format (sagemaker_timestamp ())
153
+ sparkml_model_data = sagemaker_session .upload_data (
154
+ path = os .path .join (sparkml_data_path , "mleap_model.tar.gz" ),
155
+ key_prefix = "integ-test-data/sparkml/model" ,
156
+ )
157
+ xgb_model_data = sagemaker_session .upload_data (
158
+ path = os .path .join (xgboost_data_path , "xgb_model.tar.gz" ),
159
+ key_prefix = "integ-test-data/xgboost/model" ,
160
+ )
161
+
162
+ with timeout_and_delete_endpoint_by_name (endpoint_name , sagemaker_session ):
163
+ sparkml_model = SparkMLModel (
164
+ model_data = sparkml_model_data ,
165
+ env = {"SAGEMAKER_SPARKML_SCHEMA" : SCHEMA },
166
+ sagemaker_session = sagemaker_session ,
167
+ )
168
+ xgb_image = get_image_uri (sagemaker_session .boto_region_name , "xgboost" )
169
+ xgb_model = Model (
170
+ model_data = xgb_model_data , image = xgb_image , sagemaker_session = sagemaker_session
171
+ )
172
+ model = PipelineModel (
173
+ models = [sparkml_model , xgb_model ],
174
+ role = "SageMakerRole" ,
175
+ sagemaker_session = sagemaker_session ,
176
+ name = endpoint_name ,
177
+ )
178
+ model .deploy (1 , "ml.m4.xlarge" , endpoint_name = endpoint_name )
179
+ old_endpoint = sagemaker_session .describe_endpoint (EndpointName = endpoint_name )
180
+ old_config_name = old_endpoint ["EndpointConfigName" ]
181
+
182
+ model .deploy (1 , "ml.m4.xlarge" , update_endpoint = True , endpoint_name = endpoint_name )
183
+ new_endpoint = sagemaker_session .describe_endpoint (EndpointName = endpoint_name )[
184
+ "ProductionVariants"
185
+ ]
186
+ new_production_variants = new_endpoint ["ProductionVariants" ]
187
+ new_config_name = new_endpoint ["EndpointConfigName" ]
188
+
189
+ assert old_config_name != new_config_name
190
+ assert new_production_variants ["InstanceType" ] == "ml.m4.xlarge"
191
+ assert new_production_variants ["InitialInstanceCount" ] == 1
192
+
193
+ model .delete_model ()
194
+ with pytest .raises (Exception ) as exception :
195
+ sagemaker_session .sagemaker_client .describe_model (ModelName = model .name )
196
+ assert "Could not find model" in str (exception .value )
0 commit comments