@@ -148,3 +148,53 @@ def test_inference_pipeline_model_deploy(sagemaker_session):
148
148
with pytest .raises (Exception ) as exception :
149
149
sagemaker_session .sagemaker_client .describe_model (ModelName = model .name )
150
150
assert "Could not find model" in str (exception .value )
151
+
152
+
153
+ def test_inference_pipeline_model_deploy_with_update_endpoint (sagemaker_session ):
154
+ sparkml_data_path = os .path .join (DATA_DIR , "sparkml_model" )
155
+ xgboost_data_path = os .path .join (DATA_DIR , "xgboost_model" )
156
+ endpoint_name = "test-inference-pipeline-deploy-{}" .format (sagemaker_timestamp ())
157
+ sparkml_model_data = sagemaker_session .upload_data (
158
+ path = os .path .join (sparkml_data_path , "mleap_model.tar.gz" ),
159
+ key_prefix = "integ-test-data/sparkml/model" ,
160
+ )
161
+ xgb_model_data = sagemaker_session .upload_data (
162
+ path = os .path .join (xgboost_data_path , "xgb_model.tar.gz" ),
163
+ key_prefix = "integ-test-data/xgboost/model" ,
164
+ )
165
+
166
+ with timeout_and_delete_endpoint_by_name (endpoint_name , sagemaker_session ):
167
+ sparkml_model = SparkMLModel (
168
+ model_data = sparkml_model_data ,
169
+ env = {"SAGEMAKER_SPARKML_SCHEMA" : SCHEMA },
170
+ sagemaker_session = sagemaker_session ,
171
+ )
172
+ xgb_image = get_image_uri (sagemaker_session .boto_region_name , "xgboost" )
173
+ xgb_model = Model (
174
+ model_data = xgb_model_data , image = xgb_image , sagemaker_session = sagemaker_session
175
+ )
176
+ model = PipelineModel (
177
+ models = [sparkml_model , xgb_model ],
178
+ role = "SageMakerRole" ,
179
+ sagemaker_session = sagemaker_session ,
180
+ name = endpoint_name ,
181
+ )
182
+ model .deploy (1 , "ml.m4.xlarge" , endpoint_name = endpoint_name )
183
+ old_endpoint = sagemaker_session .describe_endpoint (EndpointName = endpoint_name )
184
+ old_config_name = old_endpoint ["EndpointConfigName" ]
185
+
186
+ model .deploy (1 , "ml.m4.xlarge" , update_endpoint = True , endpoint_name = endpoint_name )
187
+ new_endpoint = sagemaker_session .describe_endpoint (EndpointName = endpoint_name )[
188
+ "ProductionVariants"
189
+ ]
190
+ new_production_variants = new_endpoint ["ProductionVariants" ]
191
+ new_config_name = new_endpoint ["EndpointConfigName" ]
192
+
193
+ assert old_config_name != new_config_name
194
+ assert new_production_variants ["InstanceType" ] == "ml.m4.xlarge"
195
+ assert new_production_variants ["InitialInstanceCount" ] == 1
196
+
197
+ model .delete_model ()
198
+ with pytest .raises (Exception ) as exception :
199
+ sagemaker_session .sagemaker_client .describe_model (ModelName = model .name )
200
+ assert "Could not find model" in str (exception .value )
0 commit comments