Skip to content

feature: Support for multi variant endpoint invocation with target variant param #1571

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/sagemaker/local/local_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,7 @@ def invoke_endpoint(
Accept=None,
CustomAttributes=None,
TargetModel=None,
TargetVariant=None,
):
"""

Expand Down Expand Up @@ -370,6 +371,9 @@ def invoke_endpoint(
if TargetModel is not None:
headers["X-Amzn-SageMaker-Target-Model"] = TargetModel

if TargetVariant is not None:
headers["X-Amzn-SageMaker-Target-Variant"] = TargetVariant

r = self.http.request("POST", url, body=Body, preload_content=False, headers=headers)

return {"Body": r, "ContentType": Accept}
Expand Down
13 changes: 10 additions & 3 deletions src/sagemaker/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def __init__(
self._endpoint_config_name = self._get_endpoint_config_name()
self._model_names = self._get_model_names()

def predict(self, data, initial_args=None, target_model=None):
def predict(self, data, initial_args=None, target_model=None, target_variant=None):
"""Return the inference from the specified endpoint.

Args:
Expand All @@ -98,6 +98,9 @@ def predict(self, data, initial_args=None, target_model=None):
target_model (str): S3 model artifact path to run an inference request on,
in case of a multi model endpoint. Does not apply to endpoints hosting
single model (Default: None)
target_variant (str): The name of the production variant to run an inference
request on (Default: None). Note that the ProductionVariant identifies the model
you want to host and the resources you want to deploy for hosting it.

