Skip to content

Commit b481c36

Browse files
authored
feature: Support for multi variant endpoint invocation with target variant param (#1571)
1 parent 2416254 commit b481c36

File tree

4 files changed

+349
-3
lines changed

4 files changed

+349
-3
lines changed

src/sagemaker/local/local_session.py

+4
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,7 @@ def invoke_endpoint(
343343
Accept=None,
344344
CustomAttributes=None,
345345
TargetModel=None,
346+
TargetVariant=None,
346347
):
347348
"""
348349
@@ -370,6 +371,9 @@ def invoke_endpoint(
370371
if TargetModel is not None:
371372
headers["X-Amzn-SageMaker-Target-Model"] = TargetModel
372373

374+
if TargetVariant is not None:
375+
headers["X-Amzn-SageMaker-Target-Variant"] = TargetVariant
376+
373377
r = self.http.request("POST", url, body=Body, preload_content=False, headers=headers)
374378

375379
return {"Body": r, "ContentType": Accept}

src/sagemaker/predictor.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def __init__(
8383
self._endpoint_config_name = self._get_endpoint_config_name()
8484
self._model_names = self._get_model_names()
8585

86-
def predict(self, data, initial_args=None, target_model=None):
86+
def predict(self, data, initial_args=None, target_model=None, target_variant=None):
8787
"""Return the inference from the specified endpoint.
8888
8989
Args:
@@ -98,6 +98,9 @@ def predict(self, data, initial_args=None, target_model=None):
9898
target_model (str): S3 model artifact path to run an inference request on,
9999
in case of a multi model endpoint. Does not apply to endpoints hosting
100100
single model (Default: None)
101+
target_variant (str): The name of the production variant to run an inference
102+
request on (Default: None). Note that the ProductionVariant identifies the model
103+
you want to host and the resources you want to deploy for hosting it.
101104
102105
Returns:
103106
object: Inference for the given input. If a deserializer was specified when creating
@@ -106,7 +109,7 @@ def predict(self, data, initial_args=None, target_model=None):
106109
as is.
107110
"""
108111

109-
request_args = self._create_request_args(data, initial_args, target_model)
112+
request_args = self._create_request_args(data, initial_args, target_model, target_variant)
110113
response = self.sagemaker_session.sagemaker_runtime_client.invoke_endpoint(**request_args)
111114
return self._handle_response(response)
112115

@@ -123,12 +126,13 @@ def _handle_response(self, response):
123126
response_body.close()
124127
return data
125128

126-
def _create_request_args(self, data, initial_args=None, target_model=None):
129+
def _create_request_args(self, data, initial_args=None, target_model=None, target_variant=None):
127130
"""
128131
Args:
129132
data:
130133
initial_args:
131134
target_model:
135+
target_variant:
132136
"""
133137
args = dict(initial_args) if initial_args else {}
134138

@@ -144,6 +148,9 @@ def _create_request_args(self, data, initial_args=None, target_model=None):
144148
if target_model:
145149
args["TargetModel"] = target_model
146150

151+
if target_variant:
152+
args["TargetVariant"] = target_variant
153+
147154
if self.serializer is not None:
148155
data = self.serializer(data)
149156

+309
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,309 @@
1+
# Copyright 2019-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import json
16+
import os
17+
import math
18+
import pytest
19+
import scipy.stats as st
20+
21+
from sagemaker.s3 import S3Uploader
22+
from sagemaker.session import production_variant
23+
from sagemaker.sparkml import SparkMLModel
24+
from sagemaker.utils import sagemaker_timestamp
25+
from sagemaker.content_types import CONTENT_TYPE_CSV
26+
from sagemaker.utils import unique_name_from_base
27+
from sagemaker.amazon.amazon_estimator import get_image_uri
28+
from sagemaker.predictor import csv_serializer, RealTimePredictor
29+
30+
31+
import tests.integ
32+
33+
34+
ROLE = "SageMakerRole"
35+
MODEL_NAME = "test-xgboost-model-{}".format(sagemaker_timestamp())
36+
ENDPOINT_NAME = unique_name_from_base("integ-test-multi-variant-endpoint")
37+
DEFAULT_REGION = "us-west-2"
38+
DEFAULT_INSTANCE_TYPE = "ml.m5.xlarge"
39+
DEFAULT_INSTANCE_COUNT = 1
40+
XG_BOOST_MODEL_LOCAL_PATH = os.path.join(tests.integ.DATA_DIR, "xgboost_model", "xgb_model.tar.gz")
41+
42+
TEST_VARIANT_1 = "Variant1"
43+
TEST_VARIANT_1_WEIGHT = 0.3
44+
45+
TEST_VARIANT_2 = "Variant2"
46+
TEST_VARIANT_2_WEIGHT = 0.7
47+
48+
VARIANT_TRAFFIC_SAMPLING_COUNT = 100
49+
DESIRED_CONFIDENCE_FOR_VARIANT_TRAFFIC_DISTRIBUTION = 0.999
50+
51+
TEST_CSV_DATA = "42,42,42,42,42,42,42"
52+
53+
SPARK_ML_MODEL_LOCAL_PATH = os.path.join(
54+
tests.integ.DATA_DIR, "sparkml_model", "mleap_model.tar.gz"
55+
)
56+
SPARK_ML_MODEL_ENDPOINT_NAME = unique_name_from_base("integ-test-target-variant-sparkml")
57+
SPARK_ML_DEFAULT_VARIANT_NAME = (
58+
"AllTraffic"
59+
) # default defined in src/sagemaker/session.py def production_variant
60+
SPARK_ML_WRONG_VARIANT_NAME = "WRONG_VARIANT"
61+
SPARK_ML_TEST_DATA = "1.0,C,38.0,71.5,1.0,female"
62+
SPARK_ML_MODEL_SCHEMA = json.dumps(
63+
{
64+
"input": [
65+
{"name": "Pclass", "type": "float"},
66+
{"name": "Embarked", "type": "string"},
67+
{"name": "Age", "type": "float"},
68+
{"name": "Fare", "type": "float"},
69+
{"name": "SibSp", "type": "float"},
70+
{"name": "Sex", "type": "string"},
71+
],
72+
"output": {"name": "features", "struct": "vector", "type": "double"},
73+
}
74+
)
75+
76+
77+
@pytest.fixture(scope="module")
78+
def multi_variant_endpoint(sagemaker_session):
79+
"""
80+
Sets up the multi variant endpoint before the integration tests run.
81+
Cleans up the multi variant endpoint after the integration tests run.
82+
"""
83+
84+
with tests.integ.timeout.timeout_and_delete_endpoint_by_name(
85+
endpoint_name=ENDPOINT_NAME, sagemaker_session=sagemaker_session, hours=2
86+
):
87+
88+
# Creating a model
89+
bucket = sagemaker_session.default_bucket()
90+
prefix = "sagemaker/DEMO-VariantTargeting"
91+
model_url = S3Uploader.upload(
92+
local_path=XG_BOOST_MODEL_LOCAL_PATH,
93+
desired_s3_uri="s3://" + bucket + "/" + prefix,
94+
session=sagemaker_session,
95+
)
96+
97+
image_uri = get_image_uri(sagemaker_session.boto_session.region_name, "xgboost", "0.90-1")
98+
99+
multi_variant_endpoint_model = sagemaker_session.create_model(
100+
name=MODEL_NAME,
101+
role=ROLE,
102+
container_defs={"Image": image_uri, "ModelDataUrl": model_url},
103+
)
104+
105+
# Creating a multi variant endpoint
106+
variant1 = production_variant(
107+
model_name=MODEL_NAME,
108+
instance_type=DEFAULT_INSTANCE_TYPE,
109+
initial_instance_count=DEFAULT_INSTANCE_COUNT,
110+
variant_name=TEST_VARIANT_1,
111+
initial_weight=TEST_VARIANT_1_WEIGHT,
112+
)
113+
variant2 = production_variant(
114+
model_name=MODEL_NAME,
115+
instance_type=DEFAULT_INSTANCE_TYPE,
116+
initial_instance_count=DEFAULT_INSTANCE_COUNT,
117+
variant_name=TEST_VARIANT_2,
118+
initial_weight=TEST_VARIANT_2_WEIGHT,
119+
)
120+
sagemaker_session.endpoint_from_production_variants(
121+
name=ENDPOINT_NAME, production_variants=[variant1, variant2]
122+
)
123+
124+
# Yield to run the integration tests
125+
yield multi_variant_endpoint
126+
127+
# Cleanup resources
128+
sagemaker_session.delete_model(multi_variant_endpoint_model)
129+
sagemaker_session.sagemaker_client.delete_endpoint_config(EndpointConfigName=ENDPOINT_NAME)
130+
131+
# Validate resource cleanup
132+
with pytest.raises(Exception) as exception:
133+
sagemaker_session.sagemaker_client.describe_model(
134+
ModelName=multi_variant_endpoint_model.name
135+
)
136+
assert "Could not find model" in str(exception.value)
137+
sagemaker_session.sagemaker_client.describe_endpoint_config(name=ENDPOINT_NAME)
138+
assert "Could not find endpoint" in str(exception.value)
139+
140+
141+
def test_target_variant_invocation(sagemaker_session, multi_variant_endpoint):
142+
143+
response = sagemaker_session.sagemaker_runtime_client.invoke_endpoint(
144+
EndpointName=ENDPOINT_NAME,
145+
Body=TEST_CSV_DATA,
146+
ContentType=CONTENT_TYPE_CSV,
147+
Accept=CONTENT_TYPE_CSV,
148+
TargetVariant=TEST_VARIANT_1,
149+
)
150+
assert response["InvokedProductionVariant"] == TEST_VARIANT_1
151+
152+
response = sagemaker_session.sagemaker_runtime_client.invoke_endpoint(
153+
EndpointName=ENDPOINT_NAME,
154+
Body=TEST_CSV_DATA,
155+
ContentType=CONTENT_TYPE_CSV,
156+
Accept=CONTENT_TYPE_CSV,
157+
TargetVariant=TEST_VARIANT_2,
158+
)
159+
assert response["InvokedProductionVariant"] == TEST_VARIANT_2
160+
161+
162+
def test_predict_invocation_with_target_variant(sagemaker_session, multi_variant_endpoint):
163+
predictor = RealTimePredictor(
164+
endpoint=ENDPOINT_NAME,
165+
sagemaker_session=sagemaker_session,
166+
serializer=csv_serializer,
167+
content_type=CONTENT_TYPE_CSV,
168+
accept=CONTENT_TYPE_CSV,
169+
)
170+
171+
# Validate that no exception is raised when the target_variant is specified.
172+
predictor.predict(TEST_CSV_DATA, target_variant=TEST_VARIANT_1)
173+
predictor.predict(TEST_CSV_DATA, target_variant=TEST_VARIANT_2)
174+
175+
176+
def test_variant_traffic_distribution(sagemaker_session, multi_variant_endpoint):
177+
variant_1_invocation_count = 0
178+
variant_2_invocation_count = 0
179+
180+
for i in range(0, VARIANT_TRAFFIC_SAMPLING_COUNT):
181+
response = sagemaker_session.sagemaker_runtime_client.invoke_endpoint(
182+
EndpointName=ENDPOINT_NAME,
183+
Body=TEST_CSV_DATA,
184+
ContentType=CONTENT_TYPE_CSV,
185+
Accept=CONTENT_TYPE_CSV,
186+
)
187+
if response["InvokedProductionVariant"] == TEST_VARIANT_1:
188+
variant_1_invocation_count += 1
189+
elif response["InvokedProductionVariant"] == TEST_VARIANT_2:
190+
variant_2_invocation_count += 1
191+
192+
assert variant_1_invocation_count + variant_2_invocation_count == VARIANT_TRAFFIC_SAMPLING_COUNT
193+
194+
variant_1_invocation_percentage = float(variant_1_invocation_count) / float(
195+
VARIANT_TRAFFIC_SAMPLING_COUNT
196+
)
197+
variant_1_margin_of_error = _compute_and_retrieve_margin_of_error(TEST_VARIANT_1_WEIGHT)
198+
assert variant_1_invocation_percentage < TEST_VARIANT_1_WEIGHT + variant_1_margin_of_error
199+
assert variant_1_invocation_percentage > TEST_VARIANT_1_WEIGHT - variant_1_margin_of_error
200+
201+
variant_2_invocation_percentage = float(variant_2_invocation_count) / float(
202+
VARIANT_TRAFFIC_SAMPLING_COUNT
203+
)
204+
variant_2_margin_of_error = _compute_and_retrieve_margin_of_error(TEST_VARIANT_2_WEIGHT)
205+
assert variant_2_invocation_percentage < TEST_VARIANT_2_WEIGHT + variant_2_margin_of_error
206+
assert variant_2_invocation_percentage > TEST_VARIANT_2_WEIGHT - variant_2_margin_of_error
207+
208+
209+
def test_spark_ml_predict_invocation_with_target_variant(sagemaker_session):
210+
model_data = sagemaker_session.upload_data(
211+
path=SPARK_ML_MODEL_LOCAL_PATH, key_prefix="integ-test-data/sparkml/model"
212+
)
213+
214+
with tests.integ.timeout.timeout_and_delete_endpoint_by_name(
215+
SPARK_ML_MODEL_ENDPOINT_NAME, sagemaker_session
216+
):
217+
spark_ml_model = SparkMLModel(
218+
model_data=model_data,
219+
role=ROLE,
220+
sagemaker_session=sagemaker_session,
221+
env={"SAGEMAKER_SPARKML_SCHEMA": SPARK_ML_MODEL_SCHEMA},
222+
)
223+
224+
predictor = spark_ml_model.deploy(
225+
DEFAULT_INSTANCE_COUNT,
226+
DEFAULT_INSTANCE_TYPE,
227+
endpoint_name=SPARK_ML_MODEL_ENDPOINT_NAME,
228+
)
229+
230+
# Validate that no exception is raised when the target_variant is specified.
231+
predictor.predict(SPARK_ML_TEST_DATA, target_variant=SPARK_ML_DEFAULT_VARIANT_NAME)
232+
233+
with pytest.raises(Exception) as exception_info:
234+
predictor.predict(SPARK_ML_TEST_DATA, target_variant=SPARK_ML_WRONG_VARIANT_NAME)
235+
236+
assert "ValidationError" in str(exception_info.value)
237+
assert SPARK_ML_WRONG_VARIANT_NAME in str(exception_info.value)
238+
239+
# cleanup resources
240+
spark_ml_model.delete_model()
241+
sagemaker_session.sagemaker_client.delete_endpoint_config(
242+
EndpointConfigName=SPARK_ML_MODEL_ENDPOINT_NAME
243+
)
244+
245+
# Validate resource cleanup
246+
with pytest.raises(Exception) as exception:
247+
sagemaker_session.sagemaker_client.describe_model(ModelName=spark_ml_model.name)
248+
assert "Could not find model" in str(exception.value)
249+
sagemaker_session.sagemaker_client.describe_endpoint_config(
250+
name=SPARK_ML_MODEL_ENDPOINT_NAME
251+
)
252+
assert "Could not find endpoint" in str(exception.value)
253+
254+
255+
@pytest.mark.local_mode
256+
def test_target_variant_invocation_local_mode(sagemaker_session, multi_variant_endpoint):
257+
258+
if sagemaker_session._region_name is None:
259+
sagemaker_session._region_name = DEFAULT_REGION
260+
261+
response = sagemaker_session.sagemaker_runtime_client.invoke_endpoint(
262+
EndpointName=ENDPOINT_NAME,
263+
Body=TEST_CSV_DATA,
264+
ContentType=CONTENT_TYPE_CSV,
265+
Accept=CONTENT_TYPE_CSV,
266+
TargetVariant=TEST_VARIANT_1,
267+
)
268+
assert response["InvokedProductionVariant"] == TEST_VARIANT_1
269+
270+
response = sagemaker_session.sagemaker_runtime_client.invoke_endpoint(
271+
EndpointName=ENDPOINT_NAME,
272+
Body=TEST_CSV_DATA,
273+
ContentType=CONTENT_TYPE_CSV,
274+
Accept=CONTENT_TYPE_CSV,
275+
TargetVariant=TEST_VARIANT_2,
276+
)
277+
assert response["InvokedProductionVariant"] == TEST_VARIANT_2
278+
279+
280+
@pytest.mark.local_mode
281+
def test_predict_invocation_with_target_variant_local_mode(
282+
sagemaker_session, multi_variant_endpoint
283+
):
284+
285+
if sagemaker_session._region_name is None:
286+
sagemaker_session._region_name = DEFAULT_REGION
287+
288+
predictor = RealTimePredictor(
289+
endpoint=ENDPOINT_NAME,
290+
sagemaker_session=sagemaker_session,
291+
serializer=csv_serializer,
292+
content_type=CONTENT_TYPE_CSV,
293+
accept=CONTENT_TYPE_CSV,
294+
)
295+
296+
# Validate that no exception is raised when the target_variant is specified.
297+
predictor.predict(TEST_CSV_DATA, target_variant=TEST_VARIANT_1)
298+
predictor.predict(TEST_CSV_DATA, target_variant=TEST_VARIANT_2)
299+
300+
301+
def _compute_and_retrieve_margin_of_error(variant_weight):
302+
"""
303+
Computes the margin of error using the Wald method for computing the confidence
304+
intervals of a binomial distribution.
305+
"""
306+
z_value = st.norm.ppf(DESIRED_CONFIDENCE_FOR_VARIANT_TRAFFIC_DISTRIBUTION)
307+
margin_of_error = (variant_weight * (1 - variant_weight)) / VARIANT_TRAFFIC_SAMPLING_COUNT
308+
margin_of_error = z_value * math.sqrt(margin_of_error)
309+
return margin_of_error

0 commit comments

Comments
 (0)