Skip to content

Commit f29e1c3

Browse files
author
Dinesh Sajwan
committed
feat(visualqa): model validation
1 parent 65ea455 commit f29e1c3

File tree

6 files changed

+23
-18
lines changed

6 files changed

+23
-18
lines changed

lambda/aws-qa-appsync-opensearch/question_answering/src/qa_agent/chain.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def run_question_answering(arguments):
4646
#set deafult modality to text
4747
qa_model= arguments['qa_model']
4848
modality=qa_model.get('modality','Text')
49-
49+
5050
# Visual QA
5151
if modality.lower()=='image':
5252
logger.info("Running QA for Image modality")

lambda/aws-qa-appsync-opensearch/question_answering/src/qa_agent/helper.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def load_vector_db_opensearch(region: str,
134134
use_ssl = True,
135135
verify_certs = True,
136136
connection_class = RequestsHttpConnection)
137-
vector_db=""
137+
138138
logger.info(f"returning handle to OpenSearchVectorSearch, vector_db={vector_db}")
139139
return vector_db
140140

@@ -191,7 +191,7 @@ def get_presigned_url(bucket,key) -> str:
191191
return url
192192
except Exception as exception:
193193
logger.error(f"Reason: {exception}")
194-
return ""
194+
return None
195195

196196
def download_file(bucket,key )-> str:
197197
try:

lambda/aws-qa-appsync-opensearch/question_answering/src/qa_agent/image_qa.py

+14-9
Original file line numberDiff line numberDiff line change
@@ -104,12 +104,13 @@ def run_qa_agent_rag_on_image_no_memory(input_params):
104104
qa_model_id= qa_model['modelId']
105105
embedding_model_id = input_params['embeddings_model']['modelId']
106106
else:
107-
logger.error(' Either qa_model_id or embedding_model_id is not present, cannot answer question using RAG, returning...')
107+
logger.error(' RAG based QA need both qa_model_id and embeddings_model_id, either one or both are missing, cannot answer question using RAG, returning...')
108108
status_variables['jobstatus'] = JobStatus.ERROR_LOAD_LLM.status
109109
status_variables['answer'] = JobStatus.ERROR_LOAD_LLM.status
110110
send_job_status(status_variables)
111111
return
112112

113+
113114

