Skip to content

Commit e3aa83a

Browse files
committed
add unit and integration tests
1 parent a2dd355 commit e3aa83a

File tree

2 files changed

+85
-0
lines changed

2 files changed

+85
-0
lines changed

tests/integ/test_inference_pipeline.py

+50
Original file line numberDiff line numberDiff line change
@@ -148,3 +148,53 @@ def test_inference_pipeline_model_deploy(sagemaker_session):
148148
with pytest.raises(Exception) as exception:
149149
sagemaker_session.sagemaker_client.describe_model(ModelName=model.name)
150150
assert "Could not find model" in str(exception.value)
151+
152+
153+
def test_inference_pipeline_model_deploy_with_update_endpoint(sagemaker_session):
154+
sparkml_data_path = os.path.join(DATA_DIR, "sparkml_model")
155+
xgboost_data_path = os.path.join(DATA_DIR, "xgboost_model")
156+
endpoint_name = "test-inference-pipeline-deploy-{}".format(sagemaker_timestamp())
157+
sparkml_model_data = sagemaker_session.upload_data(
158+
path=os.path.join(sparkml_data_path, "mleap_model.tar.gz"),
159+
key_prefix="integ-test-data/sparkml/model",
160+
)
161+
xgb_model_data = sagemaker_session.upload_data(
162+
path=os.path.join(xgboost_data_path, "xgb_model.tar.gz"),
163+
key_prefix="integ-test-data/xgboost/model",
164+
)
165+
166+
with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
167+
sparkml_model = SparkMLModel(
168+
model_data=sparkml_model_data,
169+
env={"SAGEMAKER_SPARKML_SCHEMA": SCHEMA},
170+
sagemaker_session=sagemaker_session,
171+
)
172+
xgb_image = get_image_uri(sagemaker_session.boto_region_name, "xgboost")
173+
xgb_model = Model(
174+
model_data=xgb_model_data, image=xgb_image, sagemaker_session=sagemaker_session
175+
)
176+
model = PipelineModel(
177+
models=[sparkml_model, xgb_model],
178+
role="SageMakerRole",
179+
sagemaker_session=sagemaker_session,
180+
name=endpoint_name,
181+
)
182+
model.deploy(1, "ml.m4.xlarge", endpoint_name=endpoint_name)
183+
old_endpoint = sagemaker_session.describe_endpoint(EndpointName=endpoint_name)
184+
old_config_name = old_endpoint["EndpointConfigName"]
185+
186+
model.deploy(1, "ml.m4.xlarge", update_endpoint=True, endpoint_name=endpoint_name)
187+
new_endpoint = sagemaker_session.describe_endpoint(EndpointName=endpoint_name)[
188+
"ProductionVariants"
189+
]
190+
new_production_variants = new_endpoint["ProductionVariants"]
191+
new_config_name = new_endpoint["EndpointConfigName"]
192+
193+
assert old_config_name != new_config_name
194+
assert new_production_variants["InstanceType"] == "ml.m4.xlarge"
195+
assert new_production_variants["InitialInstanceCount"] == 1
196+
197+
model.delete_model()
198+
with pytest.raises(Exception) as exception:
199+
sagemaker_session.sagemaker_client.describe_model(ModelName=model.name)
200+
assert "Could not find model" in str(exception.value)

tests/unit/test_pipeline_model.py

+35
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,41 @@ def test_deploy_endpoint_name(tfo, time, sagemaker_session):
159159
)
160160

161161

162+
@patch("tarfile.open")
163+
@patch("time.strftime", return_value=TIMESTAMP)
164+
def test_deploy_update_endpoint(tfo, time, sagemaker_session):
165+
framework_model = DummyFrameworkModel(sagemaker_session)
166+
endpoint_name = "endpoint-name"
167+
sparkml_model = SparkMLModel(
168+
model_data=MODEL_DATA_2, role=ROLE, sagemaker_session=sagemaker_session
169+
)
170+
model = PipelineModel(
171+
models=[framework_model, sparkml_model], role=ROLE, sagemaker_session=sagemaker_session
172+
)
173+
model.deploy(
174+
instance_type=INSTANCE_TYPE,
175+
initial_instance_count=1,
176+
endpoint_name=endpoint_name,
177+
update_endpoint=True,
178+
)
179+
180+
sagemaker_session.create_endpoint_config.assert_called_with(
181+
name=model.name,
182+
model_name=model.name,
183+
initial_instance_count=INSTANCE_COUNT,
184+
instance_type=INSTANCE_TYPE,
185+
tags=None,
186+
)
187+
config_name = sagemaker_session.create_endpoint_config(
188+
name=model.name,
189+
model_name=model.name,
190+
initial_instance_count=INSTANCE_COUNT,
191+
instance_type=INSTANCE_TYPE,
192+
)
193+
sagemaker_session.update_endpoint.assert_called_with(endpoint_name, config_name)
194+
sagemaker_session.create_endpoint.assert_not_called()
195+
196+
162197
@patch("tarfile.open")
163198
@patch("time.strftime", return_value=TIMESTAMP)
164199
def test_transformer(tfo, time, sagemaker_session):

0 commit comments

Comments
 (0)