Skip to content

Commit f04ea6b

Browse files
author
Dinesh Sajwan
committed
feat(visualqa): question answer on uploaded image
1 parent 16e8779 commit f04ea6b

File tree

5 files changed

+259
-74
lines changed

5 files changed

+259
-74
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
from .helper import send_job_status, JobStatus
2+
from langchain.callbacks.base import BaseCallbackHandler
3+
from langchain.schema import LLMResult
4+
import base64
5+
from typing import Any, Dict, List, Union
6+
7+
from aws_lambda_powertools import Logger, Tracer, Metrics
8+
9+
logger = Logger(service="QUESTION_ANSWERING")
10+
tracer = Tracer(service="QUESTION_ANSWERING")
11+
metrics = Metrics(namespace="question_answering", service="QUESTION_ANSWERING")
12+
13+
class StreamingCallbackHandler(BaseCallbackHandler):
14+
def __init__(self, status_variables: Dict):
15+
self.status_variables = status_variables
16+
logger.info("[StreamingCallbackHandler::__init__] Initialized")
17+
18+
def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> None:
19+
"""Runs when streaming is started."""
20+
logger.info(f"[StreamingCallbackHandler::on_llm_start] Streaming started!")
21+
22+
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
23+
"""Run on new LLM token. Only available when streaming is enabled."""
24+
try:
25+
logger.info(f'[StreamingCallbackHandler::on_llm_new_token] token is: {token}')
26+
llm_answer_bytes = token.encode("utf-8")
27+
base64_bytes = base64.b64encode(llm_answer_bytes)
28+
llm_answer_base64_string = base64_bytes.decode("utf-8")
29+
30+
self.status_variables['jobstatus'] = JobStatus.STREAMING_NEW_TOKEN.status
31+
self.status_variables['answer'] = llm_answer_base64_string
32+
send_job_status(self.status_variables)
33+
34+
except Exception as err:
35+
logger.exception(err)
36+
self.status_variables['jobstatus'] = JobStatus.ERROR_PREDICTION.status
37+
error = JobStatus.ERROR_PREDICTION.get_message()
38+
self.status_variables['answer'] = error.decode("utf-8")
39+
send_job_status(self.status_variables)
40+
41+
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
42+
"""Run when LLM ends running."""
43+
logger.info(f"[StreamingCallbackHandler::on_llm_end] Streaming ended. Response: {response}")
44+
try:
45+
self.status_variables['jobstatus'] = JobStatus.STREAMING_ENDED.status
46+
self.status_variables['answer'] = ""
47+
send_job_status(self.status_variables)
48+
49+
except Exception as err:
50+
logger.exception(err)
51+
self.status_variables['jobstatus'] = JobStatus.ERROR_PREDICTION.status
52+
error = JobStatus.ERROR_PREDICTION.get_message()
53+
self.status_variables['answer'] = error.decode("utf-8")
54+
send_job_status(self.status_variables)
55+
56+
def on_llm_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> None:
57+
"""Run when LLM errors."""
58+
logger.exception(error)
59+
self.status_variables['jobstatus'] = JobStatus.ERROR_PREDICTION.status
60+
error = JobStatus.ERROR_PREDICTION.get_message()
61+
self.status_variables['answer'] = error.decode("utf-8")
62+
send_job_status(self.status_variables)

lambda/aws-qa-appsync-opensearch/question_answering/src/qa_agent/chain.py

+10-52
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,12 @@
1313
#
1414
from .helper import load_vector_db_opensearch, send_job_status, JobStatus
1515
from .s3inmemoryloader import S3FileLoaderInMemory
16+
from .StreamingCallbackHandler import StreamingCallbackHandler
1617
from langchain.prompts import PromptTemplate
17-
from langchain.callbacks.base import BaseCallbackHandler
18-
from langchain.schema import LLMResult
1918
from llms import get_llm, get_max_tokens
2019
from langchain.chains import LLMChain
20+
from .image_qa import run_qa_agent_on_image_no_memory
21+
2122

