Skip to content

feat: jumpstart extract generated text from response #4210

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 96 additions & 2 deletions src/sagemaker/jumpstart/payload_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from __future__ import absolute_import
import base64
import json
from typing import Dict, Optional, Union
from typing import Any, Dict, List, Optional, Union
import re
import boto3

Expand All @@ -26,13 +26,33 @@
)
from sagemaker.jumpstart.enums import MIMEType
from sagemaker.jumpstart.types import JumpStartSerializablePayload
from sagemaker.jumpstart.utils import get_jumpstart_content_bucket
from sagemaker.jumpstart.utils import (
get_jumpstart_content_bucket,
)
from sagemaker.session import Session


S3_BYTES_REGEX = r"^\$s3<(?P<s3_key>[a-zA-Z0-9-_/.]+)>$"
S3_B64_STR_REGEX = r"\$s3_b64<(?P<s3_key>[a-zA-Z0-9-_/.]+)>"


def _extract_field_from_json(
json_input: dict,
keys: List[str],
) -> Any:
"""Given a dictionary, returns value at specified keys.

Raises:
KeyError: If a key cannot be found in the json input.
"""
curr_json = json_input
for idx, key in enumerate(keys):
if idx < len(keys) - 1:
curr_json = curr_json[key]
continue
return curr_json[key]


