24
24
from langchain .chains import LLMChain
25
25
from .sagemaker_endpoint import MultiModal
26
26
from aws_lambda_powertools import Logger , Tracer , Metrics
27
- from .StreamingCallbackHandler import StreamingCallbackHandler
28
27
from adapters import registry
29
28
30
29
from .helper import download_file , load_vector_db_opensearch ,send_job_status , JobStatus ,get_presigned_url ,encode_image_to_base64
@@ -181,7 +180,7 @@ def process_visual_qa(input_params,status_variables,filename):
181
180
182
181
qa_model = input_params ['qa_model' ]
183
182
qa_modelId = qa_model ['modelId' ]
184
-
183
+ streaming = qa_model . get ( "streaming" , False )
185
184
# default model provider is bedrock and defalut modality is tEXT
186
185
modality = qa_model .get ("modality" , "Text" )
187
186
model_provider = qa_model .get ("provider" ,Provider .BEDROCK )
@@ -207,28 +206,20 @@ def process_visual_qa(input_params,status_variables,filename):
207
206
if (_qa_llm is not None ):
208
207
local_file_path = download_file (bucket_name ,filename )
209
208
base64_images = encode_image_to_base64 (local_file_path ,filename )
210
- status_variables ['answer' ]= generate_vision_answer_bedrock (_qa_llm ,base64_images , qa_modelId ,decoded_question )
211
- if (status_variables ['answer' ] is None ):
212
- status_variables ['answer' ] = JobStatus .ERROR_PREDICTION .status
213
- error = JobStatus .ERROR_PREDICTION .get_message ()
214
- status_variables ['answer' ] = error .decode ("utf-8" )
215
- status_variables ['jobstatus' ] = JobStatus .ERROR_PREDICTION .status
216
- else :
217
- status_variables ['jobstatus' ] = JobStatus .DONE .status
218
- streaming = input_params .get ("streaming" , False )
219
-
209
+ generate_vision_answer_bedrock (_qa_llm ,base64_images , qa_modelId ,decoded_question ,status_variables ,streaming )
220
210
else :
221
211
logger .error ('Invalid Model , cannot load LLM , returning..' )
222
212
status_variables ['jobstatus' ] = JobStatus .ERROR_LOAD_LLM .status
223
213
error = JobStatus .ERROR_LOAD_LLM .get_message ()
224
214
status_variables ['answer' ] = error .decode ("utf-8" )
215
+ send_job_status (status_variables )
225
216
else :
226
217
logger .error ('Invalid Model provider, cannot load LLM , returning..' )
227
218
status_variables ['jobstatus' ] = JobStatus .ERROR_LOAD_LLM .status
228
219
error = JobStatus .ERROR_LOAD_LLM .get_message ()
229
220
status_variables ['answer' ] = error .decode ("utf-8" )
221
+ send_job_status (status_variables )
230
222
231
- send_job_status (status_variables )
232
223
return status_variables
233
224
234
225
def generate_vision_answer_sagemaker (_qa_llm ,input_params ,decoded_question ,status_variables ,filename ):
@@ -269,18 +260,9 @@ def generate_vision_answer_sagemaker(_qa_llm,input_params,decoded_question,statu
269
260
270
261
return status_variables
271
262
272
- def generate_vision_answer_bedrock (bedrock_client ,base64_images ,model_id ,decoded_question ):
273
- system_prompt = ""
274
- # use system prompt for fine tuning the performamce
275
- # system_prompt= """
276
- # You have perfect vision and pay great attention to detail which
277
- # makes you an expert at answering architecture diagram question.
278
- # Answer question in <question></question> tags. Before answer,
279
- # think step by step in <thinking> tags and analyze every part of the diagram.
280
- # """
281
- #Create a prompt with the question
282
- prompt = f"<question>{ decoded_question } </question>. Answer must be a numbered list in a small paragraph inside <answer></answer> tag."
283
-
263
+ def generate_vision_answer_bedrock (bedrock_client ,base64_images ,model_id ,
264
+ decoded_question ,status_variables ,streaming ):
265
+
284
266
claude_config = {
285
267
'max_tokens' : 1000 ,
286
268
'temperature' : 0 ,
@@ -302,34 +284,57 @@ def generate_vision_answer_bedrock(bedrock_client,base64_images,model_id,decoded
302
284
},
303
285
{
304
286
"type" : "text" ,
305
- "text" : prompt
287
+ "text" : decoded_question
306
288
307
289
}
308
290
]
309
291
}
310
292
311
- body = json .dumps ({'messages' : [messages ],** claude_config , "system" : system_prompt })
293
+ body = json .dumps ({'messages' : [messages ],** claude_config })
294
+
312
295
try :
313
- response = bedrock_client .invoke_model (
314
- body = body , modelId = model_id , accept = "application/json" ,
315
- contentType = "application/json"
316
- )
296
+ if streaming :
297
+ response = bedrock_client .invoke_model_with_response_stream (
298
+ body = body , modelId = model_id , accept = "application/json" ,
299
+ contentType = "application/json"
300
+ )
301
+ for event in response .get ("body" ):
302
+ chunk = json .loads (event ["chunk" ]["bytes" ])
303
+
304
+ if chunk ['type' ] == 'message_delta' :
305
+ status_variables ['answer' ]= ''
306
+ status_variables ['jobstatus' ] = JobStatus .STREAMING_ENDED .status
307
+ send_job_status (status_variables )
308
+
309
+ if chunk ['type' ] == 'content_block_delta' :
310
+ if chunk ['delta' ]['type' ] == 'text_delta' :
311
+ logger .info (chunk ['delta' ]['text' ], end = "" )
312
+ chuncked_text = chunk ['delta' ]['text' ]
313
+ llm_answer_bytes = json .dumps (chuncked_text ).encode ("utf-8" )
314
+ base64_bytes = base64 .b64encode (llm_answer_bytes )
315
+ llm_answer_base64_string = base64_bytes .decode ("utf-8" )
316
+ status_variables ['answer' ]= llm_answer_base64_string
317
+ status_variables ['jobstatus' ] = JobStatus .STREAMING_NEW_TOKEN .status
318
+ send_job_status (status_variables )
319
+
320
+
321
+ else :
322
+ response = bedrock_client .invoke_model (
323
+ body = body , modelId = model_id , accept = "application/json" ,
324
+ contentType = "application/json"
325
+ )
326
+ response_body = json .loads (response .get ('body' ).read ())
327
+ logger .info (f'answer is: { response_body } ' )
328
+ output_list = response_body .get ("content" , [])
329
+ for output in output_list :
330
+ llm_answer_bytes = json .dumps (output ["text" ]).encode ("utf-8" )
331
+ base64_bytes = base64 .b64encode (llm_answer_bytes )
332
+ llm_answer_base64_string = base64_bytes .decode ("utf-8" )
333
+ status_variables ['jobstatus' ] = JobStatus .DONE .status
334
+ status_variables ['answer' ]= llm_answer_base64_string
335
+ send_job_status (status_variables )
336
+
317
337
except Exception as err :
318
338
logger .exception (f'Error occurred , Reason :{ err } ' )
319
339
return None
320
-
321
- response = json .loads (response ['body' ].read ().decode ('utf-8' ))
322
-
323
- formated_response = response ['content' ][0 ]['text' ]
324
- answer = re .findall (r'<answer>(.*?)</answer>' , formated_response , re .DOTALL )
325
- formatted_answer = answer [0 ]
326
- llm_answer_bytes = formatted_answer .encode ("utf-8" )
327
- print (f' formatted_answer { formatted_answer } ' )
328
- base64_bytes = base64 .b64encode (llm_answer_bytes )
329
- print (f' base64_bytes' )
330
- llm_answer_base64_string = base64_bytes .decode ("utf-8" )
331
-
332
- print (f' llm_answer_base64_string { llm_answer_base64_string } ' )
333
-
334
- return llm_answer_base64_string
335
-
340
+
0 commit comments