Skip to content

Commit 16ac265

Browse files
author
Dinesh Sajwan
committed
feat(imageqa): visual qa and doc qa fixes
1 parent b3d1e12 commit 16ac265

File tree

6 files changed

+554
-290
lines changed

6 files changed

+554
-290
lines changed

lambda/aws-qa-appsync-opensearch/question_answering/src/qa_agent/chain.py

+78-255
Original file line numberDiff line numberDiff line change
@@ -11,19 +11,21 @@
1111
# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions
1212
# and limitations under the License.
1313
#
14-
from .helper import load_vector_db_opensearch, send_job_status, JobStatus
15-
from .s3inmemoryloader import S3FileLoaderInMemory
16-
from .StreamingCallbackHandler import StreamingCallbackHandler
17-
from langchain.prompts import PromptTemplate
18-
from llms import get_llm, get_max_tokens
19-
from langchain.chains import LLMChain
20-
from .image_qa import run_qa_agent_on_image_no_memory
2114

2215

2316
import boto3
2417
import os
2518
import base64
19+
20+
from langchain.chains import LLMChain
21+
from llms import get_llm, get_max_tokens
2622
from typing import Any, Dict, List, Union
23+
from langchain.prompts import PromptTemplate
24+
from .s3inmemoryloader import S3FileLoaderInMemory
25+
from .StreamingCallbackHandler import StreamingCallbackHandler
26+
from .helper import load_vector_db_opensearch, send_job_status, JobStatus
27+
from .image_qa import run_qa_agent_on_image_no_memory,run_qa_agent_rag_on_image_no_memory
28+
from .doc_qa import run_qa_agent_rag_no_memory, run_qa_agent_from_single_document_no_memory
2729

2830
from aws_lambda_powertools import Logger, Tracer, Metrics
2931

@@ -38,258 +40,79 @@ def run_question_answering(arguments):
3840
response_generation_method = arguments.get('responseGenerationMethod', 'LONG_CONTEXT')
3941

4042
try:
41-
filename = arguments['filename']
42-
image_url = arguments['presignedurl']
43-
43+
filename = arguments['filename']
4444
except:
45-
4645
filename = ''
4746
arguments['filename'] = ''
48-
49-
if image_url: # if image presigned url is present then do a QA on image file
50-
llm_response = run_qa_agent_on_image_no_memory(arguments)
51-
return llm_response
5247

53-
if not filename: # user didn't provide a specific file as input, we use the RAG source against the entire knowledge base
54-
if response_generation_method == 'LONG_CONTEXT':
55-
error = 'Error: Filename required for LONG_CONTEXT approach, defaulting to RAG.'
56-
logger.error(error)
48+
image_url = arguments['presignedurl']
49+
50+
#set deafult modality to text
51+
qa_model= arguments['qa_model']['modality']
52+
modality=qa_model.get('modality','Text')
53+
54+
# Visual QA
55+
if modality.lower()=='image':
56+
logger.info("Running QA for Image modality")
57+
58+
# user didn't provide a image url as input, we use the RAG source against the entire knowledge base
59+
if response_generation_method == 'LONG_CONTEXT':
60+
if not image_url:
61+
warning = 'Error: Image presigned url is required for LONG_CONTEXT approach, defaulting to RAG.'
62+
logger.warning(warning)
63+
llm_response = run_qa_agent_rag_on_image_no_memory(arguments)
64+
return llm_response
65+
else:
66+
llm_response = run_qa_agent_on_image_no_memory(arguments)
67+
if response_generation_method == 'RAG':
68+
llm_response = run_qa_agent_rag_on_image_no_memory(arguments)
69+
return llm_response
70+
#pdf,txt QA
71+
else:
72+
logger.info("Running QA for text modality")
73+
if not filename: # user didn't provide a specific file as input, we use the RAG source against the entire knowledge base
74+
if response_generation_method == 'LONG_CONTEXT':
75+
error = 'Error: Filename required for LONG_CONTEXT approach, defaulting to RAG.'
76+
logger.error(error)
5777