def _construct_payload(
prompt: str,
model_id: str,
Expand Down Expand Up @@ -95,6 +115,80 @@ def _construct_payload(
return payload_to_use


def _extract_generated_text_from_response(
response: dict,
model_id: str,
model_version: str,
region: Optional[str] = None,
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
accept_type: Optional[str] = None,
) -> str:
"""Returns generated text extracted from full response payload.

Args:
response (dict): Dictionary-valued response from which to extract
generated text.
model_id (str): JumpStart model ID of the JumpStart model from which to extract
generated text.
model_version (str): Version of the JumpStart model for which to extract generated
text.
region (Optional[str]): Region for which to extract generated
text. (Default: None).
tolerate_vulnerable_model (bool): True if vulnerable versions of model
specifications should be tolerated (exception not raised). If False, raises an
exception if the script used by this version of the model has dependencies with known
security vulnerabilities. (Default: False).
tolerate_deprecated_model (bool): True if deprecated versions of model
specifications should be tolerated (exception not raised). If False, raises
an exception if the version of the model is deprecated. (Default: False).
sagemaker_session (sagemaker.session.Session): A SageMaker Session
object, used for SageMaker interactions. If not
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
accept_type (Optional[str]): The accept type to optionally specify for the response.
(Default: None).

Returns:
str: extracted generated text from the endpoint response payload.

Raises:
ValueError: If the model is invalid, the model does not support generated text extraction,
or if the response is malformed.
"""

if not isinstance(response, dict):
raise ValueError(f"Response must be dictionary. Instead, got: {type(response)}")

payloads: Optional[Dict[str, JumpStartSerializablePayload]] = _retrieve_example_payloads(
model_id=model_id,
model_version=model_version,
region=region,
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
)
if payloads is None or len(payloads) == 0:
raise ValueError(f"Model ID '{model_id}' does not support generated text extraction.")

for payload in payloads.values():
if accept_type is None or payload.accept == accept_type:
generated_text_response_key: Optional[str] = payload.generated_text_response_key
if generated_text_response_key is None:
raise ValueError(
f"Model ID '{model_id}' does not support generated text extraction."
)

generated_text_response_key_split = generated_text_response_key.split(".")
try:
return _extract_field_from_json(response, generated_text_response_key_split)
except KeyError:
raise ValueError(f"Response is malformed: {response}")

raise ValueError(f"Model ID '{model_id}' does not support generated text extraction.")


class PayloadSerializer:
"""Utility class for serializing payloads associated with JumpStart models.

Expand Down
2 changes: 2 additions & 0 deletions src/sagemaker/jumpstart/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,7 @@ class JumpStartSerializablePayload(JumpStartDataHolderType):
"content_type",
"accept",
"body",
"generated_text_response_key",
"prompt_key",
]

Expand Down Expand Up @@ -365,6 +366,7 @@ def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None:
self.content_type = json_obj["content_type"]
self.body = json_obj["body"]
accept = json_obj.get("accept")
self.generated_text_response_key = json_obj.get("generated_text_response_key")
self.prompt_key = json_obj.get("prompt_key")
if accept:
self.accept = accept
Expand Down
163 changes: 163 additions & 0 deletions tests/unit/sagemaker/jumpstart/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -2054,6 +2054,169 @@
},
},
},
"response-keys": {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment please, consider: response-keys-model

"model_id": "model-depth2img-stable-diffusion-v1-5-controlnet-v1-1-fp16",
"url": "https://huggingface.co/lllyasviel/control_v11f1p_sd15_depth",
"version": "1.0.0",
"min_sdk_version": "2.144.0",
"training_supported": False,
"incremental_training_supported": False,
"hosting_ecr_specs": {
"framework": "djl-deepspeed",
"framework_version": "0.21.0",
"py_version": "py38",
"huggingface_transformers_version": "4.17",
},
"hosting_artifact_key": "stabilityai-infer/infer-model-depth2img-st"
"able-diffusion-v1-5-controlnet-v1-1-fp16.tar.gz",
"hosting_script_key": "source-directory-tarballs/stabilityai/inference/depth2img/v1.0.0/sourcedir.tar.gz",
"hosting_prepacked_artifact_key": "stabilityai-infer/prepack/v1.0.0/"
"infer-prepack-model-depth2img-stable-diffusion-v1-5-controlnet-v1-1-fp16.tar.gz",
"hosting_prepacked_artifact_version": "1.0.0",
"inference_vulnerable": False,
"inference_dependencies": [
"accelerate==0.18.0",
"diffusers==0.14.0",
"fsspec==2023.4.0",
"huggingface-hub==0.14.1",
"transformers==4.26.1",
],
"inference_vulnerabilities": [],
"training_vulnerable": False,
"training_dependencies": [],
"training_vulnerabilities": [],
"deprecated": False,
"inference_environment_variables": [
{
"name": "SAGEMAKER_PROGRAM",
"type": "text",
"default": "inference.py",
"scope": "container",
"required_for_model_class": True,
},
{
"name": "SAGEMAKER_SUBMIT_DIRECTORY",
"type": "text",
"default": "/opt/ml/model/code",
"scope": "container",
"required_for_model_class": False,
},
{
"name": "SAGEMAKER_CONTAINER_LOG_LEVEL",
"type": "text",
"default": "20",
"scope": "container",
"required_for_model_class": False,
},
{
"name": "SAGEMAKER_MODEL_SERVER_TIMEOUT",
"type": "text",
"default": "3600",
"scope": "container",
"required_for_model_class": False,
},
{
"name": "ENDPOINT_SERVER_TIMEOUT",
"type": "int",
"default": 3600,
"scope": "container",
"required_for_model_class": True,
},
{
"name": "MODEL_CACHE_ROOT",
"type": "text",
"default": "/opt/ml/model",
"scope": "container",
"required_for_model_class": True,
},
{
"name": "SAGEMAKER_ENV",
"type": "text",
"default": "1",
"scope": "container",
"required_for_model_class": True,
},
{
"name": "SAGEMAKER_MODEL_SERVER_WORKERS",
"type": "int",
"default": 1,
"scope": "container",
"required_for_model_class": True,
},
],
"metrics": [],
"default_inference_instance_type": "ml.g5.8xlarge",
"supported_inference_instance_types": [
"ml.g5.8xlarge",
"ml.g5.xlarge",
"ml.g5.2xlarge",
"ml.g5.4xlarge",
"ml.g5.16xlarge",
"ml.p3.2xlarge",
"ml.g4dn.xlarge",
"ml.g4dn.2xlarge",
"ml.g4dn.4xlarge",
"ml.g4dn.8xlarge",
"ml.g4dn.16xlarge",
],
"model_kwargs": {},
"deploy_kwargs": {},
"predictor_specs": {
"supported_content_types": ["application/json"],
"supported_accept_types": ["application/json"],
"default_content_type": "application/json",
"default_accept_type": "application/json",
},
"inference_enable_network_isolation": True,
"validation_supported": False,
"fine_tuning_supported": False,
"resource_name_base": "sd-1-5-controlnet-1-1-fp16",
"default_payloads": {
"Dog": {
"content_type": "application/json",
"prompt_key": "hello.prompt",
"generated_text_response_key": "key1.key2.generated_text",
"body": {
"hello": {"prompt": "a dog"},
"seed": 43,
},
}
},
"hosting_instance_type_variants": {
"regional_aliases": {
"af-south-1": {
"alias_ecr_uri_1": "626614931356.dkr.ecr.af-south-1.amazonaws.com/d"
"jl-inference:0.21.0-deepspeed0.8.3-cu117"
},
},
"variants": {
"c4": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
"c5": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
"c5d": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
"c5n": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
"c6i": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
"g4dn": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
"g5": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
"inf1": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
"inf2": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
"local": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
"local_gpu": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
"m4": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
"m5": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
"m5d": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
"p2": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
"p3": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
"p3dn": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
"p4d": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
"p4de": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
"p5": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
"r5": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
"r5d": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
"t2": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
"t3": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
},
},
},
"default_payloads": {
"model_id": "model-depth2img-stable-diffusion-v1-5-controlnet-v1-1-fp16",
"url": "https://huggingface.co/lllyasviel/control_v11f1p_sd15_depth",
Expand Down
Loading