1
+
2
+ from .helper import send_job_status , JobStatus
3
+ from langchain .prompts import PromptTemplate
4
+ from langchain .chains import LLMChain
5
+ from .sagemaker_endpoint import Ideficsllm
6
+ from .StreamingCallbackHandler import StreamingCallbackHandler
7
+ import os
8
+ import base64
9
+ from aws_lambda_powertools import Logger , Tracer , Metrics
10
+
11
+ logger = Logger (service = "QUESTION_ANSWERING" )
12
+ tracer = Tracer (service = "QUESTION_ANSWERING" )
13
+ metrics = Metrics (namespace = "question_answering" , service = "QUESTION_ANSWERING" )
14
+
15
+
16
+ def run_qa_agent_on_image_no_memory (input_params ):
17
+ logger .info ("starting qa agent without memory on uploaded image" )
18
+
19
+ base64_bytes = input_params ['question' ].encode ("utf-8" )
20
+
21
+ sample_string_bytes = base64 .b64decode (base64_bytes )
22
+ decoded_question = sample_string_bytes .decode ("utf-8" )
23
+
24
+ logger .info (decoded_question )
25
+
26
+ status_variables = {
27
+ 'jobstatus' : JobStatus .WORKING .status ,
28
+ 'answer' : JobStatus .WORKING .get_message (),
29
+ 'jobid' : input_params ['jobid' ],
30
+ 'filename' : input_params ['filename' ],
31
+ 'question' : input_params ['question' ],
32
+ 'sources' : ['' ]
33
+ }
34
+ send_job_status (status_variables )
35
+
36
+ # 1 : load the document
37
+ global _file_content
38
+ global _current_file_name
39
+
40
+ bucket_name = os .environ ['INPUT_BUCKET' ]
41
+ filename = input_params ['filename' ]
42
+ image_url = input_params ['presignedurl' ]
43
+ logger .info (f"Generating response to question for file { filename } " )
44
+
45
+
46
+
47
+ status_variables ['sources' ] = [filename ]
48
+ if image_url is None :
49
+ status_variables ['jobstatus' ] = JobStatus .ERROR_LOAD_DOC .status
50
+ error = JobStatus .ERROR_LOAD_DOC .get_message ()
51
+ status_variables ['answer' ] = error .decode ("utf-8" )
52
+ send_job_status (status_variables )
53
+ return
54
+
55
+ # 2 : run the question
56
+ streaming = input_params .get ("streaming" , False )
57
+ # TODO use streaming in response
58
+ callback_manager = [StreamingCallbackHandler (status_variables )] if streaming else None
59
+
60
+ #_qa_llm = get_llm(callback_manager,"HuggingFaceM4/idefics-80b-instruct")
61
+ #TODO : Update get_llm to support sagemaker as provider,
62
+ # this needs to be updated with @alain changes
63
+ print (f' get LLM Ideficsllm' )
64
+ _qa_llm = Ideficsllm .sagemakerendpoint_llm ("idefics" )
65
+
66
+ if (_qa_llm is None ):
67
+ logger .info ('llm is None, returning' )
68
+ status_variables ['jobstatus' ] = JobStatus .ERROR_LOAD_LLM .status
69
+ error = JobStatus .ERROR_LOAD_LLM .get_message ()
70
+ status_variables ['answer' ] = error .decode ("utf-8" )
71
+ send_job_status (status_variables )
72
+ return status_variables
73
+
74
+ # 3: run LLM
75
+ #template="User:{question}<end_of_utterance>\nAssistant:"
76
+ template = """\n \n User: {question}<end_of_utterance>
77
+ \n \n Assistant:"""
78
+ prompt = PromptTemplate (template = template , input_variables = ["image" , "question" ])
79
+ chain = LLMChain (llm = _qa_llm , prompt = prompt , verbose = input_params ['verbose' ])
80
+
81
+ try :
82
+ logger .info (f'image is: { filename } ' )
83
+ logger .info (f'decoded_question is: { decoded_question } ' )
84
+ tmp = chain .predict (image = image_url , question = decoded_question )
85
+ #answer = tmp.removeprefix(' ')
86
+ answer = tmp .split ("Assistant:" ,1 )[1 ]
87
+
88
+ logger .info (f'tmp is: { tmp } ' )
89
+ logger .info (f'answer is: { answer } ' )
90
+ llm_answer_bytes = answer .encode ("utf-8" )
91
+ base64_bytes = base64 .b64encode (llm_answer_bytes )
92
+ llm_answer_base64_string = base64_bytes .decode ("utf-8" )
93
+
94
+ status_variables ['jobstatus' ] = JobStatus .DONE .status
95
+ status_variables ['answer' ] = llm_answer_base64_string
96
+ send_job_status (status_variables ) if not streaming else None
97
+
98
+ except Exception as err :
99
+ logger .exception (err )
100
+ status_variables ['jobstatus' ] = JobStatus .ERROR_PREDICTION .status
101
+ error = JobStatus .ERROR_PREDICTION .get_message ()
102
+ status_variables ['answer' ] = error .decode ("utf-8" )
103
+ send_job_status (status_variables )
104
+
105
+ return status_variables
106
+
107
+
108
+ def run_qa_agent_rag_image_no_memory (input_params ):
109
+ logger .info ("starting qa agent RAG without memory on uploaded image" )
0 commit comments