Skip to content

Commit dafcdc0

Browse files
dineshSajwanDinesh Sajwangithub-actionsscottschreckengaust
authored
Feature/defect fixes (#345)
* feat(bugfix): fixed version error * feat(construct): updated image generation resource policy and graphql api type * chore: self mutation Signed-off-by: github-actions <[email protected]> * feat(operationalmetric): updated description format * Update apidocs/classes/ContentGenerationAppSyncLambda.md Co-authored-by: Scott Schreckengaust <[email protected]> Signed-off-by: Dinesh Sajwan <[email protected]> * Update src/patterns/gen-ai/aws-contentgen-appsync-lambda/index.ts Co-authored-by: Scott Schreckengaust <[email protected]> Signed-off-by: Dinesh Sajwan <[email protected]> * Update src/patterns/gen-ai/aws-contentgen-appsync-lambda/index.ts Co-authored-by: Scott Schreckengaust <[email protected]> Signed-off-by: Dinesh Sajwan <[email protected]> * feat(construct): added streaming support for image --------- Signed-off-by: github-actions <[email protected]> Signed-off-by: Dinesh Sajwan <[email protected]> Co-authored-by: Dinesh Sajwan <[email protected]> Co-authored-by: github-actions <[email protected]> Co-authored-by: Scott Schreckengaust <[email protected]>
1 parent a130ca8 commit dafcdc0

File tree

2 files changed

+53
-48
lines changed

2 files changed

+53
-48
lines changed

lambda/aws-qa-appsync-opensearch/question_answering/src/qa_agent/helper.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ def send_job_status(variables):
180180
auth=aws_auth_appsync,
181181
timeout=10
182182
)
183-
logger.info('res :: {}',responseJobstatus)
183+
#logger.info('res :: {}',responseJobstatus)
184184

185185
def get_presigned_url(bucket,key) -> str:
186186
try:

lambda/aws-qa-appsync-opensearch/question_answering/src/qa_agent/image_qa.py

+52-47
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
from langchain.chains import LLMChain
2525
from .sagemaker_endpoint import MultiModal
2626
from aws_lambda_powertools import Logger, Tracer, Metrics
27-
from .StreamingCallbackHandler import StreamingCallbackHandler
2827
from adapters import registry
2928

3029
from .helper import download_file, load_vector_db_opensearch,send_job_status, JobStatus,get_presigned_url,encode_image_to_base64
@@ -181,7 +180,7 @@ def process_visual_qa(input_params,status_variables,filename):
181180

182181
qa_model= input_params['qa_model']
183182
qa_modelId=qa_model['modelId']
184-
183+
streaming = qa_model.get("streaming", False)
185184
# default model provider is bedrock and defalut modality is tEXT
186185
modality=qa_model.get("modality", "Text")
187186
model_provider=qa_model.get("provider",Provider.BEDROCK)
@@ -207,28 +206,20 @@ def process_visual_qa(input_params,status_variables,filename):
207206
if(_qa_llm is not None):
208207
local_file_path= download_file(bucket_name,filename)
209208
base64_images=encode_image_to_base64(local_file_path,filename)
210-
status_variables['answer']= generate_vision_answer_bedrock(_qa_llm,base64_images, qa_modelId,decoded_question)
211-
if(status_variables['answer'] is None):
212-
status_variables['answer'] = JobStatus.ERROR_PREDICTION.status
213-
error = JobStatus.ERROR_PREDICTION.get_message()
214-
status_variables['answer'] = error.decode("utf-8")
215-
status_variables['jobstatus'] = JobStatus.ERROR_PREDICTION.status
216-
else:
217-
status_variables['jobstatus'] = JobStatus.DONE.status
218-
streaming = input_params.get("streaming", False)
219-
209+
generate_vision_answer_bedrock(_qa_llm,base64_images, qa_modelId,decoded_question,status_variables,streaming)
220210
else:
221211
logger.error('Invalid Model , cannot load LLM , returning..')
222212
status_variables['jobstatus'] = JobStatus.ERROR_LOAD_LLM.status
223213
error = JobStatus.ERROR_LOAD_LLM.get_message()
224214
status_variables['answer'] = error.decode("utf-8")
215+
send_job_status(status_variables)
225216
else:
226217
logger.error('Invalid Model provider, cannot load LLM , returning..')
227218
status_variables['jobstatus'] = JobStatus.ERROR_LOAD_LLM.status
228219
error = JobStatus.ERROR_LOAD_LLM.get_message()
229220
status_variables['answer'] = error.decode("utf-8")
221+
send_job_status(status_variables)
230222

231-
send_job_status(status_variables)
232223
return status_variables
233224

