Skip to content

Commit fc320f6

Browse files
authored
Merge pull request #254 from awslabs/feature/visual-qa
feat(visualqa): question answer on uploaded image
2 parents a4e4089 + 8a0bd43 commit fc320f6

File tree

26 files changed

+1795
-511
lines changed

26 files changed

+1795
-511
lines changed

apidocs/classes/QaAppsyncOpensearch.md

+9
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ The QaAppsyncOpensearch class.
3434
- [securityGroup](QaAppsyncOpensearch.md#securitygroup)
3535
- [stage](QaAppsyncOpensearch.md#stage)
3636
- [vpc](QaAppsyncOpensearch.md#vpc)
37+
- [CONSTRUCT\_SCHEMA\_UPDATE\_WARNING](QaAppsyncOpensearch.md#construct_schema_update_warning)
3738
- [usageMetricMap](QaAppsyncOpensearch.md#usagemetricmap)
3839

3940
### Methods
@@ -240,6 +241,14 @@ Returns the instance of ec2.IVpc used by the construct
240241

241242
___
242243

244+
### CONSTRUCT\_SCHEMA\_UPDATE\_WARNING
245+
246+
`Static` `Readonly` **CONSTRUCT\_SCHEMA\_UPDATE\_WARNING**: ``"\n Attention QaAppsyncOpensearch users, an update has been made to \n the GraphQL schema.To ensure continued functionality, please review \n and update your GraphQL mutations and subscriptions to align with \n the new schema.This schema update enables enhanced capabilities \n and optimizations,so adopting the changes is recommended. \n Please refer to the construct documentation for details \n on the schema changes and examples of updated GraphQL statements.\n Reach out to the support team if you need assistance \n updating your integration codebase. \n "``
247+
248+
Construct warning
249+
250+
___
251+
243252
### usageMetricMap
244253

245254
`Static` `Protected` **usageMetricMap**: `Record`\<`string`, `number`\>

lambda/aws-qa-appsync-opensearch/question_answering/src/lambda.py

+1
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,4 @@ def handler(event, context: LambdaContext) -> dict:
4141

4242
print(f"llm_response is {llm_response}")
4343
return llm_response
44+
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from .text_generation_llm_selector import get_llm, get_max_tokens, get_embeddings_llm
1+
from .text_generation_llm_selector import get_llm, get_max_tokens, get_embeddings_llm,get_bedrock_fm

lambda/aws-qa-appsync-opensearch/question_answering/src/llms/text_generation_llm_selector.py

+64-8
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions
1111
# and limitations under the License.
1212
#
13+
from aiohttp import ClientError
1314
from langchain.llms.bedrock import Bedrock
1415
from langchain_community.embeddings import BedrockEmbeddings
1516
import os
@@ -26,7 +27,8 @@
2627
metrics = Metrics(namespace="question_answering", service="QUESTION_ANSWERING")
2728

2829

29-
def get_llm(callbacks=None):
30+
31+
def get_llm(callbacks=None,model_id="anthropic.claude-v2:1"):
3032
bedrock = boto3.client('bedrock-runtime')
3133

3234
params = {
@@ -39,7 +41,7 @@ def get_llm(callbacks=None):
3941

4042
kwargs = {
4143
"client": bedrock,
42-
"model_id": "anthropic.claude-v2:1",
44+
"model_id": model_id,
4345
"model_kwargs": params,
4446
"streaming": False
4547
}
@@ -50,10 +52,64 @@ def get_llm(callbacks=None):
5052

5153
return Bedrock(**kwargs)
5254

53-
def get_embeddings_llm():
55+
def get_embeddings_llm(model_id,modality):
5456
bedrock = boto3.client('bedrock-runtime')
55-
return BedrockEmbeddings(client=bedrock, model_id="amazon.titan-embed-text-v1")
56-
57-
def get_max_tokens():
58-
return 200000
59-
57+
validation_status=validate_model_id_in_bedrock(model_id,modality)
58+
if(validation_status['status']):
59+
return BedrockEmbeddings(client=bedrock, model_id=model_id)
60+
else:
61+
return None
62+
63+
64+
def get_bedrock_fm(model_id,modality):
65+
bedrock_client = boto3.client('bedrock-runtime')
66+
validation_status= validate_model_id_in_bedrock(model_id,modality)
67+
logger.info(f' validation_status :: {validation_status}')
68+
if(validation_status['status']):
69+
return bedrock_client
70+
else:
71+
logger.error(f"reason ::{validation_status['message']} ")
72+
return None
73+
74+
75+
76+
#TODO -add max token based on model id
77+
def get_max_tokens(model_id):
78+
match model_id:
79+
case "anthropic.claude-v2:1":
80+
return 200000
81+
case "anthropic.claude-3-sonnet-20240229-v1:0":
82+
return 200000
83+
case _:
84+
return 4096
85+
86+
87+
def validate_model_id_in_bedrock(model_id,modality):
88+
"""
89+
Validate if the listed model id is supported with given modality
90+
in bedrock or not.
91+
"""
92+
response={
93+
"status":False,
94+
"message":f"model {model_id} is not supported in bedrock."
95+
}
96+
try:
97+
bedrock_client = boto3.client(service_name="bedrock")
98+
bedrock_model_list = bedrock_client.list_foundation_models()
99+
models = bedrock_model_list["modelSummaries"]
100+
for model in models:
101+
if model["modelId"].lower() == model_id.lower():
102+
response["message"]=f"model {model_id} does not support modality {modality} "
103+
for inputModality in model["inputModalities"]:
104+
if inputModality.lower() == modality.lower():
105+
response["message"]=f"model {model_id} with modality {modality} is supported with bedrock "
106+
response["status"] = True
107+
108+
logger.info(f' response :: {response}')
109+
return response
110+
except ClientError as ce:
111+
message=f"error occured while validating model in bedrock {ce}"
112+
logger.error(message)
113+
response["status"] = False
114+
response["message"] = message
115+
return response
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
#
2+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
5+
# with the License. A copy of the License is located at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES
10+
# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions
11+
# and limitations under the License.
12+
#
13+
from .helper import send_job_status, JobStatus
14+
from langchain.callbacks.base import BaseCallbackHandler
15+
from langchain.schema import LLMResult
16+
import base64
17+
from typing import Any, Dict, List, Union
18+
19+
from aws_lambda_powertools import Logger, Tracer, Metrics
20+
21+
logger = Logger(service="QUESTION_ANSWERING")
22+
tracer = Tracer(service="QUESTION_ANSWERING")
23+
metrics = Metrics(namespace="question_answering", service="QUESTION_ANSWERING")
24+
25+
class StreamingCallbackHandler(BaseCallbackHandler):
26+
def __init__(self, status_variables: Dict):
27+
self.status_variables = status_variables
28+
logger.info("[StreamingCallbackHandler::__init__] Initialized")
29+
30+
def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> None:
31+
"""Runs when streaming is started."""
32+
logger.info(f"[StreamingCallbackHandler::on_llm_start] Streaming started!")
33+
34+
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
35+
"""Run on new LLM token. Only available when streaming is enabled."""
36+
try:
37+
logger.info(f'[StreamingCallbackHandler::on_llm_new_token] token is: {token}')
38+
llm_answer_bytes = token.encode("utf-8")
39+
base64_bytes = base64.b64encode(llm_answer_bytes)
40+
llm_answer_base64_string = base64_bytes.decode("utf-8")
41+
42+
self.status_variables['jobstatus'] = JobStatus.STREAMING_NEW_TOKEN.status
43+
self.status_variables['answer'] = llm_answer_base64_string
44+
send_job_status(self.status_variables)
45+
46+
except Exception as err:
47+
logger.exception(err)
48+
self.status_variables['jobstatus'] = JobStatus.ERROR_PREDICTION.status
49+
error = JobStatus.ERROR_PREDICTION.get_message()
50+
self.status_variables['answer'] = error.decode("utf-8")
51+
send_job_status(self.status_variables)
52+
53+
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
54+
"""Run when LLM ends running."""
55+
logger.info(f"[StreamingCallbackHandler::on_llm_end] Streaming ended. Response: {response}")
56+
try:
57+
self.status_variables['jobstatus'] = JobStatus.STREAMING_ENDED.status
58+
self.status_variables['answer'] = ""
59+
send_job_status(self.status_variables)
60+
61+
except Exception as err:
62+
logger.exception(err)
63+
self.status_variables['jobstatus'] = JobStatus.ERROR_PREDICTION.status
64+
error = JobStatus.ERROR_PREDICTION.get_message()
65+
self.status_variables['answer'] = error.decode("utf-8")
66+
send_job_status(self.status_variables)
67+
68+
def on_llm_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> None:
69+
"""Run when LLM errors."""
70+
logger.exception(error)
71+
self.status_variables['jobstatus'] = JobStatus.ERROR_PREDICTION.status
72+
error = JobStatus.ERROR_PREDICTION.get_message()
73+
self.status_variables['answer'] = error.decode("utf-8")
74+
send_job_status(self.status_variables)

0 commit comments

Comments
 (0)