@@ -115,80 +115,6 @@ def _construct_payload(
115
115
return payload_to_use
116
116
117
117
118
- def _extract_generated_text_from_response (
119
- response : dict ,
120
- model_id : str ,
121
- model_version : str ,
122
- region : Optional [str ] = None ,
123
- tolerate_vulnerable_model : bool = False ,
124
- tolerate_deprecated_model : bool = False ,
125
- sagemaker_session : Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION ,
126
- accept_type : Optional [str ] = None ,
127
- ) -> str :
128
- """Returns generated text extracted from full response payload.
129
-
130
- Args:
131
- response (dict): Dictionary-valued response from which to extract
132
- generated text.
133
- model_id (str): JumpStart model ID of the JumpStart model from which to extract
134
- generated text.
135
- model_version (str): Version of the JumpStart model for which to extract generated
136
- text.
137
- region (Optional[str]): Region for which to extract generated
138
- text. (Default: None).
139
- tolerate_vulnerable_model (bool): True if vulnerable versions of model
140
- specifications should be tolerated (exception not raised). If False, raises an
141
- exception if the script used by this version of the model has dependencies with known
142
- security vulnerabilities. (Default: False).
143
- tolerate_deprecated_model (bool): True if deprecated versions of model
144
- specifications should be tolerated (exception not raised). If False, raises
145
- an exception if the version of the model is deprecated. (Default: False).
146
- sagemaker_session (sagemaker.session.Session): A SageMaker Session
147
- object, used for SageMaker interactions. If not
148
- specified, one is created using the default AWS configuration
149
- chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
150
- accept_type (Optional[str]): The accept type to optionally specify for the response.
151
- (Default: None).
152
-
153
- Returns:
154
- str: extracted generated text from the endpoint response payload.
155
-
156
- Raises:
157
- ValueError: If the model is invalid, the model does not support generated text extraction,
158
- or if the response is malformed.
159
- """
160
-
161
- if not isinstance (response , dict ):
162
- raise ValueError (f"Response must be dictionary. Instead, got: { type (response )} " )
163
-
164
- payloads : Optional [Dict [str , JumpStartSerializablePayload ]] = _retrieve_example_payloads (
165
- model_id = model_id ,
166
- model_version = model_version ,
167
- region = region ,
168
- tolerate_vulnerable_model = tolerate_vulnerable_model ,
169
- tolerate_deprecated_model = tolerate_deprecated_model ,
170
- sagemaker_session = sagemaker_session ,
171
- )
172
- if payloads is None or len (payloads ) == 0 :
173
- raise ValueError (f"Model ID '{ model_id } ' does not support generated text extraction." )
174
-
175
- for payload in payloads .values ():
176
- if accept_type is None or payload .accept == accept_type :
177
- generated_text_response_key : Optional [str ] = payload .generated_text_response_key
178
- if generated_text_response_key is None :
179
- raise ValueError (
180
- f"Model ID '{ model_id } ' does not support generated text extraction."
181
- )
182
-
183
- generated_text_response_key_split = generated_text_response_key .split ("." )
184
- try :
185
- return _extract_field_from_json (response , generated_text_response_key_split )
186
- except KeyError :
187
- raise ValueError (f"Response is malformed: { response } " )
188
-
189
- raise ValueError (f"Model ID '{ model_id } ' does not support generated text extraction." )
190
-
191
-
192
118
class PayloadSerializer :
193
119
"""Utility class for serializing payloads associated with JumpStart models.
194
120
0 commit comments