58-
llm_response = run_qa_agent_rag_no_memory(arguments)
59-
return llm_response
60-
61-
bucket_name = os.environ['INPUT_BUCKET']
62-
63-
# select the methodology based on the input size
64-
document_number_of_tokens = S3FileLoaderInMemory(bucket_name, filename).get_document_tokens()
65-
66-
if document_number_of_tokens is None:
67-
logger.exception(
68-
f'Failed to compute the number of tokens for file {filename} in bucket {bucket_name}, returning')
69-
error = JobStatus.ERROR_LOAD_INFO.get_message()
70-
status_variables = {
71-
'jobstatus': JobStatus.ERROR_LOAD_INFO.status,
72-
'answer': error.decode("utf-8"),
73-
'jobid': arguments['jobid'],
74-
'filename': filename,
75-
'question': '',
76-
'sources': ['']
77-
}
78-
send_job_status(status_variables)
79-
return ''
80-
81-
model_max_tokens = get_max_tokens()
82-
logger.info(
83-
f'For the current question, we have a max model length of {model_max_tokens} and a document containing {document_number_of_tokens} tokens')
84-
85-
if response_generation_method == 'RAG':
86-
logger.info('Running qa agent with a RAG approach')
87-
llm_response = run_qa_agent_rag_no_memory(arguments)
88-
else:
89-
# LONG CONTEXT
90-
# why add 500 ? on top of the document content, we add the prompt. So we keep an extra 500 tokens of space just in case
91-
if (document_number_of_tokens + 250) < model_max_tokens:
92-
logger.info('Running qa agent with full document in context')
93-
llm_response = run_qa_agent_from_single_document_no_memory(arguments)
94-
else:
95-
logger.info('Running qa agent with a RAG approach due to document size')
9678
llm_response = run_qa_agent_rag_no_memory(arguments)
97-
return llm_response
98-
_doc_index = None
99-
_current_doc_index = None
100-
def run_qa_agent_rag_no_memory(input_params):
101-
logger.info("starting qa agent with rag approach without memory :: {input_params}")
102-
103-
base64_bytes = input_params['question'].encode("utf-8")
104-
embedding_model_id = input_params['embeddings_model']['modelId']
105-
qa_model_id = input_params['qa_model']['modelId']
106-
sample_string_bytes = base64.b64decode(base64_bytes)
107-
decoded_question = sample_string_bytes.decode("utf-8")
108-
109-
logger.info(decoded_question)
110-
111-
status_variables = {
112-
'jobstatus': JobStatus.WORKING.status,
113-
'answer': JobStatus.WORKING.get_message(),
114-
'jobid': input_params['jobid'],
115-
'filename': input_params['filename'],
116-
'question': input_params['question'],
117-
'sources': ['']
118-
}
119-
send_job_status(status_variables)
120-
121-
# 1. Load index and question related content
122-
global _doc_index
123-
global _current_doc_index
124-
125-
if _doc_index is None:
126-
logger.info("loading opensearch retriever")
127-
doc_index = load_vector_db_opensearch(boto3.Session().region_name,
128-
os.environ.get('OPENSEARCH_API_NAME'),
129-
os.environ.get('OPENSEARCH_DOMAIN_ENDPOINT'),
130-
os.environ.get('OPENSEARCH_INDEX'),
131-
os.environ.get('OPENSEARCH_SECRET_ID'),
132-
embedding_model_id)
133-
134-
else:
135-
logger.info("_retriever already exists")
136-
137-
_current_doc_index = _doc_index
138-
139-
logger.info("Starting similarity search")
140-
max_docs = input_params['retrieval']['max_docs']
141-
output_file_name = input_params['filename']
142-
143-
source_documents = doc_index.similarity_search(decoded_question, k=max_docs)
144-
logger.info(source_documents)
145-
# --------------------------------------------------------------------------
146-
# If an output file is specified, filter the response to only include chunks
147-
# related to that file. The source metadata is added when embeddings are
148-
# created in the ingestion pipeline.
149-
#
150-
# TODO: Evaluate if this filter can be optimized by using the
151-
# OpenSearchVectorSearch.max_marginal_relevance_search() method instead.
152-
# See https://github.com/langchain-ai/langchain/issues/10524
153-
# --------------------------------------------------------------------------
154-
if output_file_name:
155-
source_documents = [doc for doc in source_documents if doc.metadata['source'] == output_file_name]
156-
logger.info(source_documents)
157-
status_variables['sources'] = list(set(doc.metadata['source'] for doc in source_documents))
158-
159-
# 2 : load llm using the selector
160-
streaming = input_params.get("streaming", False)
161-
callback_manager = [StreamingCallbackHandler(status_variables)] if streaming else None
162-
_qa_llm = get_llm(callback_manager)
163-
164-
if (_qa_llm is None):
165-
logger.info('llm is None, returning')
166-
status_variables['jobstatus'] = JobStatus.ERROR_LOAD_LLM.status
167-
error = JobStatus.ERROR_LOAD_LLM.get_message()
168-
status_variables['answer'] = error.decode("utf-8")
169-
send_job_status(status_variables)
170-
return status_variables
171-
172-
# 3. Run it
173-
template = """\n\nHuman: {context}
174-
Answer from this text: {question}
175-
\n\nAssistant:"""
176-
prompt = PromptTemplate(template=template, input_variables=["context", "question"])
177-
chain = LLMChain(llm=_qa_llm, prompt=prompt, verbose=input_params['verbose'])
178-
179-
try:
180-
tmp = chain.predict(context=source_documents, question=decoded_question)
181-
answer = tmp.removeprefix(' ')
182-
183-
logger.info(f'answer is: {answer}')
184-
llm_answer_bytes = answer.encode("utf-8")
185-
base64_bytes = base64.b64encode(llm_answer_bytes)
186-
llm_answer_base64_string = base64_bytes.decode("utf-8")
187-
188-
status_variables['jobstatus'] = JobStatus.DONE.status
189-
status_variables['answer'] = llm_answer_base64_string
190-
send_job_status(status_variables) if not streaming else None
191-
192-
except Exception as err:
193-
logger.exception(err)
194-
status_variables['jobstatus'] = JobStatus.ERROR_PREDICTION.status
195-
error = JobStatus.ERROR_PREDICTION.get_message()
196-
status_variables['answer'] = error.decode("utf-8")
197-
send_job_status(status_variables)
198-
199-
return status_variables
200-
201-
202-
_file_content = None
203-
_current_file_name = None
204-
205-
206-
def run_qa_agent_from_single_document_no_memory(input_params):
207-
logger.info("starting qa agent without memory single document")
208-
209-
base64_bytes = input_params['question'].encode("utf-8")
210-
211-
sample_string_bytes = base64.b64decode(base64_bytes)
212-
decoded_question = sample_string_bytes.decode("utf-8")
213-
214-
logger.info(decoded_question)
215-
216-
status_variables = {
217-
'jobstatus': JobStatus.WORKING.status,
218-
'answer': JobStatus.WORKING.get_message(),
219-
'jobid': input_params['jobid'],
220-
'filename': input_params['filename'],
221-
'question': input_params['question'],
222-
'sources': ['']
223-
}
224-
send_job_status(status_variables)
225-
226-
# 1 : load the document
227-
global _file_content
228-
global _current_file_name
229-
230-
bucket_name = os.environ['INPUT_BUCKET']
231-
filename = input_params['filename']
232-
logger.info(f"Generating response to question for file {filename}")
233-
234-
if _current_file_name != filename:
235-
logger.info('loading file content')
236-
_file_content = S3FileLoaderInMemory(bucket_name, filename).load()
237-
else:
238-
if _file_content is None:
239-
logger.info('loading cached file content')
79+
return llm_response
80+
81+
bucket_name = os.environ['INPUT_BUCKET']
82+
83+
# select the methodology based on the input size
84+
document_number_of_tokens = S3FileLoaderInMemory(bucket_name, filename).get_document_tokens()
85+
86+
if document_number_of_tokens is None:
87+
logger.exception(
88+
f'Failed to compute the number of tokens for file {filename} in bucket {bucket_name}, returning')
89+
error = JobStatus.ERROR_LOAD_INFO.get_message()
90+
status_variables = {
91+
'jobstatus': JobStatus.ERROR_LOAD_INFO.status,
92+
'answer': error.decode("utf-8"),
93+
'jobid': arguments['jobid'],
94+
'filename': filename,
95+
'question': '',
96+
'sources': ['']
97+
}
98+
send_job_status(status_variables)
99+
return ''
100+
101+
model_max_tokens = get_max_tokens()
102+
logger.info(
103+
f'For the current question, we have a max model length of {model_max_tokens} and a document containing {document_number_of_tokens} tokens')
104+
105+
if response_generation_method == 'RAG':
106+
logger.info('Running qa agent with a RAG approach')
107+
llm_response = run_qa_agent_rag_no_memory(arguments)
240108
else:
241-
logger.info('same file as before, but nothing cached')
242-
_file_content = S3FileLoaderInMemory(bucket_name, filename).load()
243-
244-
_current_file_name = filename
245-
status_variables['sources'] = [filename]
246-
if _file_content is None:
247-
status_variables['jobstatus'] = JobStatus.ERROR_LOAD_DOC.status
248-
error = JobStatus.ERROR_LOAD_DOC.get_message()
249-
status_variables['answer'] = error.decode("utf-8")
250-
send_job_status(status_variables)
251-
return
252-
253-
# 2 : run the question
254-
streaming = input_params.get("streaming", False)
255-
callback_manager = [StreamingCallbackHandler(status_variables)] if streaming else None
256-
_qa_llm = get_llm(callback_manager)
257-
258-
if (_qa_llm is None):
259-
logger.info('llm is None, returning')
260-
status_variables['jobstatus'] = JobStatus.ERROR_LOAD_LLM.status
261-
error = JobStatus.ERROR_LOAD_LLM.get_message()
262-
status_variables['answer'] = error.decode("utf-8")
263-
send_job_status(status_variables)
264-
return status_variables
265-
266-
# 3: run LLM
267-
template = """\n\nHuman: {context}
268-
Answer from this text: {question}
269-
\n\nAssistant:"""
270-
prompt = PromptTemplate(template=template, input_variables=["context", "question"])
271-
chain = LLMChain(llm=_qa_llm, prompt=prompt, verbose=input_params['verbose'])
272-
273-
try:
274-
logger.info(f'file content is: {_file_content}')
275-
logger.info(f'decoded_question is: {decoded_question}')
276-
tmp = chain.predict(context=_file_content, question=decoded_question)
277-
answer = tmp.removeprefix(' ')
278-
279-
logger.info(f'answer is: {answer}')
280-
llm_answer_bytes = answer.encode("utf-8")
281-
base64_bytes = base64.b64encode(llm_answer_bytes)
282-
llm_answer_base64_string = base64_bytes.decode("utf-8")
283-
284-
status_variables['jobstatus'] = JobStatus.DONE.status
285-
status_variables['answer'] = llm_answer_base64_string
286-
send_job_status(status_variables) if not streaming else None
287-
288-
except Exception as err:
289-
logger.exception(err)
290-
status_variables['jobstatus'] = JobStatus.ERROR_PREDICTION.status
291-
error = JobStatus.ERROR_PREDICTION.get_message()
292-
status_variables['answer'] = error.decode("utf-8")
293-
send_job_status(status_variables)
294-
295-
return status_variables
109+
# LONG CONTEXT
110+
# why add 500 ? on top of the document content, we add the prompt. So we keep an extra 500 tokens of space just in case
111+
if (document_number_of_tokens + 250) < model_max_tokens:
112+
logger.info('Running qa agent with full document in context')
113+
llm_response = run_qa_agent_from_single_document_no_memory(arguments)
114+
else:
115+
logger.info('Running qa agent with a RAG approach due to document size')
116+
llm_response = run_qa_agent_rag_no_memory(arguments)
117+
return llm_response
118+

0 commit comments

Comments
 (0)