@@ -83,7 +83,7 @@ def __init__(
83
83
self ._endpoint_config_name = self ._get_endpoint_config_name ()
84
84
self ._model_names = self ._get_model_names ()
85
85
86
- def predict (self , data , initial_args = None , target_model = None , target_variant = None ):
86
+ def predict (self , data , initial_args = None , target_model = None ):
87
87
"""Return the inference from the specified endpoint.
88
88
89
89
Args:
@@ -98,9 +98,6 @@ def predict(self, data, initial_args=None, target_model=None, target_variant=Non
98
98
target_model (str): S3 model artifact path to run an inference request on,
99
99
in case of a multi model endpoint. Does not apply to endpoints hosting
100
100
single model (Default: None)
101
- target_variant (str): The name of the production variant to run an inference
102
- request on (Default: None). Note that the ProductionVariant identifies the model
103
- you want to host and the resources you want to deploy for hosting it.
104
101
105
102
Returns:
106
103
object: Inference for the given input. If a deserializer was specified when creating
@@ -109,7 +106,7 @@ def predict(self, data, initial_args=None, target_model=None, target_variant=Non
109
106
as is.
110
107
"""
111
108
112
- request_args = self ._create_request_args (data , initial_args , target_model , target_variant )
109
+ request_args = self ._create_request_args (data , initial_args , target_model )
113
110
response = self .sagemaker_session .sagemaker_runtime_client .invoke_endpoint (** request_args )
114
111
return self ._handle_response (response )
115
112
@@ -126,13 +123,12 @@ def _handle_response(self, response):
126
123
response_body .close ()
127
124
return data
128
125
129
- def _create_request_args (self , data , initial_args = None , target_model = None , target_variant = None ):
126
+ def _create_request_args (self , data , initial_args = None , target_model = None ):
130
127
"""
131
128
Args:
132
129
data:
133
130
initial_args:
134
131
target_model:
135
- target_variant:
136
132
"""
137
133
args = dict (initial_args ) if initial_args else {}
138
134
@@ -148,9 +144,6 @@ def _create_request_args(self, data, initial_args=None, target_model=None, targe
148
144
if target_model :
149
145
args ["TargetModel" ] = target_model
150
146
151
- if target_variant :
152
- args ["TargetVariant" ] = target_variant
153
-
154
147
if self .serializer is not None :
155
148
data = self .serializer (data )
156
149
0 commit comments