@@ -104,12 +104,13 @@ def run_qa_agent_rag_on_image_no_memory(input_params):
104
104
qa_model_id = qa_model ['modelId' ]
105
105
embedding_model_id = input_params ['embeddings_model' ]['modelId' ]
106
106
else :
107
- logger .error (' Either qa_model_id or embedding_model_id is not present , cannot answer question using RAG, returning...' )
107
+ logger .error (' RAG based QA need both qa_model_id and embeddings_model_id, either one or both are missing , cannot answer question using RAG, returning...' )
108
108
status_variables ['jobstatus' ] = JobStatus .ERROR_LOAD_LLM .status
109
109
status_variables ['answer' ] = JobStatus .ERROR_LOAD_LLM .status
110
110
send_job_status (status_variables )
111
111
return
112
112
113
+
113
114
114
115
global _doc_index
115
116
global _current_doc_index
@@ -127,6 +128,7 @@ def get_image_from_semantic_search_in_os(input_params,status_variables):
127
128
embeddings_model = input_params ['embeddings_model' ]
128
129
embedding_model_id = embeddings_model ['modelId' ]
129
130
modality = embeddings_model .get ("modality" , "Text" )
131
+
130
132
if _doc_index is None :
131
133
logger .info ("loading opensearch retriever" )
132
134
doc_index = load_vector_db_opensearch (boto3 .Session ().region_name ,
@@ -169,16 +171,21 @@ def process_visual_qa(input_params,status_variables,filename):
169
171
170
172
qa_model = input_params ['qa_model' ]
171
173
qa_modelId = qa_model ['modelId' ]
174
+
175
+ # default model provider is bedrock and defalut modality is tEXT
172
176
modality = qa_model .get ("modality" , "Text" )
177
+ model_provider = qa_model .get ("provider" ,"Bedrock" )
178
+ logger .info (f"model provider is { model_provider } and modality is { modality } " )
179
+
173
180
base64_bytes = input_params ['question' ].encode ("utf-8" )
174
181
sample_string_bytes = base64 .b64decode (base64_bytes )
175
182
decoded_question = sample_string_bytes .decode ("utf-8" )
176
- model_provider = input_params ['qa_model' ]['provider' ]
177
-
183
+
178
184
if model_provider == 'Sagemaker Endpoint' :
179
185
_qa_llm = MultiModal .sagemakerendpoint_llm (qa_modelId )
180
186
if (_qa_llm is not None ):
181
187
status_variables ['answer' ]= generate_vision_answer_sagemaker (_qa_llm ,input_params ,decoded_question ,filename ,status_variables )
188
+ status_variables ['jobstatus' ] = JobStatus .DONE .status
182
189
else :
183
190
logger .error ('Invalid Model , cannot load LLM , returning..' )
184
191
status_variables ['jobstatus' ] = JobStatus .ERROR_LOAD_LLM .status
@@ -191,6 +198,9 @@ def process_visual_qa(input_params,status_variables,filename):
191
198
local_file_path = download_file (bucket_name ,filename )
192
199
base64_images = encode_image_to_base64 (local_file_path ,filename )
193
200
status_variables ['answer' ]= generate_vision_answer_bedrock (_qa_llm ,base64_images , qa_modelId ,decoded_question )
201
+ status_variables ['jobstatus' ] = JobStatus .DONE .status
202
+ streaming = input_params .get ("streaming" , False )
203
+
194
204
else :
195
205
logger .error ('Invalid Model , cannot load LLM , returning..' )
196
206
status_variables ['jobstatus' ] = JobStatus .ERROR_LOAD_LLM .status
@@ -233,16 +243,12 @@ def generate_vision_answer_sagemaker(_qa_llm,input_params,decoded_question,statu
233
243
234
244
status_variables ['jobstatus' ] = JobStatus .DONE .status
235
245
status_variables ['answer' ] = llm_answer_base64_string
236
- streaming = input_params .get ("streaming" , False )
237
-
238
- send_job_status (status_variables ) if not streaming else None
239
-
246
+
240
247
except Exception as err :
241
248
logger .exception (err )
242
249
status_variables ['jobstatus' ] = JobStatus .ERROR_PREDICTION .status
243
250
error = JobStatus .ERROR_PREDICTION .get_message ()
244
251
status_variables ['answer' ] = error .decode ("utf-8" )
245
- send_job_status (status_variables )
246
252
247
253
return status_variables
248
254
@@ -279,7 +285,6 @@ def generate_vision_answer_bedrock(bedrock_client,base64_images, model_id,decode
279
285
},
280
286
{
281
287
"type" : "text" ,
282
- #"text": "Describe the architecture and code terraform script to deploy it, answer inside <answer></answer> tags."
283
288
"text" : prompt
284
289
285
290
}
0 commit comments