2223
import boto3
2324
import os
@@ -29,56 +30,7 @@
2930
logger = Logger(service="QUESTION_ANSWERING")
3031
tracer = Tracer(service="QUESTION_ANSWERING")
3132
metrics = Metrics(namespace="question_answering", service="QUESTION_ANSWERING")
32-
class StreamingCallbackHandler(BaseCallbackHandler):
33-
def __init__(self, status_variables: Dict):
34-
self.status_variables = status_variables
35-
logger.info("[StreamingCallbackHandler::__init__] Initialized")
36-
37-
def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> None:
38-
"""Runs when streaming is started."""
39-
logger.info(f"[StreamingCallbackHandler::on_llm_start] Streaming started!")
40-
41-
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
42-
"""Run on new LLM token. Only available when streaming is enabled."""
43-
try:
44-
logger.info(f'[StreamingCallbackHandler::on_llm_new_token] token is: {token}')
45-
llm_answer_bytes = token.encode("utf-8")
46-
base64_bytes = base64.b64encode(llm_answer_bytes)
47-
llm_answer_base64_string = base64_bytes.decode("utf-8")
48-
49-
self.status_variables['jobstatus'] = JobStatus.STREAMING_NEW_TOKEN.status
50-
self.status_variables['answer'] = llm_answer_base64_string
51-
send_job_status(self.status_variables)
52-
53-
except Exception as err:
54-
logger.exception(err)
55-
self.status_variables['jobstatus'] = JobStatus.ERROR_PREDICTION.status
56-
error = JobStatus.ERROR_PREDICTION.get_message()
57-
self.status_variables['answer'] = error.decode("utf-8")
58-
send_job_status(self.status_variables)
59-
60-
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
61-
"""Run when LLM ends running."""
62-
logger.info(f"[StreamingCallbackHandler::on_llm_end] Streaming ended. Response: {response}")
63-
try:
64-
self.status_variables['jobstatus'] = JobStatus.STREAMING_ENDED.status
65-
self.status_variables['answer'] = ""
66-
send_job_status(self.status_variables)
67-
68-
except Exception as err:
69-
logger.exception(err)
70-
self.status_variables['jobstatus'] = JobStatus.ERROR_PREDICTION.status
71-
error = JobStatus.ERROR_PREDICTION.get_message()
72-
self.status_variables['answer'] = error.decode("utf-8")
73-
send_job_status(self.status_variables)
74-
75-
def on_llm_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> None:
76-
"""Run when LLM errors."""
77-
logger.exception(error)
78-
self.status_variables['jobstatus'] = JobStatus.ERROR_PREDICTION.status
79-
error = JobStatus.ERROR_PREDICTION.get_message()
80-
self.status_variables['answer'] = error.decode("utf-8")
81-
send_job_status(self.status_variables)
33+
8234

8335

8436
@tracer.capture_method
@@ -87,11 +39,17 @@ def run_question_answering(arguments):
8739

8840
try:
8941
filename = arguments['filename']
42+
image_url = arguments['presignedurl']
43+
9044
except:
9145

9246
filename = ''
9347
arguments['filename'] = ''
9448

