|
14 | 14 |
|
15 | 15 | import json
|
16 | 16 | import os
|
| 17 | +import time |
17 | 18 |
|
18 | 19 | import pytest
|
19 | 20 | from tests.integ import DATA_DIR, TRANSFORM_DEFAULT_TIMEOUT_MINUTES
|
@@ -148,3 +149,68 @@ def test_inference_pipeline_model_deploy(sagemaker_session, cpu_instance_type):
|
148 | 149 | with pytest.raises(Exception) as exception:
|
149 | 150 | sagemaker_session.sagemaker_client.describe_model(ModelName=model.name)
|
150 | 151 | assert "Could not find model" in str(exception.value)
|
| 152 | + |
| 153 | + |
| 154 | +def test_inference_pipeline_model_deploy_with_update_endpoint(sagemaker_session): |
| 155 | + sparkml_data_path = os.path.join(DATA_DIR, "sparkml_model") |
| 156 | + xgboost_data_path = os.path.join(DATA_DIR, "xgboost_model") |
| 157 | + endpoint_name = "test-inference-pipeline-deploy-{}".format(sagemaker_timestamp()) |
| 158 | + sparkml_model_data = sagemaker_session.upload_data( |
| 159 | + path=os.path.join(sparkml_data_path, "mleap_model.tar.gz"), |
| 160 | + key_prefix="integ-test-data/sparkml/model", |
| 161 | + ) |
| 162 | + xgb_model_data = sagemaker_session.upload_data( |
| 163 | + path=os.path.join(xgboost_data_path, "xgb_model.tar.gz"), |
| 164 | + key_prefix="integ-test-data/xgboost/model", |
| 165 | + ) |
| 166 | + |
| 167 | + with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session): |
| 168 | + sparkml_model = SparkMLModel( |
| 169 | + model_data=sparkml_model_data, |
| 170 | + env={"SAGEMAKER_SPARKML_SCHEMA": SCHEMA}, |
| 171 | + sagemaker_session=sagemaker_session, |
| 172 | + ) |
| 173 | + xgb_image = get_image_uri(sagemaker_session.boto_region_name, "xgboost") |
| 174 | + xgb_model = Model( |
| 175 | + model_data=xgb_model_data, image=xgb_image, sagemaker_session=sagemaker_session |
| 176 | + ) |
| 177 | + model = PipelineModel( |
| 178 | + models=[sparkml_model, xgb_model], |
| 179 | + role="SageMakerRole", |
| 180 | + sagemaker_session=sagemaker_session, |
| 181 | + ) |
| 182 | + model.deploy(1, "ml.t2.medium", endpoint_name=endpoint_name) |
| 183 | + old_endpoint = sagemaker_session.sagemaker_client.describe_endpoint( |
| 184 | + EndpointName=endpoint_name |
| 185 | + ) |
| 186 | + old_config_name = old_endpoint["EndpointConfigName"] |
| 187 | + |
| 188 | + model.deploy(1, "ml.m4.xlarge", update_endpoint=True, endpoint_name=endpoint_name) |
| 189 | + |
| 190 | + # Wait for endpoint to finish updating |
| 191 | + max_retry_count = 40 # Endpoint update takes ~7min. 40 retries * 30s sleeps = 20min timeout |
| 192 | + current_retry_count = 0 |
| 193 | + while current_retry_count <= max_retry_count: |
| 194 | + if current_retry_count >= max_retry_count: |
| 195 | + raise Exception("Endpoint status not 'InService' within expected timeout.") |
| 196 | + time.sleep(30) |
| 197 | + new_endpoint = sagemaker_session.sagemaker_client.describe_endpoint( |
| 198 | + EndpointName=endpoint_name |
| 199 | + ) |
| 200 | + current_retry_count += 1 |
| 201 | + if new_endpoint["EndpointStatus"] == "InService": |
| 202 | + break |
| 203 | + |
| 204 | + new_config_name = new_endpoint["EndpointConfigName"] |
| 205 | + new_config = sagemaker_session.sagemaker_client.describe_endpoint_config( |
| 206 | + EndpointConfigName=new_config_name |
| 207 | + ) |
| 208 | + |
| 209 | + assert old_config_name != new_config_name |
| 210 | + assert new_config["ProductionVariants"][0]["InstanceType"] == "ml.m4.xlarge" |
| 211 | + assert new_config["ProductionVariants"][0]["InitialInstanceCount"] == 1 |
| 212 | + |
| 213 | + model.delete_model() |
| 214 | + with pytest.raises(Exception) as exception: |
| 215 | + sagemaker_session.sagemaker_client.describe_model(ModelName=model.name) |
| 216 | + assert "Could not find model" in str(exception.value) |
0 commit comments