Returns:
object: Inference for the given input. If a deserializer was specified when creating
Expand All @@ -106,7 +109,7 @@ def predict(self, data, initial_args=None, target_model=None):
as is.
"""

request_args = self._create_request_args(data, initial_args, target_model)
request_args = self._create_request_args(data, initial_args, target_model, target_variant)
response = self.sagemaker_session.sagemaker_runtime_client.invoke_endpoint(**request_args)
return self._handle_response(response)

Expand All @@ -123,12 +126,13 @@ def _handle_response(self, response):
response_body.close()
return data

def _create_request_args(self, data, initial_args=None, target_model=None):
def _create_request_args(self, data, initial_args=None, target_model=None, target_variant=None):
"""
Args:
data:
initial_args:
target_model:
target_variant:
"""
args = dict(initial_args) if initial_args else {}

Expand All @@ -144,6 +148,9 @@ def _create_request_args(self, data, initial_args=None, target_model=None):
if target_model:
args["TargetModel"] = target_model

if target_variant:
args["TargetVariant"] = target_variant

if self.serializer is not None:
data = self.serializer(data)

Expand Down
309 changes: 309 additions & 0 deletions tests/integ/test_multi_variant_endpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,309 @@
# Copyright 2019-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import json
import os
import math
import pytest
import scipy.stats as st

from sagemaker.s3 import S3Uploader
from sagemaker.session import production_variant
from sagemaker.sparkml import SparkMLModel
from sagemaker.utils import sagemaker_timestamp
from sagemaker.content_types import CONTENT_TYPE_CSV
from sagemaker.utils import unique_name_from_base
from sagemaker.amazon.amazon_estimator import get_image_uri
from sagemaker.predictor import csv_serializer, RealTimePredictor


import tests.integ


ROLE = "SageMakerRole"
MODEL_NAME = "test-xgboost-model-{}".format(sagemaker_timestamp())
ENDPOINT_NAME = unique_name_from_base("integ-test-multi-variant-endpoint")
DEFAULT_REGION = "us-west-2"
DEFAULT_INSTANCE_TYPE = "ml.m5.xlarge"
DEFAULT_INSTANCE_COUNT = 1
XG_BOOST_MODEL_LOCAL_PATH = os.path.join(tests.integ.DATA_DIR, "xgboost_model", "xgb_model.tar.gz")

TEST_VARIANT_1 = "Variant1"
TEST_VARIANT_1_WEIGHT = 0.3

TEST_VARIANT_2 = "Variant2"
TEST_VARIANT_2_WEIGHT = 0.7

VARIANT_TRAFFIC_SAMPLING_COUNT = 100
DESIRED_CONFIDENCE_FOR_VARIANT_TRAFFIC_DISTRIBUTION = 0.999

TEST_CSV_DATA = "42,42,42,42,42,42,42"

SPARK_ML_MODEL_LOCAL_PATH = os.path.join(
tests.integ.DATA_DIR, "sparkml_model", "mleap_model.tar.gz"
)
SPARK_ML_MODEL_ENDPOINT_NAME = unique_name_from_base("integ-test-target-variant-sparkml")
SPARK_ML_DEFAULT_VARIANT_NAME = (
"AllTraffic"
) # default defined in src/sagemaker/session.py def production_variant
SPARK_ML_WRONG_VARIANT_NAME = "WRONG_VARIANT"
SPARK_ML_TEST_DATA = "1.0,C,38.0,71.5,1.0,female"
SPARK_ML_MODEL_SCHEMA = json.dumps(
{
"input": [
{"name": "Pclass", "type": "float"},
{"name": "Embarked", "type": "string"},
{"name": "Age", "type": "float"},
{"name": "Fare", "type": "float"},
{"name": "SibSp", "type": "float"},
{"name": "Sex", "type": "string"},
],
"output": {"name": "features", "struct": "vector", "type": "double"},
}
)


@pytest.fixture(scope="module")
def multi_variant_endpoint(sagemaker_session):
"""
Sets up the multi variant endpoint before the integration tests run.
Cleans up the multi variant endpoint after the integration tests run.
"""

with tests.integ.timeout.timeout_and_delete_endpoint_by_name(
endpoint_name=ENDPOINT_NAME, sagemaker_session=sagemaker_session, hours=2
):

# Creating a model
bucket = sagemaker_session.default_bucket()
prefix = "sagemaker/DEMO-VariantTargeting"
model_url = S3Uploader.upload(
local_path=XG_BOOST_MODEL_LOCAL_PATH,
desired_s3_uri="s3://" + bucket + "/" + prefix,
session=sagemaker_session,
)

image_uri = get_image_uri(sagemaker_session.boto_session.region_name, "xgboost", "0.90-1")

multi_variant_endpoint_model = sagemaker_session.create_model(
name=MODEL_NAME,
role=ROLE,
container_defs={"Image": image_uri, "ModelDataUrl": model_url},
)

# Creating a multi variant endpoint
variant1 = production_variant(
model_name=MODEL_NAME,
instance_type=DEFAULT_INSTANCE_TYPE,
initial_instance_count=DEFAULT_INSTANCE_COUNT,
variant_name=TEST_VARIANT_1,
initial_weight=TEST_VARIANT_1_WEIGHT,
)
variant2 = production_variant(
model_name=MODEL_NAME,
instance_type=DEFAULT_INSTANCE_TYPE,
initial_instance_count=DEFAULT_INSTANCE_COUNT,
variant_name=TEST_VARIANT_2,
initial_weight=TEST_VARIANT_2_WEIGHT,
)
sagemaker_session.endpoint_from_production_variants(
name=ENDPOINT_NAME, production_variants=[variant1, variant2]
)

# Yield to run the integration tests
yield multi_variant_endpoint

# Cleanup resources
sagemaker_session.delete_model(multi_variant_endpoint_model)
sagemaker_session.sagemaker_client.delete_endpoint_config(EndpointConfigName=ENDPOINT_NAME)

# Validate resource cleanup
with pytest.raises(Exception) as exception:
sagemaker_session.sagemaker_client.describe_model(
ModelName=multi_variant_endpoint_model.name
)
assert "Could not find model" in str(exception.value)
sagemaker_session.sagemaker_client.describe_endpoint_config(name=ENDPOINT_NAME)
assert "Could not find endpoint" in str(exception.value)


def test_target_variant_invocation(sagemaker_session, multi_variant_endpoint):

response = sagemaker_session.sagemaker_runtime_client.invoke_endpoint(
EndpointName=ENDPOINT_NAME,
Body=TEST_CSV_DATA,
ContentType=CONTENT_TYPE_CSV,
Accept=CONTENT_TYPE_CSV,
TargetVariant=TEST_VARIANT_1,
)
assert response["InvokedProductionVariant"] == TEST_VARIANT_1

response = sagemaker_session.sagemaker_runtime_client.invoke_endpoint(
EndpointName=ENDPOINT_NAME,
Body=TEST_CSV_DATA,
ContentType=CONTENT_TYPE_CSV,
Accept=CONTENT_TYPE_CSV,
TargetVariant=TEST_VARIANT_2,
)
assert response["InvokedProductionVariant"] == TEST_VARIANT_2


def test_predict_invocation_with_target_variant(sagemaker_session, multi_variant_endpoint):
predictor = RealTimePredictor(
endpoint=ENDPOINT_NAME,
sagemaker_session=sagemaker_session,
serializer=csv_serializer,
content_type=CONTENT_TYPE_CSV,
accept=CONTENT_TYPE_CSV,
)

# Validate that no exception is raised when the target_variant is specified.
predictor.predict(TEST_CSV_DATA, target_variant=TEST_VARIANT_1)
predictor.predict(TEST_CSV_DATA, target_variant=TEST_VARIANT_2)


def test_variant_traffic_distribution(sagemaker_session, multi_variant_endpoint):
variant_1_invocation_count = 0
variant_2_invocation_count = 0

for i in range(0, VARIANT_TRAFFIC_SAMPLING_COUNT):
response = sagemaker_session.sagemaker_runtime_client.invoke_endpoint(
EndpointName=ENDPOINT_NAME,
Body=TEST_CSV_DATA,
ContentType=CONTENT_TYPE_CSV,
Accept=CONTENT_TYPE_CSV,
)
if response["InvokedProductionVariant"] == TEST_VARIANT_1:
variant_1_invocation_count += 1
elif response["InvokedProductionVariant"] == TEST_VARIANT_2:
variant_2_invocation_count += 1

assert variant_1_invocation_count + variant_2_invocation_count == VARIANT_TRAFFIC_SAMPLING_COUNT

variant_1_invocation_percentage = float(variant_1_invocation_count) / float(
VARIANT_TRAFFIC_SAMPLING_COUNT
)
variant_1_margin_of_error = _compute_and_retrieve_margin_of_error(TEST_VARIANT_1_WEIGHT)
assert variant_1_invocation_percentage < TEST_VARIANT_1_WEIGHT + variant_1_margin_of_error
assert variant_1_invocation_percentage > TEST_VARIANT_1_WEIGHT - variant_1_margin_of_error

variant_2_invocation_percentage = float(variant_2_invocation_count) / float(
VARIANT_TRAFFIC_SAMPLING_COUNT
)
variant_2_margin_of_error = _compute_and_retrieve_margin_of_error(TEST_VARIANT_2_WEIGHT)
assert variant_2_invocation_percentage < TEST_VARIANT_2_WEIGHT + variant_2_margin_of_error
assert variant_2_invocation_percentage > TEST_VARIANT_2_WEIGHT - variant_2_margin_of_error


def test_spark_ml_predict_invocation_with_target_variant(sagemaker_session):
model_data = sagemaker_session.upload_data(
path=SPARK_ML_MODEL_LOCAL_PATH, key_prefix="integ-test-data/sparkml/model"
)

with tests.integ.timeout.timeout_and_delete_endpoint_by_name(
SPARK_ML_MODEL_ENDPOINT_NAME, sagemaker_session
):
spark_ml_model = SparkMLModel(
model_data=model_data,
role=ROLE,
sagemaker_session=sagemaker_session,
env={"SAGEMAKER_SPARKML_SCHEMA": SPARK_ML_MODEL_SCHEMA},
)

predictor = spark_ml_model.deploy(
DEFAULT_INSTANCE_COUNT,
DEFAULT_INSTANCE_TYPE,
endpoint_name=SPARK_ML_MODEL_ENDPOINT_NAME,
)

# Validate that no exception is raised when the target_variant is specified.
predictor.predict(SPARK_ML_TEST_DATA, target_variant=SPARK_ML_DEFAULT_VARIANT_NAME)

with pytest.raises(Exception) as exception_info:
predictor.predict(SPARK_ML_TEST_DATA, target_variant=SPARK_ML_WRONG_VARIANT_NAME)

assert "ValidationError" in str(exception_info.value)
assert SPARK_ML_WRONG_VARIANT_NAME in str(exception_info.value)

# cleanup resources
spark_ml_model.delete_model()
sagemaker_session.sagemaker_client.delete_endpoint_config(
EndpointConfigName=SPARK_ML_MODEL_ENDPOINT_NAME
)

# Validate resource cleanup
with pytest.raises(Exception) as exception:
sagemaker_session.sagemaker_client.describe_model(ModelName=spark_ml_model.name)
assert "Could not find model" in str(exception.value)
sagemaker_session.sagemaker_client.describe_endpoint_config(
name=SPARK_ML_MODEL_ENDPOINT_NAME
)
assert "Could not find endpoint" in str(exception.value)


@pytest.mark.local_mode
def test_target_variant_invocation_local_mode(sagemaker_session, multi_variant_endpoint):

if sagemaker_session._region_name is None:
sagemaker_session._region_name = DEFAULT_REGION

response = sagemaker_session.sagemaker_runtime_client.invoke_endpoint(
EndpointName=ENDPOINT_NAME,
Body=TEST_CSV_DATA,
ContentType=CONTENT_TYPE_CSV,
Accept=CONTENT_TYPE_CSV,
TargetVariant=TEST_VARIANT_1,
)
assert response["InvokedProductionVariant"] == TEST_VARIANT_1

response = sagemaker_session.sagemaker_runtime_client.invoke_endpoint(
EndpointName=ENDPOINT_NAME,
Body=TEST_CSV_DATA,
ContentType=CONTENT_TYPE_CSV,
Accept=CONTENT_TYPE_CSV,
TargetVariant=TEST_VARIANT_2,
)
assert response["InvokedProductionVariant"] == TEST_VARIANT_2


@pytest.mark.local_mode
def test_predict_invocation_with_target_variant_local_mode(
sagemaker_session, multi_variant_endpoint
):

if sagemaker_session._region_name is None:
sagemaker_session._region_name = DEFAULT_REGION

predictor = RealTimePredictor(
endpoint=ENDPOINT_NAME,
sagemaker_session=sagemaker_session,
serializer=csv_serializer,
content_type=CONTENT_TYPE_CSV,
accept=CONTENT_TYPE_CSV,
)

# Validate that no exception is raised when the target_variant is specified.
predictor.predict(TEST_CSV_DATA, target_variant=TEST_VARIANT_1)
predictor.predict(TEST_CSV_DATA, target_variant=TEST_VARIANT_2)


def _compute_and_retrieve_margin_of_error(variant_weight):
"""
Computes the margin of error using the Wald method for computing the confidence
intervals of a binomial distribution.
"""
z_value = st.norm.ppf(DESIRED_CONFIDENCE_FOR_VARIANT_TRAFFIC_DISTRIBUTION)
margin_of_error = (variant_weight * (1 - variant_weight)) / VARIANT_TRAFFIC_SAMPLING_COUNT
margin_of_error = z_value * math.sqrt(margin_of_error)
return margin_of_error
Loading