49+
if image_url: # if image presigned url is present then do a QA on image file
50+
llm_response = run_qa_agent_on_image_no_memory(arguments)
51+
return llm_response
52+
9553
if not filename: # user didn't provide a specific file as input, we use the RAG source against the entire knowledge base
9654
if response_generation_method == 'LONG_CONTEXT':
9755
error = 'Error: Filename required for LONG_CONTEXT approach, defaulting to RAG.'
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
2+
from .helper import send_job_status, JobStatus
3+
from langchain.prompts import PromptTemplate
4+
from langchain.chains import LLMChain
5+
from .sagemaker_endpoint import Ideficsllm
6+
from .StreamingCallbackHandler import StreamingCallbackHandler
7+
import os
8+
import base64
9+
from aws_lambda_powertools import Logger, Tracer, Metrics
10+
11+
logger = Logger(service="QUESTION_ANSWERING")
12+
tracer = Tracer(service="QUESTION_ANSWERING")
13+
metrics = Metrics(namespace="question_answering", service="QUESTION_ANSWERING")
14+
15+
16+
def run_qa_agent_on_image_no_memory(input_params):
17+
logger.info("starting qa agent without memory on uploaded image")
18+
19+
base64_bytes = input_params['question'].encode("utf-8")
20+
21+
sample_string_bytes = base64.b64decode(base64_bytes)
22+
decoded_question = sample_string_bytes.decode("utf-8")
23+
24+
logger.info(decoded_question)
25+
26+
status_variables = {
27+
'jobstatus': JobStatus.WORKING.status,
28+
'answer': JobStatus.WORKING.get_message(),
29+
'jobid': input_params['jobid'],
30+
'filename': input_params['filename'],
31+
'question': input_params['question'],
32+
'sources': ['']
33+
}
34+
send_job_status(status_variables)
35+
36+
# 1 : load the document
37+
global _file_content
38+
global _current_file_name
39+
40+
bucket_name = os.environ['INPUT_BUCKET']
41+
filename = input_params['filename']
42+
image_url = input_params['presignedurl']
43+
logger.info(f"Generating response to question for file {filename}")
44+
45+
46+
47+
status_variables['sources'] = [filename]
48+
if image_url is None:
49+
status_variables['jobstatus'] = JobStatus.ERROR_LOAD_DOC.status
50+
error = JobStatus.ERROR_LOAD_DOC.get_message()
51+
status_variables['answer'] = error.decode("utf-8")
52+
send_job_status(status_variables)
53+
return
54+
55+
# 2 : run the question
56+
streaming = input_params.get("streaming", False)
57+
# TODO use streaming in response
58+
callback_manager = [StreamingCallbackHandler(status_variables)] if streaming else None
59+
60+
#_qa_llm = get_llm(callback_manager,"HuggingFaceM4/idefics-80b-instruct")
61+
#TODO : Update get_llm to support sagemaker as provider,
62+
# this needs to be updated with @alain changes
63+
print(f' get LLM Ideficsllm')
64+
_qa_llm = Ideficsllm.sagemakerendpoint_llm("idefics")
65+
66+
if (_qa_llm is None):
67+
logger.info('llm is None, returning')
68+
status_variables['jobstatus'] = JobStatus.ERROR_LOAD_LLM.status
69+
error = JobStatus.ERROR_LOAD_LLM.get_message()
70+
status_variables['answer'] = error.decode("utf-8")
71+
send_job_status(status_variables)
72+
return status_variables
73+
74+
# 3: run LLM
75+
#template="User:{question}![]({image})<end_of_utterance>\nAssistant:"
76+
template = """\n\nUser: {question}![]({image})<end_of_utterance>
77+
\n\nAssistant:"""
78+
prompt = PromptTemplate(template=template, input_variables=["image", "question"])
79+
chain = LLMChain(llm=_qa_llm, prompt=prompt, verbose=input_params['verbose'])
80+
81+
try:
82+
logger.info(f'image is: {filename}')
83+
logger.info(f'decoded_question is: {decoded_question}')
84+
tmp = chain.predict(image=image_url, question=decoded_question)
85+
#answer = tmp.removeprefix(' ')
86+
answer=tmp.split("Assistant:",1)[1]
87+
88+
logger.info(f'tmp is: {tmp}')
89+
logger.info(f'answer is: {answer}')
90+
llm_answer_bytes = answer.encode("utf-8")
91+
base64_bytes = base64.b64encode(llm_answer_bytes)
92+
llm_answer_base64_string = base64_bytes.decode("utf-8")
93+
94+
status_variables['jobstatus'] = JobStatus.DONE.status
95+
status_variables['answer'] = llm_answer_base64_string
96+
send_job_status(status_variables) if not streaming else None
97+
98+
except Exception as err:
99+
logger.exception(err)
100+
status_variables['jobstatus'] = JobStatus.ERROR_PREDICTION.status
101+
error = JobStatus.ERROR_PREDICTION.get_message()
102+
status_variables['answer'] = error.decode("utf-8")
103+
send_job_status(status_variables)
104+
105+
return status_variables
106+
107+
108+
def run_qa_agent_rag_image_no_memory(input_params):
109+
logger.info("starting qa agent RAG without memory on uploaded image")
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
2+
from langchain.llms.sagemaker_endpoint import LLMContentHandler, SagemakerEndpoint
3+
4+
import json
5+
import os
6+
7+
class ContentHandler(LLMContentHandler):
8+
content_type = "application/json"
9+
accepts = "application/json"
10+
11+
def transform_input(self, prompt, model_kwargs) -> bytes:
12+
input_str = json.dumps({"inputs": prompt, "parameters": model_kwargs})
13+
return input_str.encode("utf-8")
14+
15+
def transform_output(self, output: bytes) -> str:
16+
response_json = json.loads(output.read().decode("utf-8"))
17+
return response_json[0]["generated_text"]
18+
19+
20+
content_handler = ContentHandler()
21+
22+
class Ideficsllm():
23+
24+
parameters = {
25+
"do_sample": True,
26+
"top_p": 0.2,
27+
"temperature": 0.4,
28+
"top_k": 50,
29+
"max_new_tokens": 512,
30+
"stop": ["User:","<end_of_utterance>"]
31+
}
32+
33+
34+
@classmethod
35+
def sagemakerendpoint_llm(self,model_id):
36+
return SagemakerEndpoint(
37+
endpoint_name=model_id,
38+
region_name=os.environ["AWS_REGION"],
39+
model_kwargs=self.parameters,
40+
content_handler=content_handler,
41+
42+
43+
44+
)

0 commit comments

Comments
 (0)