Skip to content

Commit 4befd93

Browse files
authored
feat: jumpstart extract generated text from response (#4210)
1 parent b8e3e05 commit 4befd93

File tree

5 files changed

+339
-4
lines changed

5 files changed

+339
-4
lines changed

src/sagemaker/jumpstart/payload_utils.py

+96-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from __future__ import absolute_import
1515
import base64
1616
import json
17-
from typing import Dict, Optional, Union
17+
from typing import Any, Dict, List, Optional, Union
1818
import re
1919
import boto3
2020

@@ -26,13 +26,33 @@
2626
)
2727
from sagemaker.jumpstart.enums import MIMEType
2828
from sagemaker.jumpstart.types import JumpStartSerializablePayload
29-
from sagemaker.jumpstart.utils import get_jumpstart_content_bucket
29+
from sagemaker.jumpstart.utils import (
30+
get_jumpstart_content_bucket,
31+
)
3032
from sagemaker.session import Session
3133

34+
3235
S3_BYTES_REGEX = r"^\$s3<(?P<s3_key>[a-zA-Z0-9-_/.]+)>$"
3336
S3_B64_STR_REGEX = r"\$s3_b64<(?P<s3_key>[a-zA-Z0-9-_/.]+)>"
3437

3538

39+
def _extract_field_from_json(
40+
json_input: dict,
41+
keys: List[str],
42+
) -> Any:
43+
"""Given a dictionary, returns value at specified keys.
44+
45+
Raises:
46+
KeyError: If a key cannot be found in the json input.
47+
"""
48+
curr_json = json_input
49+
for idx, key in enumerate(keys):
50+
if idx < len(keys) - 1:
51+
curr_json = curr_json[key]
52+
continue
53+
return curr_json[key]
54+
55+
3656
def _construct_payload(
3757
prompt: str,
3858
model_id: str,
@@ -95,6 +115,80 @@ def _construct_payload(
95115
return payload_to_use
96116

97117

118+
def _extract_generated_text_from_response(
119+
response: dict,
120+
model_id: str,
121+
model_version: str,
122+
region: Optional[str] = None,
123+
tolerate_vulnerable_model: bool = False,
124+
tolerate_deprecated_model: bool = False,
125+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
126+
accept_type: Optional[str] = None,
127+
) -> str:
128+
"""Returns generated text extracted from full response payload.
129+
130+
Args:
131+
response (dict): Dictionary-valued response from which to extract
132+
generated text.
133+
model_id (str): JumpStart model ID of the JumpStart model from which to extract
134+
generated text.
135+
model_version (str): Version of the JumpStart model for which to extract generated
136+
text.
137+
region (Optional[str]): Region for which to extract generated
138+
text. (Default: None).
139+
tolerate_vulnerable_model (bool): True if vulnerable versions of model
140+
specifications should be tolerated (exception not raised). If False, raises an
141+
exception if the script used by this version of the model has dependencies with known
142+
security vulnerabilities. (Default: False).
143+
tolerate_deprecated_model (bool): True if deprecated versions of model
144+
specifications should be tolerated (exception not raised). If False, raises
145+
an exception if the version of the model is deprecated. (Default: False).
146+
sagemaker_session (sagemaker.session.Session): A SageMaker Session
147+
object, used for SageMaker interactions. If not
148+
specified, one is created using the default AWS configuration
149+
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
150+
accept_type (Optional[str]): The accept type to optionally specify for the response.
151+
(Default: None).
152+
153+
Returns:
154+
str: extracted generated text from the endpoint response payload.
155+
156+
Raises:
157+
ValueError: If the model is invalid, the model does not support generated text extraction,
158+
or if the response is malformed.
159+
"""
160+
161+
if not isinstance(response, dict):
162+
raise ValueError(f"Response must be dictionary. Instead, got: {type(response)}")
163+
164+
payloads: Optional[Dict[str, JumpStartSerializablePayload]] = _retrieve_example_payloads(
165+
model_id=model_id,
166+
model_version=model_version,
167+
region=region,
168+
tolerate_vulnerable_model=tolerate_vulnerable_model,
169+
tolerate_deprecated_model=tolerate_deprecated_model,
170+
sagemaker_session=sagemaker_session,
171+
)
172+
if payloads is None or len(payloads) == 0:
173+
raise ValueError(f"Model ID '{model_id}' does not support generated text extraction.")
174+
175+
for payload in payloads.values():
176+
if accept_type is None or payload.accept == accept_type:
177+
generated_text_response_key: Optional[str] = payload.generated_text_response_key
178+
if generated_text_response_key is None:
179+
raise ValueError(
180+
f"Model ID '{model_id}' does not support generated text extraction."
181+
)
182+
183+
generated_text_response_key_split = generated_text_response_key.split(".")
184+
try:
185+
return _extract_field_from_json(response, generated_text_response_key_split)
186+
except KeyError:
187+
raise ValueError(f"Response is malformed: {response}")
188+
189+
raise ValueError(f"Model ID '{model_id}' does not support generated text extraction.")
190+
191+
98192
class PayloadSerializer:
99193
"""Utility class for serializing payloads associated with JumpStart models.
100194

src/sagemaker/jumpstart/types.py

+2
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,7 @@ class JumpStartSerializablePayload(JumpStartDataHolderType):
334334
"content_type",
335335
"accept",
336336
"body",
337+
"generated_text_response_key",
337338
"prompt_key",
338339
]
339340

@@ -365,6 +366,7 @@ def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None:
365366
self.content_type = json_obj["content_type"]
366367
self.body = json_obj["body"]
367368
accept = json_obj.get("accept")
369+
self.generated_text_response_key = json_obj.get("generated_text_response_key")
368370
self.prompt_key = json_obj.get("prompt_key")
369371
if accept:
370372
self.accept = accept

tests/unit/sagemaker/jumpstart/constants.py

+163
Original file line numberDiff line numberDiff line change
@@ -2054,6 +2054,169 @@
20542054
},
20552055
},
20562056
},
2057+
"response-keys": {
2058+
"model_id": "model-depth2img-stable-diffusion-v1-5-controlnet-v1-1-fp16",
2059+
"url": "https://huggingface.co/lllyasviel/control_v11f1p_sd15_depth",
2060+
"version": "1.0.0",
2061+
"min_sdk_version": "2.144.0",
2062+
"training_supported": False,
2063+
"incremental_training_supported": False,
2064+
"hosting_ecr_specs": {
2065+
"framework": "djl-deepspeed",
2066+
"framework_version": "0.21.0",
2067+
"py_version": "py38",
2068+
"huggingface_transformers_version": "4.17",
2069+
},
2070+
"hosting_artifact_key": "stabilityai-infer/infer-model-depth2img-st"
2071+
"able-diffusion-v1-5-controlnet-v1-1-fp16.tar.gz",
2072+
"hosting_script_key": "source-directory-tarballs/stabilityai/inference/depth2img/v1.0.0/sourcedir.tar.gz",
2073+
"hosting_prepacked_artifact_key": "stabilityai-infer/prepack/v1.0.0/"
2074+
"infer-prepack-model-depth2img-stable-diffusion-v1-5-controlnet-v1-1-fp16.tar.gz",
2075+
"hosting_prepacked_artifact_version": "1.0.0",
2076+
"inference_vulnerable": False,
2077+
"inference_dependencies": [
2078+
"accelerate==0.18.0",
2079+
"diffusers==0.14.0",
2080+
"fsspec==2023.4.0",
2081+
"huggingface-hub==0.14.1",
2082+
"transformers==4.26.1",
2083+
],
2084+
"inference_vulnerabilities": [],
2085+
"training_vulnerable": False,
2086+
"training_dependencies": [],
2087+
"training_vulnerabilities": [],
2088+
"deprecated": False,
2089+
"inference_environment_variables": [
2090+
{
2091+
"name": "SAGEMAKER_PROGRAM",
2092+
"type": "text",
2093+
"default": "inference.py",
2094+
"scope": "container",
2095+
"required_for_model_class": True,
2096+
},
2097+
{
2098+
"name": "SAGEMAKER_SUBMIT_DIRECTORY",
2099+
"type": "text",
2100+
"default": "/opt/ml/model/code",
2101+
"scope": "container",
2102+
"required_for_model_class": False,
2103+
},
2104+
{
2105+
"name": "SAGEMAKER_CONTAINER_LOG_LEVEL",
2106+
"type": "text",
2107+
"default": "20",
2108+
"scope": "container",
2109+
"required_for_model_class": False,
2110+
},
2111+
{
2112+
"name": "SAGEMAKER_MODEL_SERVER_TIMEOUT",
2113+
"type": "text",
2114+
"default": "3600",
2115+
"scope": "container",
2116+
"required_for_model_class": False,
2117+
},
2118+
{
2119+
"name": "ENDPOINT_SERVER_TIMEOUT",
2120+
"type": "int",
2121+
"default": 3600,
2122+
"scope": "container",
2123+
"required_for_model_class": True,
2124+
},
2125+
{
2126+
"name": "MODEL_CACHE_ROOT",
2127+
"type": "text",
2128+
"default": "/opt/ml/model",
2129+
"scope": "container",
2130+
"required_for_model_class": True,
2131+
},
2132+
{
2133+
"name": "SAGEMAKER_ENV",
2134+
"type": "text",
2135+
"default": "1",
2136+
"scope": "container",
2137+
"required_for_model_class": True,
2138+
},
2139+
{
2140+
"name": "SAGEMAKER_MODEL_SERVER_WORKERS",
2141+
"type": "int",
2142+
"default": 1,
2143+
"scope": "container",
2144+
"required_for_model_class": True,
2145+
},
2146+
],
2147+
"metrics": [],
2148+
"default_inference_instance_type": "ml.g5.8xlarge",
2149+
"supported_inference_instance_types": [
2150+
"ml.g5.8xlarge",
2151+
"ml.g5.xlarge",
2152+
"ml.g5.2xlarge",
2153+
"ml.g5.4xlarge",
2154+
"ml.g5.16xlarge",
2155+
"ml.p3.2xlarge",
2156+
"ml.g4dn.xlarge",
2157+
"ml.g4dn.2xlarge",
2158+
"ml.g4dn.4xlarge",
2159+
"ml.g4dn.8xlarge",
2160+
"ml.g4dn.16xlarge",
2161+
],
2162+
"model_kwargs": {},
2163+
"deploy_kwargs": {},
2164+
"predictor_specs": {
2165+
"supported_content_types": ["application/json"],
2166+
"supported_accept_types": ["application/json"],
2167+
"default_content_type": "application/json",
2168+
"default_accept_type": "application/json",
2169+
},
2170+
"inference_enable_network_isolation": True,
2171+
"validation_supported": False,
2172+
"fine_tuning_supported": False,
2173+
"resource_name_base": "sd-1-5-controlnet-1-1-fp16",
2174+
"default_payloads": {
2175+
"Dog": {
2176+
"content_type": "application/json",
2177+
"prompt_key": "hello.prompt",
2178+
"generated_text_response_key": "key1.key2.generated_text",
2179+
"body": {
2180+
"hello": {"prompt": "a dog"},
2181+
"seed": 43,
2182+
},
2183+
}
2184+
},
2185+
"hosting_instance_type_variants": {
2186+
"regional_aliases": {
2187+
"af-south-1": {
2188+
"alias_ecr_uri_1": "626614931356.dkr.ecr.af-south-1.amazonaws.com/d"
2189+
"jl-inference:0.21.0-deepspeed0.8.3-cu117"
2190+
},
2191+
},
2192+
"variants": {
2193+
"c4": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
2194+
"c5": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
2195+
"c5d": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
2196+
"c5n": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
2197+
"c6i": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
2198+
"g4dn": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
2199+
"g5": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
2200+
"inf1": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
2201+
"inf2": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
2202+
"local": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
2203+
"local_gpu": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
2204+
"m4": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
2205+
"m5": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
2206+
"m5d": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
2207+
"p2": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
2208+
"p3": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
2209+
"p3dn": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
2210+
"p4d": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
2211+
"p4de": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
2212+
"p5": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
2213+
"r5": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
2214+
"r5d": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
2215+
"t2": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
2216+
"t3": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
2217+
},
2218+
},
2219+
},
20572220
"default_payloads": {
20582221
"model_id": "model-depth2img-stable-diffusion-v1-5-controlnet-v1-1-fp16",
20592222
"url": "https://huggingface.co/lllyasviel/control_v11f1p_sd15_depth",

0 commit comments

Comments
 (0)