Skip to content

Commit b9a8ef1

Browse files
committed
chore: remove deprecated generated_text_response_key from JumpStartSerializablePayload
1 parent cc99a57 commit b9a8ef1

File tree

5 files changed

+1
-150
lines changed

5 files changed

+1
-150
lines changed

src/sagemaker/jumpstart/payload_utils.py

-74
Original file line numberDiff line numberDiff line change
@@ -115,80 +115,6 @@ def _construct_payload(
115115
return payload_to_use
116116

117117

118-
def _extract_generated_text_from_response(
119-
response: dict,
120-
model_id: str,
121-
model_version: str,
122-
region: Optional[str] = None,
123-
tolerate_vulnerable_model: bool = False,
124-
tolerate_deprecated_model: bool = False,
125-
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
126-
accept_type: Optional[str] = None,
127-
) -> str:
128-
"""Returns generated text extracted from full response payload.
129-
130-
Args:
131-
response (dict): Dictionary-valued response from which to extract
132-
generated text.
133-
model_id (str): JumpStart model ID of the JumpStart model from which to extract
134-
generated text.
135-
model_version (str): Version of the JumpStart model for which to extract generated
136-
text.
137-
region (Optional[str]): Region for which to extract generated
138-
text. (Default: None).
139-
tolerate_vulnerable_model (bool): True if vulnerable versions of model
140-
specifications should be tolerated (exception not raised). If False, raises an
141-
exception if the script used by this version of the model has dependencies with known
142-
security vulnerabilities. (Default: False).
143-
tolerate_deprecated_model (bool): True if deprecated versions of model
144-
specifications should be tolerated (exception not raised). If False, raises
145-
an exception if the version of the model is deprecated. (Default: False).
146-
sagemaker_session (sagemaker.session.Session): A SageMaker Session
147-
object, used for SageMaker interactions. If not
148-
specified, one is created using the default AWS configuration
149-
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
150-
accept_type (Optional[str]): The accept type to optionally specify for the response.
151-
(Default: None).
152-
153-
Returns:
154-
str: extracted generated text from the endpoint response payload.
155-
156-
Raises:
157-
ValueError: If the model is invalid, the model does not support generated text extraction,
158-
or if the response is malformed.
159-
"""
160-
161-
if not isinstance(response, dict):
162-
raise ValueError(f"Response must be dictionary. Instead, got: {type(response)}")
163-
164-
payloads: Optional[Dict[str, JumpStartSerializablePayload]] = _retrieve_example_payloads(
165-
model_id=model_id,
166-
model_version=model_version,
167-
region=region,
168-
tolerate_vulnerable_model=tolerate_vulnerable_model,
169-
tolerate_deprecated_model=tolerate_deprecated_model,
170-
sagemaker_session=sagemaker_session,
171-
)
172-
if payloads is None or len(payloads) == 0:
173-
raise ValueError(f"Model ID '{model_id}' does not support generated text extraction.")
174-
175-
for payload in payloads.values():
176-
if accept_type is None or payload.accept == accept_type:
177-
generated_text_response_key: Optional[str] = payload.generated_text_response_key
178-
if generated_text_response_key is None:
179-
raise ValueError(
180-
f"Model ID '{model_id}' does not support generated text extraction."
181-
)
182-
183-
generated_text_response_key_split = generated_text_response_key.split(".")
184-
try:
185-
return _extract_field_from_json(response, generated_text_response_key_split)
186-
except KeyError:
187-
raise ValueError(f"Response is malformed: {response}")
188-
189-
raise ValueError(f"Model ID '{model_id}' does not support generated text extraction.")
190-
191-
192118
class PayloadSerializer:
193119
"""Utility class for serializing payloads associated with JumpStart models.
194120

src/sagemaker/jumpstart/types.py

-2
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,6 @@ class JumpStartSerializablePayload(JumpStartDataHolderType):
339339
"content_type",
340340
"accept",
341341
"body",
342-
"generated_text_response_key",
343342
"prompt_key",
344343
]
345344

@@ -371,7 +370,6 @@ def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None:
371370
self.content_type = json_obj["content_type"]
372371
self.body = json_obj["body"]
373372
accept = json_obj.get("accept")
374-
self.generated_text_response_key = json_obj.get("generated_text_response_key")
375373
self.prompt_key = json_obj.get("prompt_key")
376374
if accept:
377375
self.accept = accept

tests/unit/sagemaker/jumpstart/constants.py

-1
Original file line numberDiff line numberDiff line change
@@ -4129,7 +4129,6 @@
41294129
"Dog": {
41304130
"content_type": "application/json",
41314131
"prompt_key": "hello.prompt",
4132-
"generated_text_response_key": "key1.key2.generated_text",
41334132
"body": {
41344133
"hello": {"prompt": "a dog"},
41354134
"seed": 43,

tests/unit/sagemaker/jumpstart/test_payload_utils.py

-72
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818

1919
from sagemaker.jumpstart.payload_utils import (
2020
PayloadSerializer,
21-
_extract_generated_text_from_response,
2221
_construct_payload,
2322
)
2423
from sagemaker.jumpstart.types import JumpStartSerializablePayload
@@ -59,77 +58,6 @@ def test_construct_payload(self, patched_get_model_specs):
5958
)
6059

6160

62-
class TestResponseExtraction(TestCase):
63-
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
64-
def test_extract_generated_text(self, patched_get_model_specs):
65-
patched_get_model_specs.side_effect = get_special_model_spec
66-
67-
model_id = "response-keys"
68-
region = "us-west-2"
69-
generated_text = _extract_generated_text_from_response(
70-
response={"key1": {"key2": {"generated_text": "top secret"}}},
71-
model_id=model_id,
72-
model_version="*",
73-
region=region,
74-
)
75-
76-
self.assertEqual(
77-
_extract_generated_text_from_response(
78-
response={"key1": {"key2": {"generated_text": "top secret"}}},
79-
model_id=model_id,
80-
model_version="*",
81-
region=region,
82-
accept_type="application/json",
83-
),
84-
generated_text,
85-
)
86-
87-
self.assertEqual(
88-
generated_text,
89-
"top secret",
90-
)
91-
92-
with pytest.raises(ValueError):
93-
_extract_generated_text_from_response(
94-
response={"key1": {"key2": {"generated_texts": "top secret"}}},
95-
model_id=model_id,
96-
model_version="*",
97-
region=region,
98-
)
99-
100-
with pytest.raises(ValueError):
101-
_extract_generated_text_from_response(
102-
response={"key1": {"key2": {"generated_text": "top secret"}}},
103-
model_id=model_id,
104-
model_version="*",
105-
region=region,
106-
accept_type="blah/blah",
107-
)
108-
109-
with pytest.raises(ValueError):
110-
_extract_generated_text_from_response(
111-
response={"key1": {"key2": {"generated_text": "top secret"}}},
112-
model_id="env-var-variant-model", # some model without the required metadata
113-
model_version="*",
114-
region=region,
115-
)
116-
with pytest.raises(ValueError):
117-
_extract_generated_text_from_response(
118-
response={"key1": {"generated_texts": "top secret"}},
119-
model_id=model_id,
120-
model_version="*",
121-
region=region,
122-
)
123-
124-
with pytest.raises(ValueError):
125-
_extract_generated_text_from_response(
126-
response="blah",
127-
model_id=model_id,
128-
model_version="*",
129-
region=region,
130-
)
131-
132-
13361
class TestPayloadSerializer(TestCase):
13462

13563
payload_serializer = PayloadSerializer()

tests/unit/sagemaker/jumpstart/test_predictor.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def test_jumpstart_serializable_payload_with_predictor(
7676
"JumpStartSerializablePayload: {'content_type': 'application/json', 'accept': 'application/json'"
7777
", 'body': {'prompt': 'a dog', 'num_images_per_prompt': 2, 'num_inference_steps':"
7878
" 20, 'guidance_scale': 7.5, 'seed': 43, 'eta': 0.7, 'image':"
79-
" '$s3_b64<inference-notebook-assets/inpainting_cow.jpg>'}, 'generated_text_response_key': None}"
79+
" '$s3_b64<inference-notebook-assets/inpainting_cow.jpg>'}}"
8080
)
8181

8282
js_predictor.predict(default_payload)

0 commit comments

Comments
 (0)