Skip to content

Commit 6cbb5e0

Browse files
author
Dinesh Sajwan
committed
feat(visualqa): implemented review comments
1 parent bcc1ccd commit 6cbb5e0

File tree

10 files changed

+298
-164
lines changed

10 files changed

+298
-164
lines changed

lambda/aws-qa-appsync-opensearch/question_answering/src/lambda.py

+30-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
tracer = Tracer(service="QUESTION_ANSWERING")
2424
metrics = Metrics(namespace="question_answering", service="QUESTION_ANSWERING")
2525

26-
@logger.inject_lambda_context(log_event=True)
26+
#@logger.inject_lambda_context(log_event=True)
2727
@tracer.capture_lambda_handler
2828
@metrics.log_metrics(capture_cold_start_metric=True)
2929
def handler(event, context: LambdaContext) -> dict:
@@ -41,3 +41,32 @@ def handler(event, context: LambdaContext) -> dict:
4141

4242
print(f"llm_response is {llm_response}")
4343
return llm_response
44+
45+
input ={"detail": {
46+
"jobid": "111",
47+
"jobstatus": "",
48+
"qa_model": {
49+
"provider": "Bedrock",
50+
"modelId": "anthropic.claude-3-sonnet-20240229-v1:0",
51+
"streaming": True,
52+
"modality": "Image"
53+
},
54+
"embeddings_model": {
55+
"provider": "Bedrock",
56+
"modelId": "amazon.titan-embed-image-v1",
57+
"streaming": True
58+
},
59+
"retrieval": {
60+
"max_docs": 1,
61+
"index_name": "",
62+
"filter_filename": ""
63+
},
64+
"filename": "two_cats.jpeg",
65+
"presignedurl": "",
66+
"question": "d2hhdCBhcmUgdGhlIGNhdHMgZG9pbmc/",
67+
"verbose": False,
68+
"responseGenerationMethod": "LONG_CONTEXT"
69+
}
70+
}
71+
72+
handler(input, None)
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from .text_generation_llm_selector import get_llm, get_max_tokens, get_embeddings_llm
1+
from .text_generation_llm_selector import get_llm, get_max_tokens, get_embeddings_llm,get_bedrock_fm

lambda/aws-qa-appsync-opensearch/question_answering/src/llms/text_generation_llm_selector.py

+57-6
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions
1111
# and limitations under the License.
1212
#
13+
from aiohttp import ClientError
1314
from langchain.llms.bedrock import Bedrock
1415
from langchain_community.embeddings import BedrockEmbeddings
1516
import os
@@ -26,7 +27,7 @@
2627
metrics = Metrics(namespace="question_answering", service="QUESTION_ANSWERING")
2728

2829

29-
def get_llm(callbacks=None):
30+
def get_llm(callbacks=None,model_id="anthropic.claude-v2:1"):
3031
bedrock = boto3.client('bedrock-runtime')
3132

