11
11
# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions
12
12
# and limitations under the License.
13
13
#
14
- from .helper import load_vector_db_opensearch , send_job_status , JobStatus
15
- from .s3inmemoryloader import S3FileLoaderInMemory
16
- from .StreamingCallbackHandler import StreamingCallbackHandler
17
- from langchain .prompts import PromptTemplate
18
- from llms import get_llm , get_max_tokens
19
- from langchain .chains import LLMChain
20
- from .image_qa import run_qa_agent_on_image_no_memory
21
14
22
15
23
16
import boto3
24
17
import os
25
18
import base64
19
+
20
+ from langchain .chains import LLMChain
21
+ from llms import get_llm , get_max_tokens
26
22
from typing import Any , Dict , List , Union
23
+ from langchain .prompts import PromptTemplate
24
+ from .s3inmemoryloader import S3FileLoaderInMemory
25
+ from .StreamingCallbackHandler import StreamingCallbackHandler
26
+ from .helper import load_vector_db_opensearch , send_job_status , JobStatus
27
+ from .image_qa import run_qa_agent_on_image_no_memory ,run_qa_agent_rag_on_image_no_memory
28
+ from .doc_qa import run_qa_agent_rag_no_memory , run_qa_agent_from_single_document_no_memory
27
29
28
30
from aws_lambda_powertools import Logger , Tracer , Metrics
29
31
@@ -38,258 +40,79 @@ def run_question_answering(arguments):
38
40
response_generation_method = arguments .get ('responseGenerationMethod' , 'LONG_CONTEXT' )
39
41
40
42
try :
41
- filename = arguments ['filename' ]
42
- image_url = arguments ['presignedurl' ]
43
-
43
+ filename = arguments ['filename' ]
44
44
except :
45
-
46
45
filename = ''
47
46
arguments ['filename' ] = ''
48
-
49
- if image_url : # if image presigned url is present then do a QA on image file
50
- llm_response = run_qa_agent_on_image_no_memory (arguments )
51
- return llm_response
52
47
53
- if not filename : # user didn't provide a specific file as input, we use the RAG source against the entire knowledge base
54
- if response_generation_method == 'LONG_CONTEXT' :
55
- error = 'Error: Filename required for LONG_CONTEXT approach, defaulting to RAG.'
56
- logger .error (error )
48
+ image_url = arguments ['presignedurl' ]
49
+
50
+ #set deafult modality to text
51
+ qa_model = arguments ['qa_model' ]['modality' ]
52
+ modality = qa_model .get ('modality' ,'Text' )
53
+
54
+ # Visual QA
55
+ if modality .lower ()== 'image' :
56
+ logger .info ("Running QA for Image modality" )
57
+
58
+ # user didn't provide a image url as input, we use the RAG source against the entire knowledge base
59
+ if response_generation_method == 'LONG_CONTEXT' :
60
+ if not image_url :
61
+ warning = 'Error: Image presigned url is required for LONG_CONTEXT approach, defaulting to RAG.'
62
+ logger .warning (warning )
63
+ llm_response = run_qa_agent_rag_on_image_no_memory (arguments )
64
+ return llm_response
65
+ else :
66
+ llm_response = run_qa_agent_on_image_no_memory (arguments )
67
+ if response_generation_method == 'RAG' :
68
+ llm_response = run_qa_agent_rag_on_image_no_memory (arguments )
69
+ return llm_response
70
+ #pdf,txt QA
71
+ else :
72
+ logger .info ("Running QA for text modality" )
73
+ if not filename : # user didn't provide a specific file as input, we use the RAG source against the entire knowledge base
74
+ if response_generation_method == 'LONG_CONTEXT' :
75
+ error = 'Error: Filename required for LONG_CONTEXT approach, defaulting to RAG.'
76
+ logger .error (error )
57
77
58
- llm_response = run_qa_agent_rag_no_memory (arguments )
59
- return llm_response
60
-
61
- bucket_name = os .environ ['INPUT_BUCKET' ]
62
-
63
- # select the methodology based on the input size
64
- document_number_of_tokens = S3FileLoaderInMemory (bucket_name , filename ).get_document_tokens ()
65
-
66
- if document_number_of_tokens is None :
67
- logger .exception (
68
- f'Failed to compute the number of tokens for file { filename } in bucket { bucket_name } , returning' )
69
- error = JobStatus .ERROR_LOAD_INFO .get_message ()
70
- status_variables = {
71
- 'jobstatus' : JobStatus .ERROR_LOAD_INFO .status ,
72
- 'answer' : error .decode ("utf-8" ),
73
- 'jobid' : arguments ['jobid' ],
74
- 'filename' : filename ,
75
- 'question' : '' ,
76
- 'sources' : ['' ]
77
- }
78
- send_job_status (status_variables )
79
- return ''
80
-
81
- model_max_tokens = get_max_tokens ()
82
- logger .info (
83
- f'For the current question, we have a max model length of { model_max_tokens } and a document containing { document_number_of_tokens } tokens' )
84
-
85
- if response_generation_method == 'RAG' :
86
- logger .info ('Running qa agent with a RAG approach' )
87
- llm_response = run_qa_agent_rag_no_memory (arguments )
88
- else :
89
- # LONG CONTEXT
90
- # why add 500 ? on top of the document content, we add the prompt. So we keep an extra 500 tokens of space just in case
91
- if (document_number_of_tokens + 250 ) < model_max_tokens :
92
- logger .info ('Running qa agent with full document in context' )
93
- llm_response = run_qa_agent_from_single_document_no_memory (arguments )
94
- else :
95
- logger .info ('Running qa agent with a RAG approach due to document size' )
96
78
llm_response = run_qa_agent_rag_no_memory (arguments )
97
- return llm_response
98
- _doc_index = None
99
- _current_doc_index = None
100
- def run_qa_agent_rag_no_memory (input_params ):
101
- logger .info ("starting qa agent with rag approach without memory :: {input_params}" )
102
-
103
- base64_bytes = input_params ['question' ].encode ("utf-8" )
104
- embedding_model_id = input_params ['embeddings_model' ]['modelId' ]
105
- qa_model_id = input_params ['qa_model' ]['modelId' ]
106
- sample_string_bytes = base64 .b64decode (base64_bytes )
107
- decoded_question = sample_string_bytes .decode ("utf-8" )
108
-
109
- logger .info (decoded_question )
110
-
111
- status_variables = {
112
- 'jobstatus' : JobStatus .WORKING .status ,
113
- 'answer' : JobStatus .WORKING .get_message (),
114
- 'jobid' : input_params ['jobid' ],
115
- 'filename' : input_params ['filename' ],
116
- 'question' : input_params ['question' ],
117
- 'sources' : ['' ]
118
- }
119
- send_job_status (status_variables )
120
-
121
- # 1. Load index and question related content
122
- global _doc_index
123
- global _current_doc_index
124
-
125
- if _doc_index is None :
126
- logger .info ("loading opensearch retriever" )
127
- doc_index = load_vector_db_opensearch (boto3 .Session ().region_name ,
128
- os .environ .get ('OPENSEARCH_API_NAME' ),
129
- os .environ .get ('OPENSEARCH_DOMAIN_ENDPOINT' ),
130
- os .environ .get ('OPENSEARCH_INDEX' ),
131
- os .environ .get ('OPENSEARCH_SECRET_ID' ),
132
- embedding_model_id )
133
-
134
- else :
135
- logger .info ("_retriever already exists" )
136
-
137
- _current_doc_index = _doc_index
138
-
139
- logger .info ("Starting similarity search" )
140
- max_docs = input_params ['retrieval' ]['max_docs' ]
141
- output_file_name = input_params ['filename' ]
142
-
143
- source_documents = doc_index .similarity_search (decoded_question , k = max_docs )
144
- logger .info (source_documents )
145
- # --------------------------------------------------------------------------
146
- # If an output file is specified, filter the response to only include chunks
147
- # related to that file. The source metadata is added when embeddings are
148
- # created in the ingestion pipeline.
149
- #
150
- # TODO: Evaluate if this filter can be optimized by using the
151
- # OpenSearchVectorSearch.max_marginal_relevance_search() method instead.
152
- # See https://github.com/langchain-ai/langchain/issues/10524
153
- # --------------------------------------------------------------------------
154
- if output_file_name :
155
- source_documents = [doc for doc in source_documents if doc .metadata ['source' ] == output_file_name ]
156
- logger .info (source_documents )
157
- status_variables ['sources' ] = list (set (doc .metadata ['source' ] for doc in source_documents ))
158
-
159
- # 2 : load llm using the selector
160
- streaming = input_params .get ("streaming" , False )
161
- callback_manager = [StreamingCallbackHandler (status_variables )] if streaming else None
162
- _qa_llm = get_llm (callback_manager )
163
-
164
- if (_qa_llm is None ):
165
- logger .info ('llm is None, returning' )
166
- status_variables ['jobstatus' ] = JobStatus .ERROR_LOAD_LLM .status
167
- error = JobStatus .ERROR_LOAD_LLM .get_message ()
168
- status_variables ['answer' ] = error .decode ("utf-8" )
169
- send_job_status (status_variables )
170
- return status_variables
171
-
172
- # 3. Run it
173
- template = """\n \n Human: {context}
174
- Answer from this text: {question}
175
- \n \n Assistant:"""
176
- prompt = PromptTemplate (template = template , input_variables = ["context" , "question" ])
177
- chain = LLMChain (llm = _qa_llm , prompt = prompt , verbose = input_params ['verbose' ])
178
-
179
- try :
180
- tmp = chain .predict (context = source_documents , question = decoded_question )
181
- answer = tmp .removeprefix (' ' )
182
-
183
- logger .info (f'answer is: { answer } ' )
184
- llm_answer_bytes = answer .encode ("utf-8" )
185
- base64_bytes = base64 .b64encode (llm_answer_bytes )
186
- llm_answer_base64_string = base64_bytes .decode ("utf-8" )
187
-
188
- status_variables ['jobstatus' ] = JobStatus .DONE .status
189
- status_variables ['answer' ] = llm_answer_base64_string
190
- send_job_status (status_variables ) if not streaming else None
191
-
192
- except Exception as err :
193
- logger .exception (err )
194
- status_variables ['jobstatus' ] = JobStatus .ERROR_PREDICTION .status
195
- error = JobStatus .ERROR_PREDICTION .get_message ()
196
- status_variables ['answer' ] = error .decode ("utf-8" )
197
- send_job_status (status_variables )
198
-
199
- return status_variables
200
-
201
-
202
- _file_content = None
203
- _current_file_name = None
204
-
205
-
206
- def run_qa_agent_from_single_document_no_memory (input_params ):
207
- logger .info ("starting qa agent without memory single document" )
208
-
209
- base64_bytes = input_params ['question' ].encode ("utf-8" )
210
-
211
- sample_string_bytes = base64 .b64decode (base64_bytes )
212
- decoded_question = sample_string_bytes .decode ("utf-8" )
213
-
214
- logger .info (decoded_question )
215
-
216
- status_variables = {
217
- 'jobstatus' : JobStatus .WORKING .status ,
218
- 'answer' : JobStatus .WORKING .get_message (),
219
- 'jobid' : input_params ['jobid' ],
220
- 'filename' : input_params ['filename' ],
221
- 'question' : input_params ['question' ],
222
- 'sources' : ['' ]
223
- }
224
- send_job_status (status_variables )
225
-
226
- # 1 : load the document
227
- global _file_content
228
- global _current_file_name
229
-
230
- bucket_name = os .environ ['INPUT_BUCKET' ]
231
- filename = input_params ['filename' ]
232
- logger .info (f"Generating response to question for file { filename } " )
233
-
234
- if _current_file_name != filename :
235
- logger .info ('loading file content' )
236
- _file_content = S3FileLoaderInMemory (bucket_name , filename ).load ()
237
- else :
238
- if _file_content is None :
239
- logger .info ('loading cached file content' )
79
+ return llm_response
80
+
81
+ bucket_name = os .environ ['INPUT_BUCKET' ]
82
+
83
+ # select the methodology based on the input size
84
+ document_number_of_tokens = S3FileLoaderInMemory (bucket_name , filename ).get_document_tokens ()
85
+
86
+ if document_number_of_tokens is None :
87
+ logger .exception (
88
+ f'Failed to compute the number of tokens for file { filename } in bucket { bucket_name } , returning' )
89
+ error = JobStatus .ERROR_LOAD_INFO .get_message ()
90
+ status_variables = {
91
+ 'jobstatus' : JobStatus .ERROR_LOAD_INFO .status ,
92
+ 'answer' : error .decode ("utf-8" ),
93
+ 'jobid' : arguments ['jobid' ],
94
+ 'filename' : filename ,
95
+ 'question' : '' ,
96
+ 'sources' : ['' ]
97
+ }
98
+ send_job_status (status_variables )
99
+ return ''
100
+
101
+ model_max_tokens = get_max_tokens ()
102
+ logger .info (
103
+ f'For the current question, we have a max model length of { model_max_tokens } and a document containing { document_number_of_tokens } tokens' )
104
+
105
+ if response_generation_method == 'RAG' :
106
+ logger .info ('Running qa agent with a RAG approach' )
107
+ llm_response = run_qa_agent_rag_no_memory (arguments )
240
108
else :
241
- logger .info ('same file as before, but nothing cached' )
242
- _file_content = S3FileLoaderInMemory (bucket_name , filename ).load ()
243
-
244
- _current_file_name = filename
245
- status_variables ['sources' ] = [filename ]
246
- if _file_content is None :
247
- status_variables ['jobstatus' ] = JobStatus .ERROR_LOAD_DOC .status
248
- error = JobStatus .ERROR_LOAD_DOC .get_message ()
249
- status_variables ['answer' ] = error .decode ("utf-8" )
250
- send_job_status (status_variables )
251
- return
252
-
253
- # 2 : run the question
254
- streaming = input_params .get ("streaming" , False )
255
- callback_manager = [StreamingCallbackHandler (status_variables )] if streaming else None
256
- _qa_llm = get_llm (callback_manager )
257
-
258
- if (_qa_llm is None ):
259
- logger .info ('llm is None, returning' )
260
- status_variables ['jobstatus' ] = JobStatus .ERROR_LOAD_LLM .status
261
- error = JobStatus .ERROR_LOAD_LLM .get_message ()
262
- status_variables ['answer' ] = error .decode ("utf-8" )
263
- send_job_status (status_variables )
264
- return status_variables
265
-
266
- # 3: run LLM
267
- template = """\n \n Human: {context}
268
- Answer from this text: {question}
269
- \n \n Assistant:"""
270
- prompt = PromptTemplate (template = template , input_variables = ["context" , "question" ])
271
- chain = LLMChain (llm = _qa_llm , prompt = prompt , verbose = input_params ['verbose' ])
272
-
273
- try :
274
- logger .info (f'file content is: { _file_content } ' )
275
- logger .info (f'decoded_question is: { decoded_question } ' )
276
- tmp = chain .predict (context = _file_content , question = decoded_question )
277
- answer = tmp .removeprefix (' ' )
278
-
279
- logger .info (f'answer is: { answer } ' )
280
- llm_answer_bytes = answer .encode ("utf-8" )
281
- base64_bytes = base64 .b64encode (llm_answer_bytes )
282
- llm_answer_base64_string = base64_bytes .decode ("utf-8" )
283
-
284
- status_variables ['jobstatus' ] = JobStatus .DONE .status
285
- status_variables ['answer' ] = llm_answer_base64_string
286
- send_job_status (status_variables ) if not streaming else None
287
-
288
- except Exception as err :
289
- logger .exception (err )
290
- status_variables ['jobstatus' ] = JobStatus .ERROR_PREDICTION .status
291
- error = JobStatus .ERROR_PREDICTION .get_message ()
292
- status_variables ['answer' ] = error .decode ("utf-8" )
293
- send_job_status (status_variables )
294
-
295
- return status_variables
109
+ # LONG CONTEXT
110
+ # why add 500 ? on top of the document content, we add the prompt. So we keep an extra 500 tokens of space just in case
111
+ if (document_number_of_tokens + 250 ) < model_max_tokens :
112
+ logger .info ('Running qa agent with full document in context' )
113
+ llm_response = run_qa_agent_from_single_document_no_memory (arguments )
114
+ else :
115
+ logger .info ('Running qa agent with a RAG approach due to document size' )
116
+ llm_response = run_qa_agent_rag_no_memory (arguments )
117
+ return llm_response
118
+
0 commit comments