Skip to content

Commit 69d06ad

Browse files
feature: Support for multi variant endpoint invocation with target variant param (#1577)
Co-authored-by: Chuyang <[email protected]>
1 parent 6a8bb6d commit 69d06ad

File tree

5 files changed

+359
-4
lines changed

5 files changed

+359
-4
lines changed

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def read_version():
3434

3535
# Declare minimal set for installation
3636
required_packages = [
37-
"boto3>=1.13.6",
37+
"boto3>=1.13.24",
3838
"numpy>=1.9.0",
3939
"protobuf>=3.1",
4040
"scipy>=0.19.0",

src/sagemaker/local/local_session.py

+4
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,7 @@ def invoke_endpoint(
343343
Accept=None,
344344
CustomAttributes=None,
345345
TargetModel=None,
346+
TargetVariant=None,
346347
):
347348
"""
348349
@@ -370,6 +371,9 @@ def invoke_endpoint(
370371
if TargetModel is not None:
371372
headers["X-Amzn-SageMaker-Target-Model"] = TargetModel
372373

374+
if TargetVariant is not None:
375+
headers["X-Amzn-SageMaker-Target-Variant"] = TargetVariant
376+
373377
r = self.http.request("POST", url, body=Body, preload_content=False, headers=headers)
374378

375379
return {"Body": r, "ContentType": Accept}

src/sagemaker/predictor.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def __init__(
8383
self._endpoint_config_name = self._get_endpoint_config_name()
8484
self._model_names = self._get_model_names()
8585

86-
def predict(self, data, initial_args=None, target_model=None):
86+
def predict(self, data, initial_args=None, target_model=None, target_variant=None):
8787
"""Return the inference from the specified endpoint.
8888
8989
Args:
@@ -98,6 +98,9 @@ def predict(self, data, initial_args=None, target_model=None):
9898
target_model (str): S3 model artifact path to run an inference request on,
9999
in case of a multi model endpoint. Does not apply to endpoints hosting
100100
single model (Default: None)
101+
target_variant (str): The name of the production variant to run an inference
102+
request on (Default: None). Note that the ProductionVariant identifies the model
103+
you want to host and the resources you want to deploy for hosting it.
101104
102105
Returns:
103106
object: Inference for the given input. If a deserializer was specified when creating
@@ -106,7 +109,7 @@ def predict(self, data, initial_args=None, target_model=None):
106109
as is.
107110
"""
108111

109-
request_args = self._create_request_args(data, initial_args, target_model)
112+
request_args = self._create_request_args(data, initial_args, target_model, target_variant)
110113
response = self.sagemaker_session.sagemaker_runtime_client.invoke_endpoint(**request_args)
111114
return self._handle_response(response)
112115

@@ -123,12 +126,13 @@ def _handle_response(self, response):
123126
response_body.close()
124127
return data
125128

126-
def _create_request_args(self, data, initial_args=None, target_model=None):
129+
def _create_request_args(self, data, initial_args=None, target_model=None, target_variant=None):
127130
"""
128131
Args:
129132
data:
130133
initial_args:
131134
target_model:
135+
target_variant:
132136
"""
133137
args = dict(initial_args) if initial_args else {}
134138

@@ -144,6 +148,9 @@ def _create_request_args(self, data, initial_args=None, target_model=None):
144148
if target_model:
145149
args["TargetModel"] = target_model
146150

151+
if target_variant:
152+
args["TargetVariant"] = target_variant
153+
147154
if self.serializer is not None:
148155
data = self.serializer(data)
149156

+318
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,318 @@
1+
# Copyright 2019-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import json
16+
import os
17+
import math
18+
import pytest
19+
import scipy.stats as st
20+
21+
from sagemaker.s3 import S3Uploader
22+
from sagemaker.session import production_variant
23+
from sagemaker.sparkml import SparkMLModel
24+
from sagemaker.utils import sagemaker_timestamp
25+
from sagemaker.content_types import CONTENT_TYPE_CSV
26+
from sagemaker.utils import unique_name_from_base
27+
from sagemaker.amazon.amazon_estimator import get_image_uri
28+
from sagemaker.predictor import csv_serializer, RealTimePredictor
29+
30+
31+
import tests.integ
32+
33+
34+
ROLE = "SageMakerRole"
35+
MODEL_NAME = "test-xgboost-model-{}".format(sagemaker_timestamp())
36+
DEFAULT_REGION = "us-west-2"
37+
DEFAULT_INSTANCE_TYPE = "ml.m5.xlarge"
38+
DEFAULT_INSTANCE_COUNT = 1
39+
XG_BOOST_MODEL_LOCAL_PATH = os.path.join(tests.integ.DATA_DIR, "xgboost_model", "xgb_model.tar.gz")
40+
41+
TEST_VARIANT_1 = "Variant1"
42+
TEST_VARIANT_1_WEIGHT = 0.3
43+
44+
TEST_VARIANT_2 = "Variant2"
45+
TEST_VARIANT_2_WEIGHT = 0.7
46+
47+
VARIANT_TRAFFIC_SAMPLING_COUNT = 100
48+
DESIRED_CONFIDENCE_FOR_VARIANT_TRAFFIC_DISTRIBUTION = 0.999
49+
50+
TEST_CSV_DATA = "42,42,42,42,42,42,42"
51+
52+
SPARK_ML_MODEL_LOCAL_PATH = os.path.join(
53+
tests.integ.DATA_DIR, "sparkml_model", "mleap_model.tar.gz"
54+
)
55+
SPARK_ML_DEFAULT_VARIANT_NAME = (
56+
"AllTraffic"
57+
) # default defined in src/sagemaker/session.py def production_variant
58+
SPARK_ML_WRONG_VARIANT_NAME = "WRONG_VARIANT"
59+
SPARK_ML_TEST_DATA = "1.0,C,38.0,71.5,1.0,female"
60+
SPARK_ML_MODEL_SCHEMA = json.dumps(
61+
{
62+
"input": [
63+
{"name": "Pclass", "type": "float"},
64+
{"name": "Embarked", "type": "string"},
65+
{"name": "Age", "type": "float"},
66+
{"name": "Fare", "type": "float"},
67+
{"name": "SibSp", "type": "float"},
68+
{"name": "Sex", "type": "string"},
69+
],
70+
"output": {"name": "features", "struct": "vector", "type": "double"},
71+
}
72+
)
73+
74+
75+
@pytest.fixture(scope="module")
76+
def multi_variant_endpoint(sagemaker_session):
77+
"""
78+
Sets up the multi variant endpoint before the integration tests run.
79+
Cleans up the multi variant endpoint after the integration tests run.
80+
"""
81+
multi_variant_endpoint.endpoint_name = unique_name_from_base(
82+
"integ-test-multi-variant-endpoint"
83+
)
84+
with tests.integ.timeout.timeout_and_delete_endpoint_by_name(
85+
endpoint_name=multi_variant_endpoint.endpoint_name,
86+
sagemaker_session=sagemaker_session,
87+
hours=2,
88+
):
89+
90+
# Creating a model
91+
bucket = sagemaker_session.default_bucket()
92+
prefix = "sagemaker/DEMO-VariantTargeting"
93+
model_url = S3Uploader.upload(
94+
local_path=XG_BOOST_MODEL_LOCAL_PATH,
95+
desired_s3_uri="s3://" + bucket + "/" + prefix,
96+
session=sagemaker_session,
97+
)
98+
99+
image_uri = get_image_uri(sagemaker_session.boto_session.region_name, "xgboost", "0.90-1")
100+
101+
multi_variant_endpoint_model = sagemaker_session.create_model(
102+
name=MODEL_NAME,
103+
role=ROLE,
104+
container_defs={"Image": image_uri, "ModelDataUrl": model_url},
105+
)
106+
107+
# Creating a multi variant endpoint
108+
variant1 = production_variant(
109+
model_name=MODEL_NAME,
110+
instance_type=DEFAULT_INSTANCE_TYPE,
111+
initial_instance_count=DEFAULT_INSTANCE_COUNT,
112+
variant_name=TEST_VARIANT_1,
113+
initial_weight=TEST_VARIANT_1_WEIGHT,
114+
)
115+
variant2 = production_variant(
116+
model_name=MODEL_NAME,
117+
instance_type=DEFAULT_INSTANCE_TYPE,
118+
initial_instance_count=DEFAULT_INSTANCE_COUNT,
119+
variant_name=TEST_VARIANT_2,
120+
initial_weight=TEST_VARIANT_2_WEIGHT,
121+
)
122+
sagemaker_session.endpoint_from_production_variants(
123+
name=multi_variant_endpoint.endpoint_name, production_variants=[variant1, variant2]
124+
)
125+
126+
# Yield to run the integration tests
127+
yield multi_variant_endpoint
128+
129+
# Cleanup resources
130+
sagemaker_session.delete_model(multi_variant_endpoint_model)
131+
sagemaker_session.sagemaker_client.delete_endpoint_config(
132+
EndpointConfigName=multi_variant_endpoint.endpoint_name
133+
)
134+
135+
# Validate resource cleanup
136+
with pytest.raises(Exception) as exception:
137+
sagemaker_session.sagemaker_client.describe_model(
138+
ModelName=multi_variant_endpoint_model.name
139+
)
140+
assert "Could not find model" in str(exception.value)
141+
sagemaker_session.sagemaker_client.describe_endpoint_config(
142+
name=multi_variant_endpoint.endpoint_name
143+
)
144+
assert "Could not find endpoint" in str(exception.value)
145+
146+
147+
def test_target_variant_invocation(sagemaker_session, multi_variant_endpoint):
148+
149+
response = sagemaker_session.sagemaker_runtime_client.invoke_endpoint(
150+
EndpointName=multi_variant_endpoint.endpoint_name,
151+
Body=TEST_CSV_DATA,
152+
ContentType=CONTENT_TYPE_CSV,
153+
Accept=CONTENT_TYPE_CSV,
154+
TargetVariant=TEST_VARIANT_1,
155+
)
156+
assert response["InvokedProductionVariant"] == TEST_VARIANT_1
157+
158+
response = sagemaker_session.sagemaker_runtime_client.invoke_endpoint(
159+
EndpointName=multi_variant_endpoint.endpoint_name,
160+
Body=TEST_CSV_DATA,
161+
ContentType=CONTENT_TYPE_CSV,
162+
Accept=CONTENT_TYPE_CSV,
163+
TargetVariant=TEST_VARIANT_2,
164+
)
165+
assert response["InvokedProductionVariant"] == TEST_VARIANT_2
166+
167+
168+
def test_predict_invocation_with_target_variant(sagemaker_session, multi_variant_endpoint):
169+
predictor = RealTimePredictor(
170+
endpoint=multi_variant_endpoint.endpoint_name,
171+
sagemaker_session=sagemaker_session,
172+
serializer=csv_serializer,
173+
content_type=CONTENT_TYPE_CSV,
174+
accept=CONTENT_TYPE_CSV,
175+
)
176+
177+
# Validate that no exception is raised when the target_variant is specified.
178+
predictor.predict(TEST_CSV_DATA, target_variant=TEST_VARIANT_1)
179+
predictor.predict(TEST_CSV_DATA, target_variant=TEST_VARIANT_2)
180+
181+
182+
def test_variant_traffic_distribution(sagemaker_session, multi_variant_endpoint):
183+
variant_1_invocation_count = 0
184+
variant_2_invocation_count = 0
185+
186+
for i in range(0, VARIANT_TRAFFIC_SAMPLING_COUNT):
187+
response = sagemaker_session.sagemaker_runtime_client.invoke_endpoint(
188+
EndpointName=multi_variant_endpoint.endpoint_name,
189+
Body=TEST_CSV_DATA,
190+
ContentType=CONTENT_TYPE_CSV,
191+
Accept=CONTENT_TYPE_CSV,
192+
)
193+
if response["InvokedProductionVariant"] == TEST_VARIANT_1:
194+
variant_1_invocation_count += 1
195+
elif response["InvokedProductionVariant"] == TEST_VARIANT_2:
196+
variant_2_invocation_count += 1
197+
198+
assert variant_1_invocation_count + variant_2_invocation_count == VARIANT_TRAFFIC_SAMPLING_COUNT
199+
200+
variant_1_invocation_percentage = float(variant_1_invocation_count) / float(
201+
VARIANT_TRAFFIC_SAMPLING_COUNT
202+
)
203+
variant_1_margin_of_error = _compute_and_retrieve_margin_of_error(TEST_VARIANT_1_WEIGHT)
204+
assert variant_1_invocation_percentage < TEST_VARIANT_1_WEIGHT + variant_1_margin_of_error
205+
assert variant_1_invocation_percentage > TEST_VARIANT_1_WEIGHT - variant_1_margin_of_error
206+
207+
variant_2_invocation_percentage = float(variant_2_invocation_count) / float(
208+
VARIANT_TRAFFIC_SAMPLING_COUNT
209+
)
210+
variant_2_margin_of_error = _compute_and_retrieve_margin_of_error(TEST_VARIANT_2_WEIGHT)
211+
assert variant_2_invocation_percentage < TEST_VARIANT_2_WEIGHT + variant_2_margin_of_error
212+
assert variant_2_invocation_percentage > TEST_VARIANT_2_WEIGHT - variant_2_margin_of_error
213+
214+
215+
def test_spark_ml_predict_invocation_with_target_variant(sagemaker_session):
216+
217+
spark_ml_model_endpoint_name = unique_name_from_base("integ-test-target-variant-sparkml")
218+
219+
model_data = sagemaker_session.upload_data(
220+
path=SPARK_ML_MODEL_LOCAL_PATH, key_prefix="integ-test-data/sparkml/model"
221+
)
222+
223+
with tests.integ.timeout.timeout_and_delete_endpoint_by_name(
224+
spark_ml_model_endpoint_name, sagemaker_session
225+
):
226+
spark_ml_model = SparkMLModel(
227+
model_data=model_data,
228+
role=ROLE,
229+
sagemaker_session=sagemaker_session,
230+
env={"SAGEMAKER_SPARKML_SCHEMA": SPARK_ML_MODEL_SCHEMA},
231+
)
232+
233+
predictor = spark_ml_model.deploy(
234+
DEFAULT_INSTANCE_COUNT,
235+
DEFAULT_INSTANCE_TYPE,
236+
endpoint_name=spark_ml_model_endpoint_name,
237+
)
238+
239+
# Validate that no exception is raised when the target_variant is specified.
240+
predictor.predict(SPARK_ML_TEST_DATA, target_variant=SPARK_ML_DEFAULT_VARIANT_NAME)
241+
242+
with pytest.raises(Exception) as exception_info:
243+
predictor.predict(SPARK_ML_TEST_DATA, target_variant=SPARK_ML_WRONG_VARIANT_NAME)
244+
245+
assert "ValidationError" in str(exception_info.value)
246+
assert SPARK_ML_WRONG_VARIANT_NAME in str(exception_info.value)
247+
248+
# cleanup resources
249+
spark_ml_model.delete_model()
250+
sagemaker_session.sagemaker_client.delete_endpoint_config(
251+
EndpointConfigName=spark_ml_model_endpoint_name
252+
)
253+
254+
# Validate resource cleanup
255+
with pytest.raises(Exception) as exception:
256+
sagemaker_session.sagemaker_client.describe_model(ModelName=spark_ml_model.name)
257+
assert "Could not find model" in str(exception.value)
258+
sagemaker_session.sagemaker_client.describe_endpoint_config(
259+
name=spark_ml_model_endpoint_name
260+
)
261+
assert "Could not find endpoint" in str(exception.value)
262+
263+
264+
@pytest.mark.local_mode
265+
def test_target_variant_invocation_local_mode(sagemaker_session, multi_variant_endpoint):
266+
267+
if sagemaker_session._region_name is None:
268+
sagemaker_session._region_name = DEFAULT_REGION
269+
270+
response = sagemaker_session.sagemaker_runtime_client.invoke_endpoint(
271+
EndpointName=multi_variant_endpoint.endpoint_name,
272+
Body=TEST_CSV_DATA,
273+
ContentType=CONTENT_TYPE_CSV,
274+
Accept=CONTENT_TYPE_CSV,
275+
TargetVariant=TEST_VARIANT_1,
276+
)
277+
assert response["InvokedProductionVariant"] == TEST_VARIANT_1
278+
279+
response = sagemaker_session.sagemaker_runtime_client.invoke_endpoint(
280+
EndpointName=multi_variant_endpoint.endpoint_name,
281+
Body=TEST_CSV_DATA,
282+
ContentType=CONTENT_TYPE_CSV,
283+
Accept=CONTENT_TYPE_CSV,
284+
TargetVariant=TEST_VARIANT_2,
285+
)
286+
assert response["InvokedProductionVariant"] == TEST_VARIANT_2
287+
288+
289+
@pytest.mark.local_mode
290+
def test_predict_invocation_with_target_variant_local_mode(
291+
sagemaker_session, multi_variant_endpoint
292+
):
293+
294+
if sagemaker_session._region_name is None:
295+
sagemaker_session._region_name = DEFAULT_REGION
296+
297+
predictor = RealTimePredictor(
298+
endpoint=multi_variant_endpoint.endpoint_name,
299+
sagemaker_session=sagemaker_session,
300+
serializer=csv_serializer,
301+
content_type=CONTENT_TYPE_CSV,
302+
accept=CONTENT_TYPE_CSV,
303+
)
304+
305+
# Validate that no exception is raised when the target_variant is specified.
306+
predictor.predict(TEST_CSV_DATA, target_variant=TEST_VARIANT_1)
307+
predictor.predict(TEST_CSV_DATA, target_variant=TEST_VARIANT_2)
308+
309+
310+
def _compute_and_retrieve_margin_of_error(variant_weight):
311+
"""
312+
Computes the margin of error using the Wald method for computing the confidence
313+
intervals of a binomial distribution.
314+
"""
315+
z_value = st.norm.ppf(DESIRED_CONFIDENCE_FOR_VARIANT_TRAFFIC_DISTRIBUTION)
316+
margin_of_error = (variant_weight * (1 - variant_weight)) / VARIANT_TRAFFIC_SAMPLING_COUNT
317+
margin_of_error = z_value * math.sqrt(margin_of_error)
318+
return margin_of_error

0 commit comments

Comments
 (0)