Skip to content

Commit 8496001

Browse files
author
Dinesh Sajwan
committed
feat(visualqa): fixed bugs
1 parent 6cbb5e0 commit 8496001

File tree

10 files changed

+104
-101
lines changed

10 files changed

+104
-101
lines changed

lambda/aws-qa-appsync-opensearch/question_answering/src/lambda.py

+1-30
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
tracer = Tracer(service="QUESTION_ANSWERING")
2424
metrics = Metrics(namespace="question_answering", service="QUESTION_ANSWERING")
2525

26-
#@logger.inject_lambda_context(log_event=True)
26+
@logger.inject_lambda_context(log_event=True)
2727
@tracer.capture_lambda_handler
2828
@metrics.log_metrics(capture_cold_start_metric=True)
2929
def handler(event, context: LambdaContext) -> dict:
@@ -41,32 +41,3 @@ def handler(event, context: LambdaContext) -> dict:
4141

4242
print(f"llm_response is {llm_response}")
4343
return llm_response
44-
45-
input ={"detail": {
46-
"jobid": "111",
47-
"jobstatus": "",
48-
"qa_model": {
49-
"provider": "Bedrock",
50-
"modelId": "anthropic.claude-3-sonnet-20240229-v1:0",
51-
"streaming": True,
52-
"modality": "Image"
53-
},
54-
"embeddings_model": {
55-
"provider": "Bedrock",
56-
"modelId": "amazon.titan-embed-image-v1",
57-
"streaming": True
58-
},
59-
"retrieval": {
60-
"max_docs": 1,
61-
"index_name": "",
62-
"filter_filename": ""
63-
},
64-
"filename": "two_cats.jpeg",
65-
"presignedurl": "",
66-
"question": "d2hhdCBhcmUgdGhlIGNhdHMgZG9pbmc/",
67-
"verbose": False,
68-
"responseGenerationMethod": "LONG_CONTEXT"
69-
}
70-
}
71-
72-
handler(input, None)

lambda/aws-qa-appsync-opensearch/question_answering/src/llms/text_generation_llm_selector.py

+13-8
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
metrics = Metrics(namespace="question_answering", service="QUESTION_ANSWERING")
2828

2929

30+
3031
def get_llm(callbacks=None,model_id="anthropic.claude-v2:1"):
3132
bedrock = boto3.client('bedrock-runtime')
3233

@@ -63,7 +64,7 @@ def get_embeddings_llm(model_id,modality):
6364
def get_bedrock_fm(model_id,modality):
6465
bedrock_client = boto3.client('bedrock-runtime')
6566
validation_status= validate_model_id_in_bedrock(model_id,modality)
66-
print(f' validation_status :: {validation_status}')
67+
logger.info(f' validation_status :: {validation_status}')
6768
if(validation_status['status']):
6869
return bedrock_client
6970
else:
@@ -73,9 +74,16 @@ def get_bedrock_fm(model_id,modality):
7374

7475