114115
global _doc_index
115116
global _current_doc_index
@@ -127,6 +128,7 @@ def get_image_from_semantic_search_in_os(input_params,status_variables):
127128
embeddings_model=input_params['embeddings_model']
128129
embedding_model_id = embeddings_model['modelId']
129130
modality=embeddings_model.get("modality", "Text")
131+
130132
if _doc_index is None:
131133
logger.info("loading opensearch retriever")
132134
doc_index = load_vector_db_opensearch(boto3.Session().region_name,
@@ -169,16 +171,21 @@ def process_visual_qa(input_params,status_variables,filename):
169171

170172
qa_model= input_params['qa_model']
171173
qa_modelId=qa_model['modelId']
174+
175+
# default model provider is bedrock and defalut modality is tEXT
172176
modality=qa_model.get("modality", "Text")
177+
model_provider=qa_model.get("provider","Bedrock")
178+
logger.info(f"model provider is {model_provider} and modality is {modality}")
179+
173180
base64_bytes = input_params['question'].encode("utf-8")
174181
sample_string_bytes = base64.b64decode(base64_bytes)
175182
decoded_question = sample_string_bytes.decode("utf-8")
176-
model_provider = input_params['qa_model']['provider']
177-
183+
178184
if model_provider=='Sagemaker Endpoint':
179185
_qa_llm = MultiModal.sagemakerendpoint_llm(qa_modelId)
180186
if(_qa_llm is not None):
181187
status_variables['answer']=generate_vision_answer_sagemaker(_qa_llm,input_params,decoded_question,filename,status_variables)
188+
status_variables['jobstatus'] = JobStatus.DONE.status
182189
else:
183190
logger.error('Invalid Model , cannot load LLM , returning..')
184191
status_variables['jobstatus'] = JobStatus.ERROR_LOAD_LLM.status
@@ -191,6 +198,9 @@ def process_visual_qa(input_params,status_variables,filename):
191198
local_file_path= download_file(bucket_name,filename)
192199
base64_images=encode_image_to_base64(local_file_path,filename)
193200
status_variables['answer']= generate_vision_answer_bedrock(_qa_llm,base64_images, qa_modelId,decoded_question)
201+
status_variables['jobstatus'] = JobStatus.DONE.status
202+
streaming = input_params.get("streaming", False)
203+
194204
else:
195205
logger.error('Invalid Model , cannot load LLM , returning..')
196206
status_variables['jobstatus'] = JobStatus.ERROR_LOAD_LLM.status
@@ -233,16 +243,12 @@ def generate_vision_answer_sagemaker(_qa_llm,input_params,decoded_question,statu
233243

234244
status_variables['jobstatus'] = JobStatus.DONE.status
235245
status_variables['answer'] = llm_answer_base64_string
236-
streaming = input_params.get("streaming", False)
237-
238-
send_job_status(status_variables) if not streaming else None
239-
246+
240247
except Exception as err:
241248
logger.exception(err)
242249
status_variables['jobstatus'] = JobStatus.ERROR_PREDICTION.status
243250
error = JobStatus.ERROR_PREDICTION.get_message()
244251
status_variables['answer'] = error.decode("utf-8")
245-
send_job_status(status_variables)
246252

247253
return status_variables
248254

@@ -279,7 +285,6 @@ def generate_vision_answer_bedrock(bedrock_client,base64_images, model_id,decode
279285
},
280286
{
281287
"type": "text",
282-
#"text": "Describe the architecture and code terraform script to deploy it, answer inside <answer></answer> tags."
283288
"text": prompt
284289

285290
}

lambda/aws-qa-appsync-opensearch/question_answering/src/qa_agent/sagemaker_endpoint.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,10 @@ def sagemakerendpoint_llm(self,model_id):
5959
return endpoint
6060
except Exception as err:
6161
logger.error(f' Error when accessing sagemaker endpoint :: {model_id} , returning...')
62-
return ''
62+
return None
6363
else:
6464
logger.error(f" The sagemaker model Id {model_id} do not match a sagemaker endpoint name {sageMakerEndpoint}")
65-
return ''
65+
return None
6666

6767

6868

src/patterns/gen-ai/aws-qa-appsync-opensearch/index.ts

+3-2
Original file line numberDiff line numberDiff line change
@@ -535,10 +535,11 @@ export class QaAppsyncOpensearch extends BaseClass {
535535
actions: [
536536
'bedrock:InvokeModel',
537537
'bedrock:InvokeModelWithResponseStream',
538+
'bedrock:ListFoundationModels',
538539
],
540+
// ListFoundationModels has no specific resource type
539541
resources: [
540-
'arn:' + Aws.PARTITION + ':bedrock:' + Aws.REGION + '::foundation-model',
541-
'arn:' + Aws.PARTITION + ':bedrock:' + Aws.REGION + '::foundation-model/*',
542+
'*',
542543
],
543544
}),
544545
);

src/patterns/gen-ai/aws-rag-appsync-stepfn-opensearch/index.ts

+1-2
Original file line numberDiff line numberDiff line change
@@ -536,8 +536,7 @@ export class RagAppsyncStepfnOpensearch extends BaseClass {
536536
actions: [
537537
'rekognition:DetectModerationLabels',
538538
],
539-
resources: ['arn:'+ Aws.PARTITION +':rekognition:' + Aws.REGION +':'+ Aws.ACCOUNT_ID + ':project/*'],
540-
}));
539+
resources: ['*'] }));
541540

542541
s3_transformer_job_function_role.addToPolicy(new iam.PolicyStatement({
543542
effect: iam.Effect.ALLOW,

0 commit comments

Comments
 (0)