Skip to content

Revert "feature: Support for multi variant endpoint invocation with target variant param" #1574

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions src/sagemaker/local/local_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,6 @@ def invoke_endpoint(
Accept=None,
CustomAttributes=None,
TargetModel=None,
TargetVariant=None,
):
"""

Expand Down Expand Up @@ -371,9 +370,6 @@ def invoke_endpoint(
if TargetModel is not None:
headers["X-Amzn-SageMaker-Target-Model"] = TargetModel

if TargetVariant is not None:
headers["X-Amzn-SageMaker-Target-Variant"] = TargetVariant

r = self.http.request("POST", url, body=Body, preload_content=False, headers=headers)

return {"Body": r, "ContentType": Accept}
Expand Down
13 changes: 3 additions & 10 deletions src/sagemaker/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def __init__(
self._endpoint_config_name = self._get_endpoint_config_name()
self._model_names = self._get_model_names()

def predict(self, data, initial_args=None, target_model=None, target_variant=None):
def predict(self, data, initial_args=None, target_model=None):
"""Return the inference from the specified endpoint.

Args:
Expand All @@ -98,9 +98,6 @@ def predict(self, data, initial_args=None, target_model=None, target_variant=Non
target_model (str): S3 model artifact path to run an inference request on,
in case of a multi model endpoint. Does not apply to endpoints hosting
single model (Default: None)
target_variant (str): The name of the production variant to run an inference
request on (Default: None). Note that the ProductionVariant identifies the model
you want to host and the resources you want to deploy for hosting it.

Returns:
object: Inference for the given input. If a deserializer was specified when creating
Expand All @@ -109,7 +106,7 @@ def predict(self, data, initial_args=None, target_model=None, target_variant=Non
as is.
"""

request_args = self._create_request_args(data, initial_args, target_model, target_variant)
request_args = self._create_request_args(data, initial_args, target_model)
response = self.sagemaker_session.sagemaker_runtime_client.invoke_endpoint(**request_args)
return self._handle_response(response)

Expand All @@ -126,13 +123,12 @@ def _handle_response(self, response):
response_body.close()
return data

def _create_request_args(self, data, initial_args=None, target_model=None, target_variant=None):
def _create_request_args(self, data, initial_args=None, target_model=None):
"""
Args:
data:
initial_args:
target_model:
target_variant:
"""
args = dict(initial_args) if initial_args else {}

Expand All @@ -148,9 +144,6 @@ def _create_request_args(self, data, initial_args=None, target_model=None, targe
if target_model:
args["TargetModel"] = target_model

if target_variant:
args["TargetVariant"] = target_variant

if self.serializer is not None:
data = self.serializer(data)

Expand Down
309 changes: 0 additions & 309 deletions tests/integ/test_multi_variant_endpoint.py

This file was deleted.

Loading