7576
#TODO -add max token based on model id
76-
def get_max_tokens():
77-
return 200000
78-
77+
def get_max_tokens(model_id):
78+
match model_id:
79+
case "anthropic.claude-v2:1":
80+
return 200000
81+
case "anthropic.claude-3-sonnet-20240229-v1:0":
82+
return 200000
83+
case _:
84+
return 4096
85+
86+
7987
def validate_model_id_in_bedrock(model_id,modality):
8088
"""
8189
Validate if the listed model id is supported with given modality
@@ -92,19 +100,16 @@ def validate_model_id_in_bedrock(model_id,modality):
92100
for model in models:
93101
if model["modelId"].lower() == model_id.lower():
94102
response["message"]=f"model {model_id} does not support modality {modality} "
95-
print(f' modality :: {model["inputModalities"]}')
96103
for inputModality in model["inputModalities"]:
97104
if inputModality.lower() == modality.lower():
98-
print(f' modality supported')
99105
response["message"]=f"model {model_id} with modality {modality} is supported with bedrock "
100106
response["status"] = True
101107

102-
print(f' response :: {response}')
108+
logger.info(f' response :: {response}')
103109
return response
104110
except ClientError as ce:
105111
message=f"error occured while validating model in bedrock {ce}"
106112
logger.error(message)
107113
response["status"] = False
108114
response["message"] = message
109-
print(f' response :: {response}')
110115
return response

lambda/aws-qa-appsync-opensearch/question_answering/src/qa_agent/helper.py

+28-22
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,25 @@
1010
# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions
1111
# and limitations under the License.
1212
#
13+
import os
14+
import boto3
15+
import json
16+
import base64
1317
from pathlib import Path
1418
from aiohttp import ClientError
1519
from langchain_community.vectorstores import OpenSearchVectorSearch
16-
#from opensearchpy import RequestsHttpConnection
20+
from opensearchpy import RequestsHttpConnection
1721
from llms import get_embeddings_llm
1822
import requests
19-
import os
20-
import boto3
21-
import json
22-
import base64
2323
from enum import Enum
2424
from requests_aws4auth import AWS4Auth
2525
s3 = boto3.client('s3')
26+
from aws_lambda_powertools import Logger, Tracer, Metrics
27+
28+
29+
logger = Logger(service="QUESTION_ANSWERING")
30+
tracer = Tracer(service="QUESTION_ANSWERING")
31+
metrics = Metrics(namespace="question_answering", service="QUESTION_ANSWERING")
2632

2733

2834
class JobStatus(Enum):
@@ -101,7 +107,7 @@ def load_vector_db_opensearch(region: str,
101107
secret_id: str,
102108
model_id: str,
103109
modality: str) -> OpenSearchVectorSearch:
104-
print(f"load_vector_db_opensearch, region={region}, "
110+
logger.info(f"load_vector_db_opensearch, region={region}, "
105111
f"opensearch_domain_endpoint={opensearch_domain_endpoint}, opensearch_index={opensearch_index}")
106112

107113
# if the secret id is not provided
@@ -120,15 +126,16 @@ def load_vector_db_opensearch(region: str,
120126
embedding_function = get_embeddings_llm(model_id,modality)
121127

122128
opensearch_url = opensearch_domain_endpoint if opensearch_domain_endpoint.startswith("https://") else f"https://{opensearch_domain_endpoint}"
123-
# vector_db = OpenSearchVectorSearch(index_name=opensearch_index,
124-
# embedding_function=embedding_function,
125-
# opensearch_url=opensearch_url,
126-
# http_auth=http_auth,
127-
# use_ssl = True,
128-
# verify_certs = True,
129-
# connection_class = RequestsHttpConnection)
129+
130+
vector_db = OpenSearchVectorSearch(index_name=opensearch_index,
131+
embedding_function=embedding_function,
132+
opensearch_url=opensearch_url,
133+
http_auth=http_auth,
134+
use_ssl = True,
135+
verify_certs = True,
136+
connection_class = RequestsHttpConnection)
130137
vector_db=""
131-
print(f"returning handle to OpenSearchVectorSearch, vector_db={vector_db}")
138+
logger.info(f"returning handle to OpenSearchVectorSearch, vector_db={vector_db}")
132139
return vector_db
133140

134141
def send_job_status(variables):
@@ -159,8 +166,7 @@ def send_job_status(variables):
159166

160167
print(request)
161168

162-
#GRAPHQL_URL = os.environ['GRAPHQL_URL']
163-
GRAPHQL_URL ="https://j2uzmlvujbhbzoduvpctgkpu2e.appsync-api.us-east-1.amazonaws.com/graphql"
169+
GRAPHQL_URL = os.environ['GRAPHQL_URL']
164170
HEADERS={
165171
"Content-Type": "application/json",
166172
}
@@ -172,7 +178,7 @@ def send_job_status(variables):
172178
auth=aws_auth_appsync,
173179
timeout=10
174180
)
175-
print('res :: {}',responseJobstatus)
181+
logger.info('res :: {}',responseJobstatus)
176182

177183
def get_presigned_url(bucket,key) -> str:
178184
try:
@@ -181,23 +187,23 @@ def get_presigned_url(bucket,key) -> str:
181187
Params={'Bucket': bucket, 'Key': key},
182188
ExpiresIn=900
183189
)
184-
print(f"presigned url generated for {key} from {bucket}")
190+
logger.info(f"presigned url generated for {key} from {bucket}")
185191
return url
186192
except Exception as exception:
187-
print(f"Reason: {exception}")
193+
logger.error(f"Reason: {exception}")
188194
return ""
189195

190196
def download_file(bucket,key )-> str:
191197
try:
192198
file_path = "/tmp/" + os.path.basename(key)
193199
s3.download_file(bucket, key,file_path)
194-
print(f"file downloaded {file_path}")
200+
logger.info(f"file downloaded {file_path}")
195201
return file_path
196202
except ClientError as client_err:
197-
print(f"Couldn\'t download file {client_err.response['Error']['Message']}")
203+
logger.error(f"Couldn\'t download file {client_err.response['Error']['Message']}")
198204

199205
except Exception as exp:
200-
print(f"Couldn\'t download file : {exp}")
206+
logger.error(f"Couldn\'t download file : {exp}")
201207

202208
def encode_image_to_base64(image_file_path,image_file) -> str:
203209
with open(image_file_path, "rb") as image_file:

lambda/aws-qa-appsync-opensearch/question_answering/src/qa_agent/image_qa.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,7 @@
3535

3636

3737

38-
#bucket_name = os.environ['INPUT_BUCKET']
39-
bucket_name="persistencestack-inputassets7d1d3f52-qert2sgpwhtu"
38+
bucket_name = os.environ['INPUT_BUCKET']
4039

4140
def run_qa_agent_on_image_no_memory(input_params):
4241
logger.info("starting qa agent without memory on uploaded image")

lambda/aws-qa-appsync-opensearch/question_answering/src/qa_agent/sagemaker_endpoint.py

+33-15
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,24 @@
1-
2-
from langchain.llms.sagemaker_endpoint import LLMContentHandler, SagemakerEndpoint
3-
from aws_lambda_powertools import Logger, Tracer, Metrics
4-
1+
#
2+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
5+
# with the License. A copy of the License is located at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES
10+
# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions
11+
# and limitations under the License.
12+
#
513
import json
614
import os
15+
from langchain.llms.sagemaker_endpoint import LLMContentHandler, SagemakerEndpoint
16+
from aws_lambda_powertools import Logger, Tracer, Metrics
717
logger = Logger(service="QUESTION_ANSWERING")
818

19+
20+
sageMakerEndpoint= os.environ['SAGEMAKER_ENDPOINT']
21+
922
class ContentHandler(LLMContentHandler):
1023
content_type = "application/json"
1124
accepts = "application/json"
@@ -35,17 +48,22 @@ class MultiModal():
3548

3649
@classmethod
3750
def sagemakerendpoint_llm(self,model_id):
38-
try:
39-
endpoint= SagemakerEndpoint(
40-
endpoint_name=model_id,
41-
region_name=os.environ["AWS_REGION"],
42-
model_kwargs=self.parameters,
43-
content_handler=content_handler,
44-
)
45-
return endpoint
46-
except Exception as err:
47-
logger.error(' Error when accessing sagemaker endpoint :: {model_id} , returning...')
48-
return ''
51+
if(sageMakerEndpoint ==model_id):
52+
try:
53+
endpoint= SagemakerEndpoint(
54+
endpoint_name=model_id,
55+
region_name=os.environ["AWS_REGION"],
56+
model_kwargs=self.parameters,
57+
content_handler=content_handler,
58+
)
59+
return endpoint
60+
except Exception as err:
61+
logger.error(f' Error when accessing sagemaker endpoint :: {model_id} , returning...')
62+
return ''
63+
else:
64+
logger.error(f" The sagemaker model Id {model_id} do not match a sagemaker endpoint name {sageMakerEndpoint}")
65+
return ''
66+
4967

5068

5169

lambda/aws-rag-appsync-stepfn-opensearch/s3_file_transformer/src/helpers/utils.py

+5
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,14 @@ def transform_pdf_document(input_bucket: str,file_name: str,output_bucket: str,o
4747
if not document_content:
4848
return 'Unable to load document'
4949
else:
50+
try:
5051
encoded_string = document_content.encode("utf-8")
5152
s3.Bucket(output_bucket).put_object(Key=output_file_name, Body=encoded_string)
5253
return 'File transformed'
54+
except Exception as e:
55+
logger.error(f'Error in uploading {output_file_name} to {output_bucket} :: {e}')
56+
return 'File transformed Failed'
57+
5358

5459
@tracer.capture_method
5560
def transform_image_document(input_bucket: str,file_name: str,output_bucket: str):

lambda/aws-rag-appsync-stepfn-opensearch/s3_file_transformer/src/lambda.py

-1
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,6 @@ def handler(event, context: LambdaContext) -> dict:
8787
response['name'] = output_file_name
8888
if(extension == '.pdf'):
8989
response['status'] = transform_pdf_document(input_bucket,file_name,output_bucket,output_file_name)
90-
print(f' pdf processed ::' )
9190
elif(extension == '.jpg'or extension == '.jpeg' or extension == '.png' or extension == '.svg'):
9291
response['status'] = transform_image_document(input_bucket,file_name,output_bucket)
9392
#TODO add csv, doc, docx file type support as well.

lambda/aws-rag-appsync-stepfn-opensearch/s3_file_transformer/src/requirements.txt

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
aws-lambda-powertools
2-
aws-xray-sdk
3-
fastjsonschema
4-
typing-extensions
2+
aws-xray-sdk==3.5.4
3+
fastjsonschema==2.19.1
4+
typing-extensions====4.7.0
55
boto3>=1.34.29
6-
requests
6+
requests==2.31.0
77
langchain==0.1.4
88
pypdf2==3.0.1
99
Pillow==10.2.0

src/patterns/gen-ai/aws-qa-appsync-opensearch/index.ts

+18-8
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,12 @@ export interface QaAppsyncOpensearchProps {
152152
* and settings instead of the existing
153153
*/
154154
readonly customDockerLambdaProps?: DockerLambdaCustomProps | undefined;
155+
156+
/**
157+
* Optional. Allows to provide custom lambda code
158+
* and settings instead of the existing
159+
*/
160+
readonly sagemakerEndpointName?: string
155161
}
156162

157163
/**
@@ -466,14 +472,16 @@ export class QaAppsyncOpensearch extends Construct {
466472
resources: ['*'],
467473
}),
468474
);
469-
question_answering_function_role.addToPolicy(
470-
new iam.PolicyStatement({
471-
effect: iam.Effect.ALLOW,
472-
actions: ['sagemaker:InvokeEndpoint'],
473-
resources: ['*'],
474-
}),
475-
);
476-
475+
476+
if(props.sagemakerEndpointName){
477+
question_answering_function_role.addToPolicy(
478+
new iam.PolicyStatement({
479+
effect: iam.Effect.ALLOW,
480+
actions: ['sagemaker:InvokeEndpoint'],
481+
resources: ['arn:'+ Aws.PARTITION +':sagemaker:' + Aws.ACCOUNT_ID + ':endpoint/*' ],
482+
}),
483+
);
484+
}
477485
// The lambda will access the opensearch credentials
478486
if (props.openSearchSecret) {
479487
props.openSearchSecret.grantRead(question_answering_function_role);
@@ -553,6 +561,7 @@ export class QaAppsyncOpensearch extends Construct {
553561
true,
554562
);
555563

564+
const sagemakerEndpointNamestr = props.sagemakerEndpointName || ""
556565
const construct_docker_lambda_props = {
557566
code: lambda.DockerImageCode.fromImageAsset(
558567
path.join(
@@ -576,6 +585,7 @@ export class QaAppsyncOpensearch extends Construct {
576585
OPENSEARCH_DOMAIN_ENDPOINT: opensearch_helper.getOpenSearchEndpoint(props),
577586
OPENSEARCH_INDEX: props.openSearchIndexName,
578587
OPENSEARCH_SECRET_ID: SecretId,
588+
SAGEMAKER_ENDPOINT:sagemakerEndpointNamestr
579589
},
580590
...(props.lambdaProvisionedConcurrency && {
581591
currentVersionOptions: {

src/patterns/gen-ai/aws-rag-appsync-stepfn-opensearch/index.ts

+1-11
Original file line numberDiff line numberDiff line change
@@ -555,19 +555,9 @@ export class RagAppsyncStepfnOpensearch extends Construct {
555555
s3_transformer_job_function_role.addToPolicy(new iam.PolicyStatement({
556556
effect: iam.Effect.ALLOW,
557557
actions: [
558-
'rekognition:CompareFaces',
559-
'rekognition:DetectFaces',
560-
'rekognition:DetectLabels',
561-
'rekognition:ListFaces',
562-
'rekognition:SearchFaces',
563-
'rekognition:SearchFacesByImage',
564-
'rekognition:DetectText',
565-
'rekognition:GetCelebrityInfo',
566-
'rekognition:RecognizeCelebrities',
567558
'rekognition:DetectModerationLabels',
568559
],
569-
//TODO: change the resource to specific arn
570-
resources: ['*'],
560+
resources: ['arn:'+ Aws.PARTITION +':rekognition:' + Aws.ACCOUNT_ID + ':project/*' ],
571561
}));
572562

573563
s3_transformer_job_function_role.addToPolicy(new iam.PolicyStatement({

0 commit comments

Comments
 (0)