Skip to content

feat: jumpstart extract generated text from response #4210

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 100 additions & 2 deletions src/sagemaker/jumpstart/payload_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,118 @@
from __future__ import absolute_import
import base64
import json
from typing import Optional, Union
from typing import Any, Dict, Optional, Union
import re
import boto3

from sagemaker.jumpstart.accessors import JumpStartS3PayloadAccessor
from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME
from sagemaker.jumpstart.enums import MIMEType
from sagemaker.jumpstart.types import JumpStartSerializablePayload
from sagemaker.jumpstart.utils import get_jumpstart_content_bucket
from sagemaker.jumpstart.utils import (
get_jumpstart_content_bucket,
)
from sagemaker.session import Session
from sagemaker.jumpstart.artifacts.payloads import _retrieve_example_payloads
from sagemaker.jumpstart.constants import (
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
)

S3_BYTES_REGEX = r"^\$s3<(?P<s3_key>[a-zA-Z0-9-_/.]+)>$"
S3_B64_STR_REGEX = r"\$s3_b64<(?P<s3_key>[a-zA-Z0-9-_/.]+)>"


def _extract_field_from_json(
json_input: dict,
keys: list[str],
) -> Any:
"""Given a dictionary, returns value at specified keys.

Raises:
KeyError: If a key cannot be found in the json input.
"""
curr_json = json_input
for idx, key in enumerate(keys):
if idx < len(keys) - 1:
curr_json = curr_json[key]
continue
return curr_json[key]


def _extract_generated_text_from_response(
response: dict,
model_id: str,
model_version: str,
region: Optional[str] = None,
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
accept_type: Optional[str] = None,
) -> str:
"""Returns generated text extracted from full response payload.

Args:
response (dict): Dictionary-valued response from which to extract
generated text.
model_id (str): JumpStart model ID of the JumpStart model from which to extract
generated text.
model_version (str): Version of the JumpStart model for which to extract generated
text.
region (Optional[str]): Region for which to extract generated
text. (Default: None).
tolerate_vulnerable_model (bool): True if vulnerable versions of model
specifications should be tolerated (exception not raised). If False, raises an
exception if the script used by this version of the model has dependencies with known
security vulnerabilities. (Default: False).
tolerate_deprecated_model (bool): True if deprecated versions of model
specifications should be tolerated (exception not raised). If False, raises
an exception if the version of the model is deprecated. (Default: False).
sagemaker_session (sagemaker.session.Session): A SageMaker Session
object, used for SageMaker interactions. If not
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
accept_type (Optional[str]): The accept type to optionally specify for the response.
(Default: None).

Returns:
str: extracted generated text from the endpoint response payload.

Raises:
ValueError: If the model is invalid, the model does not support generated text extraction,
or if the response is malformed.
"""

if not isinstance(response, dict):
raise ValueError(f"Response must be dictionary. Instead, got: {type(response)}")

payloads: Optional[Dict[str, JumpStartSerializablePayload]] = _retrieve_example_payloads(
model_id=model_id,
model_version=model_version,
region=region,
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
)
if payloads is None or len(payloads) == 0:
raise ValueError(f"Model ID '{model_id}' does not support generated text extraction.")

for payload in payloads.values():
if accept_type is None or payload.accept == accept_type:
generated_text_response_key: Optional[str] = payload.generated_text_response_key
if generated_text_response_key is None:
raise ValueError(
f"Model ID '{model_id}' does not support generated text extraction."
)

generated_text_response_key_split = generated_text_response_key.split(".")
try:
return _extract_field_from_json(response, generated_text_response_key_split)
except KeyError:
raise ValueError(f"Response is malformed: {response}")

raise ValueError(f"Model ID '{model_id}' does not support generated text extraction.")