3233
params = {
@@ -39,7 +40,7 @@ def get_llm(callbacks=None):
3940

4041
kwargs = {
4142
"client": bedrock,
42-
"model_id": "anthropic.claude-v2:1",
43+
"model_id": model_id,
4344
"model_kwargs": params,
4445
"streaming": False
4546
}
@@ -50,10 +51,60 @@ def get_llm(callbacks=None):
5051

5152
return Bedrock(**kwargs)
5253

53-
def get_embeddings_llm(model_id):
54+
def get_embeddings_llm(model_id,modality):
5455
bedrock = boto3.client('bedrock-runtime')
55-
return BedrockEmbeddings(client=bedrock, model_id=model_id)
56-
56+
validation_status=validate_model_id_in_bedrock(model_id,modality)
57+
if(validation_status['status']):
58+
return BedrockEmbeddings(client=bedrock, model_id=model_id)
59+
else:
60+
return None
61+
62+
63+
def get_bedrock_fm(model_id,modality):
64+
bedrock_client = boto3.client('bedrock-runtime')
65+
validation_status= validate_model_id_in_bedrock(model_id,modality)
66+
print(f' validation_status :: {validation_status}')
67+
if(validation_status['status']):
68+
return bedrock_client
69+
else:
70+
logger.error(f"reason ::{validation_status['message']} ")
71+
return None
72+
73+
74+
75+
#TODO -add max token based on model id
5776
def get_max_tokens():
5877
return 200000
59-
78+
79+
def validate_model_id_in_bedrock(model_id,modality):
80+
"""
81+
Validate if the listed model id is supported with given modality
82+
in bedrock or not.
83+
"""
84+
response={
85+
"status":False,
86+
"message":f"model {model_id} is not supported in bedrock."
87+
}
88+
try:
89+
bedrock_client = boto3.client(service_name="bedrock")
90+
bedrock_model_list = bedrock_client.list_foundation_models()
91+
models = bedrock_model_list["modelSummaries"]
92+
for model in models:
93+
if model["modelId"].lower() == model_id.lower():
94+
response["message"]=f"model {model_id} does not support modality {modality} "
95+
print(f' modality :: {model["inputModalities"]}')
96+
for inputModality in model["inputModalities"]:
97+
if inputModality.lower() == modality.lower():
98+
print(f' modality supported')
99+
response["message"]=f"model {model_id} with modality {modality} is supported with bedrock "
100+
response["status"] = True
101+
102+
print(f' response :: {response}')
103+
return response
104+
except ClientError as ce:
105+
message=f"error occured while validating model in bedrock {ce}"
106+
logger.error(message)
107+
response["status"] = False
108+
response["message"] = message
109+
print(f' response :: {response}')
110+
return response

lambda/aws-qa-appsync-opensearch/question_answering/src/qa_agent/StreamingCallbackHandler.py

+12
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,15 @@
1+
#
2+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
5+
# with the License. A copy of the License is located at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES
10+
# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions
11+
# and limitations under the License.
12+
#
113
from .helper import send_job_status, JobStatus
214
from langchain.callbacks.base import BaseCallbackHandler
315
from langchain.schema import LLMResult

lambda/aws-qa-appsync-opensearch/question_answering/src/qa_agent/chain.py

+3-7
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,10 @@
1818
import base64
1919

2020
from langchain.chains import LLMChain
21-
from llms import get_llm, get_max_tokens
21+
from llms import get_max_tokens
2222
from typing import Any, Dict, List, Union
23-
from langchain.prompts import PromptTemplate
2423
from .s3inmemoryloader import S3FileLoaderInMemory
25-
from .StreamingCallbackHandler import StreamingCallbackHandler
26-
from .helper import load_vector_db_opensearch, send_job_status, JobStatus
24+
from .helper import send_job_status, JobStatus
2725
from .image_qa import run_qa_agent_on_image_no_memory,run_qa_agent_rag_on_image_no_memory
2826
from .doc_qa import run_qa_agent_rag_no_memory, run_qa_agent_from_single_document_no_memory
2927

@@ -45,8 +43,6 @@ def run_question_answering(arguments):
4543
filename = ''
4644
arguments['filename'] = ''
4745

48-
image_url = arguments['presignedurl']
49-
5046
#set deafult modality to text
5147
qa_model= arguments['qa_model']
5248
modality=qa_model.get('modality','Text')
@@ -57,7 +53,7 @@ def run_question_answering(arguments):
5753

5854
# user didn't provide a image url as input, we use the RAG source against the entire knowledge base
5955
if response_generation_method == 'LONG_CONTEXT':
60-
if not image_url:
56+
if not filename:
6157
warning = 'Error: Image presigned url is required for LONG_CONTEXT approach, defaulting to RAG.'
6258
logger.warning(warning)
6359
llm_response = run_qa_agent_rag_on_image_no_memory(arguments)

lambda/aws-qa-appsync-opensearch/question_answering/src/qa_agent/doc_qa.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import base64
1919

2020
from langchain.chains import LLMChain
21-
from llms import get_llm, get_max_tokens
21+
from llms import get_llm
2222
from typing import Any, Dict, List, Union
2323
from langchain.prompts import PromptTemplate
2424
from .s3inmemoryloader import S3FileLoaderInMemory
@@ -41,7 +41,10 @@ def run_qa_agent_rag_no_memory(input_params):
4141
logger.info("starting qa agent with rag approach without memory :: {input_params}")
4242

4343
base64_bytes = input_params['question'].encode("utf-8")
44-
embedding_model_id = input_params['embeddings_model']['modelId']
44+
embedding_model = input_params['embeddings_model']
45+
embedding_model_id = embedding_model['modelId']
46+
modality=embedding_model.get("modality", "Text")
47+
4548
qa_model_id = input_params['qa_model']['modelId']
4649
sample_string_bytes = base64.b64decode(base64_bytes)
4750
decoded_question = sample_string_bytes.decode("utf-8")
@@ -69,7 +72,8 @@ def run_qa_agent_rag_no_memory(input_params):
6972
os.environ.get('OPENSEARCH_DOMAIN_ENDPOINT'),
7073
os.environ.get('OPENSEARCH_INDEX'),
7174
os.environ.get('OPENSEARCH_SECRET_ID'),
72-
embedding_model_id)
75+
embedding_model_id,
76+
modality)
7377

7478
else:
7579
logger.info("_retriever already exists")
@@ -106,7 +110,7 @@ def run_qa_agent_rag_no_memory(input_params):
106110
# 2 : load llm using the selector
107111
streaming = input_params.get("streaming", False)
108112
callback_manager = [StreamingCallbackHandler(status_variables)] if streaming else None
109-
_qa_llm = get_llm(callback_manager)
113+
_qa_llm = get_llm(callback_manager,qa_model_id)
110114

111115
if (_qa_llm is None):
112116
logger.error('llm is None, returning')
@@ -154,6 +158,7 @@ def run_qa_agent_from_single_document_no_memory(input_params):
154158
logger.info("starting qa agent without memory single document")
155159

156160
base64_bytes = input_params['question'].encode("utf-8")
161+
qa_model_id = input_params['qa_model']['modelId']
157162