234225
def generate_vision_answer_sagemaker(_qa_llm,input_params,decoded_question,status_variables,filename):
@@ -269,18 +260,9 @@ def generate_vision_answer_sagemaker(_qa_llm,input_params,decoded_question,statu
269260

270261
return status_variables
271262

272-
def generate_vision_answer_bedrock(bedrock_client,base64_images,model_id,decoded_question):
273-
system_prompt=""
274-
# use system prompt for fine tuning the performamce
275-
# system_prompt= """
276-
# You have perfect vision and pay great attention to detail which
277-
# makes you an expert at answering architecture diagram question.
278-
# Answer question in <question></question> tags. Before answer,
279-
# think step by step in <thinking> tags and analyze every part of the diagram.
280-
# """
281-
#Create a prompt with the question
282-
prompt =f"<question>{decoded_question}</question>. Answer must be a numbered list in a small paragraph inside <answer></answer> tag."
283-
263+
def generate_vision_answer_bedrock(bedrock_client,base64_images,model_id,
264+
decoded_question,status_variables,streaming):
265+
284266
claude_config = {
285267
'max_tokens': 1000,
286268
'temperature': 0,
@@ -302,34 +284,57 @@ def generate_vision_answer_bedrock(bedrock_client,base64_images,model_id,decoded
302284
},
303285
{
304286
"type": "text",
305-
"text": prompt
287+
"text": decoded_question
306288

307289
}
308290
]
309291
}
310292

311-
body=json.dumps({'messages': [messages],**claude_config, "system": system_prompt})
293+
body=json.dumps({'messages': [messages],**claude_config})
294+
312295
try:
313-
response = bedrock_client.invoke_model(
314-
body=body, modelId=model_id, accept="application/json",
315-
contentType="application/json"
316-
)
296+
if streaming:
297+
response = bedrock_client.invoke_model_with_response_stream(
298+
body=body, modelId=model_id, accept="application/json",
299+
contentType="application/json"
300+
)
301+
for event in response.get("body"):
302+
chunk = json.loads(event["chunk"]["bytes"])
303+
304+
if chunk['type'] == 'message_delta':
305+
status_variables['answer']=''
306+
status_variables['jobstatus'] = JobStatus.STREAMING_ENDED.status
307+
send_job_status(status_variables)
308+
309+
if chunk['type'] == 'content_block_delta':
310+
if chunk['delta']['type'] == 'text_delta':
311+
logger.info(chunk['delta']['text'], end="")
312+
chuncked_text=chunk['delta']['text']
313+
llm_answer_bytes = json.dumps(chuncked_text).encode("utf-8")
314+
base64_bytes = base64.b64encode(llm_answer_bytes)
315+
llm_answer_base64_string = base64_bytes.decode("utf-8")
316+
status_variables['answer']=llm_answer_base64_string
317+
status_variables['jobstatus'] = JobStatus.STREAMING_NEW_TOKEN.status
318+
send_job_status(status_variables)
319+
320+
321+
else:
322+
response = bedrock_client.invoke_model(
323+
body=body, modelId=model_id, accept="application/json",
324+
contentType="application/json"
325+
)
326+
response_body = json.loads(response.get('body').read())
327+
logger.info(f'answer is: {response_body}')
328+
output_list = response_body.get("content", [])
329+
for output in output_list:
330+
llm_answer_bytes = json.dumps(output["text"]).encode("utf-8")
331+
base64_bytes = base64.b64encode(llm_answer_bytes)
332+
llm_answer_base64_string = base64_bytes.decode("utf-8")
333+
status_variables['jobstatus'] = JobStatus.DONE.status
334+
status_variables['answer']=llm_answer_base64_string
335+
send_job_status(status_variables)
336+
317337
except Exception as err:
318338
logger.exception(f'Error occurred , Reason :{err}')
319339
return None
320-
321-
response = json.loads(response['body'].read().decode('utf-8'))
322-
323-
formated_response= response['content'][0]['text']
324-
answer = re.findall(r'<answer>(.*?)</answer>', formated_response, re.DOTALL)
325-
formatted_answer=answer[0]
326-
llm_answer_bytes = formatted_answer.encode("utf-8")
327-
print(f' formatted_answer {formatted_answer}')
328-
base64_bytes = base64.b64encode(llm_answer_bytes)
329-
print(f' base64_bytes')
330-
llm_answer_base64_string = base64_bytes.decode("utf-8")
331-
332-
print(f' llm_answer_base64_string {llm_answer_base64_string}')
333-
334-
return llm_answer_base64_string
335-
340+

0 commit comments

Comments
 (0)