class PayloadSerializer:
"""Utility class for serializing payloads associated with JumpStart models.

Expand Down
2 changes: 2 additions & 0 deletions src/sagemaker/jumpstart/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,7 @@ class JumpStartSerializablePayload(JumpStartDataHolderType):
"content_type",
"accept",
"body",
"generated_text_response_key",
]

_non_serializable_slots = ["raw_payload"]
Expand Down Expand Up @@ -361,6 +362,7 @@ def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None:
self.content_type = json_obj["content_type"]
self.body = json_obj["body"]
accept = json_obj.get("accept")
self.generated_text_response_key = json_obj.get("generated_text_response_key")
if accept:
self.accept = accept

Expand Down
163 changes: 163 additions & 0 deletions tests/unit/sagemaker/jumpstart/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -1781,6 +1781,169 @@
},
},
},
"response-keys": {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment please, consider: response-keys-model

"model_id": "model-depth2img-stable-diffusion-v1-5-controlnet-v1-1-fp16",
"url": "https://huggingface.co/lllyasviel/control_v11f1p_sd15_depth",
"version": "1.0.0",
"min_sdk_version": "2.144.0",
"training_supported": False,
"incremental_training_supported": False,
"hosting_ecr_specs": {
"framework": "djl-deepspeed",
"framework_version": "0.21.0",
"py_version": "py38",
"huggingface_transformers_version": "4.17",
},
"hosting_artifact_key": "stabilityai-infer/infer-model-depth2img-st"
"able-diffusion-v1-5-controlnet-v1-1-fp16.tar.gz",
"hosting_script_key": "source-directory-tarballs/stabilityai/inference/depth2img/v1.0.0/sourcedir.tar.gz",
"hosting_prepacked_artifact_key": "stabilityai-infer/prepack/v1.0.0/"
"infer-prepack-model-depth2img-stable-diffusion-v1-5-controlnet-v1-1-fp16.tar.gz",
"hosting_prepacked_artifact_version": "1.0.0",
"inference_vulnerable": False,
"inference_dependencies": [
"accelerate==0.18.0",
"diffusers==0.14.0",
"fsspec==2023.4.0",
"huggingface-hub==0.14.1",
"transformers==4.26.1",
],
"inference_vulnerabilities": [],
"training_vulnerable": False,
"training_dependencies": [],
"training_vulnerabilities": [],
"deprecated": False,
"inference_environment_variables": [
{
"name": "SAGEMAKER_PROGRAM",
"type": "text",
"default": "inference.py",
"scope": "container",
"required_for_model_class": True,
},
{
"name": "SAGEMAKER_SUBMIT_DIRECTORY",
"type": "text",
"default": "/opt/ml/model/code",
"scope": "container",
"required_for_model_class": False,
},
{
"name": "SAGEMAKER_CONTAINER_LOG_LEVEL",
"type": "text",
"default": "20",
"scope": "container",
"required_for_model_class": False,
},
{
"name": "SAGEMAKER_MODEL_SERVER_TIMEOUT",
"type": "text",
"default": "3600",
"scope": "container",
"required_for_model_class": False,
},
{
"name": "ENDPOINT_SERVER_TIMEOUT",
"type": "int",
"default": 3600,
"scope": "container",
"required_for_model_class": True,
},
{
"name": "MODEL_CACHE_ROOT",
"type": "text",
"default": "/opt/ml/model",
"scope": "container",
"required_for_model_class": True,
},
{
"name": "SAGEMAKER_ENV",
"type": "text",
"default": "1",
"scope": "container",
"required_for_model_class": True,
},
{
"name": "SAGEMAKER_MODEL_SERVER_WORKERS",
"type": "int",
"default": 1,
"scope": "container",
"required_for_model_class": True,
},
],
"metrics": [],
"default_inference_instance_type": "ml.g5.8xlarge",
"supported_inference_instance_types": [
"ml.g5.8xlarge",
"ml.g5.xlarge",
"ml.g5.2xlarge",
"ml.g5.4xlarge",
"ml.g5.16xlarge",
"ml.p3.2xlarge",
"ml.g4dn.xlarge",
"ml.g4dn.2xlarge",
"ml.g4dn.4xlarge",
"ml.g4dn.8xlarge",
"ml.g4dn.16xlarge",
],
"model_kwargs": {},
"deploy_kwargs": {},
"predictor_specs": {
"supported_content_types": ["application/json"],
"supported_accept_types": ["application/json"],
"default_content_type": "application/json",
"default_accept_type": "application/json",
},
"inference_enable_network_isolation": True,
"validation_supported": False,
"fine_tuning_supported": False,
"resource_name_base": "sd-1-5-controlnet-1-1-fp16",
"default_payloads": {
"Dog": {
"content_type": "application/json",
"prompt_key": "hello.prompt",
"generated_text_response_key": "key1.key2.generated_text",
"body": {
"hello": {"prompt": "a dog"},
"seed": 43,
},
}
},
"hosting_instance_type_variants": {
"regional_aliases": {
"af-south-1": {
"alias_ecr_uri_1": "626614931356.dkr.ecr.af-south-1.amazonaws.com/d"
"jl-inference:0.21.0-deepspeed0.8.3-cu117"
},
},
"variants": {
"c4": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
"c5": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
"c5d": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
"c5n": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
"c6i": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
"g4dn": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
"g5": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
"inf1": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
"inf2": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
"local": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
"local_gpu": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
"m4": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
"m5": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
"m5d": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
"p2": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
"p3": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
"p3dn": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
"p4d": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
"p4de": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
"p5": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
"r5": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
"r5d": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
"t2": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
"t3": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
},
},
},
"default_payloads": {
"model_id": "model-depth2img-stable-diffusion-v1-5-controlnet-v1-1-fp16",
"url": "https://huggingface.co/lllyasviel/control_v11f1p_sd15_depth",
Expand Down
80 changes: 79 additions & 1 deletion tests/unit/sagemaker/jumpstart/test_payload_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,89 @@
import base64
from unittest import TestCase
from mock.mock import patch
import pytest

from sagemaker.jumpstart.payload_utils import PayloadSerializer
from sagemaker.jumpstart.payload_utils import (
PayloadSerializer,
_extract_generated_text_from_response,
)
from sagemaker.jumpstart.types import JumpStartSerializablePayload


from tests.unit.sagemaker.jumpstart.utils import get_special_model_spec


class TestResponseExtraction(TestCase):
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
def test_extract_generated_text(self, patched_get_model_specs):
patched_get_model_specs.side_effect = get_special_model_spec

model_id = "response-keys"
region = "us-west-2"
generated_text = _extract_generated_text_from_response(
response={"key1": {"key2": {"generated_text": "top secret"}}},
model_id=model_id,
model_version="*",
region=region,
)

self.assertEqual(
_extract_generated_text_from_response(
response={"key1": {"key2": {"generated_text": "top secret"}}},
model_id=model_id,
model_version="*",
region=region,
accept_type="application/json",
),
generated_text,
)

self.assertEqual(
generated_text,
"top secret",
)

with pytest.raises(ValueError):
_extract_generated_text_from_response(
response={"key1": {"key2": {"generated_texts": "top secret"}}},
model_id=model_id,
model_version="*",
region=region,
)

with pytest.raises(ValueError):
_extract_generated_text_from_response(
response={"key1": {"key2": {"generated_text": "top secret"}}},
model_id=model_id,
model_version="*",
region=region,
accept_type="blah/blah",
)

with pytest.raises(ValueError):
_extract_generated_text_from_response(
response={"key1": {"key2": {"generated_text": "top secret"}}},
model_id="env-var-variant-model", # some model without the required metadata
model_version="*",
region=region,
)
with pytest.raises(ValueError):
_extract_generated_text_from_response(
response={"key1": {"generated_texts": "top secret"}},
model_id=model_id,
model_version="*",
region=region,
)

with pytest.raises(ValueError):
_extract_generated_text_from_response(
response="blah",
model_id=model_id,
model_version="*",
region=region,
)


class TestPayloadSerializer(TestCase):

payload_serializer = PayloadSerializer()
Expand Down
Loading