158163
sample_string_bytes = base64.b64decode(base64_bytes)
159164
decoded_question = sample_string_bytes.decode("utf-8")
@@ -200,7 +205,7 @@ def run_qa_agent_from_single_document_no_memory(input_params):
200205
# 2 : run the question
201206
streaming = input_params.get("streaming", False)
202207
callback_manager = [StreamingCallbackHandler(status_variables)] if streaming else None
203-
_qa_llm = get_llm(callback_manager)
208+
_qa_llm = get_llm(callback_manager,qa_model_id)
204209

205210
if (_qa_llm is None):
206211
logger.info('llm is None, returning')

lambda/aws-qa-appsync-opensearch/question_answering/src/qa_agent/helper.py

+35-13
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,10 @@
1010
# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions
1111
# and limitations under the License.
1212
#
13+
from pathlib import Path
14+
from aiohttp import ClientError
1315
from langchain_community.vectorstores import OpenSearchVectorSearch
14-
from opensearchpy import RequestsHttpConnection
16+
#from opensearchpy import RequestsHttpConnection
1517
from llms import get_embeddings_llm
1618
import requests
1719
import os
@@ -57,7 +59,7 @@ class JobStatus(Enum):
5759
base64.b64encode("Sorry, it seems an issue happened on my end, and I'm not able to answer your question. Please contact an administrator to understand why !".encode('utf-8'))
5860
)
5961
ERROR_SEMANTIC_SEARCH = (
60-
'Exception during simialirty search, Please verify model for the selected modality',
62+
'Exception during similarity search, Please verify model for the selected modality',
6163
base64.b64encode("Sorry, it seems an issue happened on my end, and I'm not able to answer your question. Please contact an administrator to understand why !".encode('utf-8'))
6264
)
6365

@@ -97,7 +99,8 @@ def load_vector_db_opensearch(region: str,
9799
opensearch_domain_endpoint: str,
98100
opensearch_index: str,
99101
secret_id: str,
100-
model_id: str) -> OpenSearchVectorSearch:
102+
model_id: str,
103+
modality: str) -> OpenSearchVectorSearch:
101104
print(f"load_vector_db_opensearch, region={region}, "
102105
f"opensearch_domain_endpoint={opensearch_domain_endpoint}, opensearch_index={opensearch_index}")
103106

@@ -114,16 +117,17 @@ def load_vector_db_opensearch(region: str,
114117
opensearch_api_name,
115118
session_token=credentials.token,
116119
)
117-
embedding_function = get_embeddings_llm(model_id)
120+
embedding_function = get_embeddings_llm(model_id,modality)
118121

119122
opensearch_url = opensearch_domain_endpoint if opensearch_domain_endpoint.startswith("https://") else f"https://{opensearch_domain_endpoint}"
120-
vector_db = OpenSearchVectorSearch(index_name=opensearch_index,
121-
embedding_function=embedding_function,
122-
opensearch_url=opensearch_url,
123-
http_auth=http_auth,
124-
use_ssl = True,
125-
verify_certs = True,
126-
connection_class = RequestsHttpConnection)
123+
# vector_db = OpenSearchVectorSearch(index_name=opensearch_index,
124+
# embedding_function=embedding_function,
125+
# opensearch_url=opensearch_url,
126+
# http_auth=http_auth,
127+
# use_ssl = True,
128+
# verify_certs = True,
129+
# connection_class = RequestsHttpConnection)
130+
vector_db=""
127131
print(f"returning handle to OpenSearchVectorSearch, vector_db={vector_db}")
128132
return vector_db
129133

@@ -155,7 +159,8 @@ def send_job_status(variables):
155159

156160
print(request)
157161

158-
GRAPHQL_URL = os.environ['GRAPHQL_URL']
162+
#GRAPHQL_URL = os.environ['GRAPHQL_URL']
163+
GRAPHQL_URL ="https://j2uzmlvujbhbzoduvpctgkpu2e.appsync-api.us-east-1.amazonaws.com/graphql"
159164
HEADERS={
160165
"Content-Type": "application/json",
161166
}
@@ -180,4 +185,21 @@ def get_presigned_url(bucket,key) -> str:
180185
return url
181186
except Exception as exception:
182187
print(f"Reason: {exception}")
183-
return ""
188+
return ""
189+
190+
def download_file(bucket,key )-> str:
191+
try:
192+
file_path = "/tmp/" + os.path.basename(key)
193+
s3.download_file(bucket, key,file_path)
194+
print(f"file downloaded {file_path}")
195+
return file_path
196+
except ClientError as client_err:
197+
print(f"Couldn\'t download file {client_err.response['Error']['Message']}")
198+
199+
except Exception as exp:
200+
print(f"Couldn\'t download file : {exp}")
201+
202+
def encode_image_to_base64(image_file_path,image_file) -> str:
203+
with open(image_file_path, "rb") as image_file:
204+
b64_image = base64.b64encode(image_file.read()).decode('utf8')
205+
return b64_image

0 commit comments

Comments
 (0)