Skip to content

Commit 9ee580e

Browse files
author
Dinesh Sajwan
committed
feat(visualqa): merged dataingestion changes and updated image transformer
1 parent 10f556f commit 9ee580e

File tree

12 files changed

+151
-32
lines changed

12 files changed

+151
-32
lines changed

lambda/aws-qa-appsync-opensearch/question_answering/src/llms/text_generation_llm_selector.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,9 @@ def get_llm(callbacks=None):
5050

5151
return Bedrock(**kwargs)
5252

53-
def get_embeddings_llm():
53+
def get_embeddings_llm(model_id):
5454
bedrock = boto3.client('bedrock-runtime')
55-
return BedrockEmbeddings(client=bedrock, model_id="amazon.titan-embed-text-v1")
55+
return BedrockEmbeddings(client=bedrock, model_id=model_id)
5656

5757
def get_max_tokens():
5858
return 200000

lambda/aws-qa-appsync-opensearch/question_answering/src/qa_agent/chain.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -98,10 +98,11 @@ def run_question_answering(arguments):
9898
_doc_index = None
9999
_current_doc_index = None
100100
def run_qa_agent_rag_no_memory(input_params):
101-
logger.info("starting qa agent with rag approach without memory")
101+
logger.info("starting qa agent with rag approach without memory :: {input_params}")
102102

103103
base64_bytes = input_params['question'].encode("utf-8")
104-
104+
model_id = input_params['embeddings_model']['modelId']
105+
print(f'model id :: {model_id}')
105106
sample_string_bytes = base64.b64decode(base64_bytes)
106107
decoded_question = sample_string_bytes.decode("utf-8")
107108

@@ -127,18 +128,20 @@ def run_qa_agent_rag_no_memory(input_params):
127128
os.environ.get('OPENSEARCH_API_NAME'),
128129
os.environ.get('OPENSEARCH_DOMAIN_ENDPOINT'),
129130
os.environ.get('OPENSEARCH_INDEX'),
130-
os.environ.get('OPENSEARCH_SECRET_ID'))
131+
os.environ.get('OPENSEARCH_SECRET_ID'),
132+
model_id)
131133

132134
else:
133135
logger.info("_retriever already exists")
134136

135137
_current_doc_index = _doc_index
136138

137139
logger.info("Starting similarity search")
138-
max_docs = input_params['max_docs']
140+
max_docs = input_params['retrieval']['max_docs']
139141
output_file_name = input_params['filename']
140142

141143
source_documents = doc_index.similarity_search(decoded_question, k=max_docs)
144+
logger.info(source_documents)
142145
# --------------------------------------------------------------------------
143146
# If an output file is specified, filter the response to only include chunks
144147
# related to that file. The source metadata is added when embeddings are

lambda/aws-qa-appsync-opensearch/question_answering/src/qa_agent/helper.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,8 @@ def load_vector_db_opensearch(region: str,
9090
opensearch_api_name: str,
9191
opensearch_domain_endpoint: str,
9292
opensearch_index: str,
93-
secret_id: str) -> OpenSearchVectorSearch:
93+
secret_id: str,
94+
model_id: str) -> OpenSearchVectorSearch:
9495
print(f"load_vector_db_opensearch, region={region}, "
9596
f"opensearch_domain_endpoint={opensearch_domain_endpoint}, opensearch_index={opensearch_index}")
9697

@@ -107,7 +108,7 @@ def load_vector_db_opensearch(region: str,
107108
opensearch_api_name,
108109
session_token=credentials.token,
109110
)
110-
embedding_function = get_embeddings_llm()
111+
embedding_function = get_embeddings_llm(model_id)
111112

112113
opensearch_url = opensearch_domain_endpoint if opensearch_domain_endpoint.startswith("https://") else f"https://{opensearch_domain_endpoint}"
113114
vector_db = OpenSearchVectorSearch(index_name=opensearch_index,

lambda/aws-rag-appsync-stepfn-opensearch/embeddings_job/src/helpers/image_loader.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -36,15 +36,15 @@ def __init__(self, bucket: str, image_file: str,image_detail_file: str):
3636
self.bucket = bucket
3737
self.image_file = image_file
3838
self.image_detail_file = image_detail_file
39-
39+
print(f"load image {image_file}, and image txt {image_detail_file} from :: {bucket}")
40+
4041

4142

4243
@tracer.capture_method
4344
def load(self):
4445
"""Load documents."""
4546
try:
4647
local_file_path = self.download_file(self.image_file)
47-
print(f"file downloaded :: {local_file_path}")
4848

4949
with open(f"{local_file_path}", "rb") as image_file:
5050
input_image = base64.b64encode(image_file.read()).decode("utf8")
@@ -57,9 +57,9 @@ def load(self):
5757

5858
docs = json.dumps({
5959
"inputImage": input_image,
60-
#"inputText": raw_text,
60+
"inputText": raw_text,
6161
})
62-
#print(f'docs for titan embeddings {docs}')
62+
print(f'raw_text for titan embeddings {raw_text}')
6363
return [Document(page_content=docs, metadata=metadata)]
6464

6565
except Exception as exception:

lambda/aws-rag-appsync-stepfn-opensearch/embeddings_job/src/helpers/update_ingestion_status.py

+2
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def updateIngestionJobStatus(variables):
4444
files {
4545
name
4646
status
47+
imageurl
4748
}
4849
ingestionjobid
4950
}
@@ -54,6 +55,7 @@ def updateIngestionJobStatus(variables):
5455
query = query.replace("$files", str(variables['files']).replace("\'", "\""))
5556
query = query.replace("\"name\"", "name")
5657
query = query.replace("\"status\"", "status")
58+
query = query.replace("\"imageurl\"", "imageurl")
5759

5860
request = {'query':query}
5961

lambda/aws-rag-appsync-stepfn-opensearch/embeddings_job/src/lambda.py

+4-7
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,8 @@
1818
import numpy as np
1919
import tempfile
2020
from helpers.credentials_helper import get_credentials
21-
from helpers.csv_loader import csv_loader
2221
from helpers.image_loader import image_loader
23-
from helpers.msdoc_loader import msdoc_loader
24-
from helpers.html_loader import html_loader
22+
2523

2624
from helpers.opensearch_helper import check_if_index_exists, process_shard, create_index_for_image
2725
from helpers.update_ingestion_status import updateIngestionJobStatus
@@ -198,17 +196,16 @@ def handler(event, context: LambdaContext) -> dict:
198196
for doc in sub_docs:
199197
doc.metadata['source'] = filename
200198
docs.extend(sub_docs)
201-
if(extension == '.jpg' or extension == '.jpeg' or extension == '.png'):
199+
if(extension == '.jpg' or extension == '.jpeg' or extension == '.png' or extension == '.svg'):
202200
# Try adding text to document
203201
#image_detal_file is created by aws rekognition
204-
img_load = image_loader(bucket_name, f"{name}-resized.png",f"{name}.txt")
202+
img_load = image_loader(bucket_name, f"{name}-resized{extension}",f"{name}.txt")
205203
sub_docs = img_load.load()
206204
for doc in sub_docs:
207205
doc.metadata['source'] = filename
208206
docs.extend(sub_docs)
209207
url = img_load.get_presigned_url()
210-
print(f" url set :: {url} ")
211-
print(f" prepare os object ")
208+
print(f" source :: {filename} ")
212209
os_document = img_load.prepare_document_for_direct_load()
213210

214211

lambda/aws-rag-appsync-stepfn-opensearch/s3_file_transformer/src/helpers/image_transformer.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -78,15 +78,17 @@ def check_moderation(self)-> str:
7878
@tracer.capture_method
7979
def detect_image_lables(self)-> str:
8080
try:
81-
labels=[]
81+
labels=''
8282
response = self.rekognition_client.detect_labels(Image=self.image,MaxLabels=20 )
8383
for label in response['Labels']:
84-
print(label)
84+
name = label['Name']
8585
if(label['Confidence'] > 0.80):
86-
labels.append(label['Name'])
86+
labels = labels + label['Name'] + ","
8787
except Exception as exp:
8888
print(f"Couldn't analyze image: {exp}")
8989
return labels
90+
91+
9092

9193

9294
@tracer.capture_method
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
#
2+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
5+
# with the License. A copy of the License is located at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES
10+
# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions
11+
# and limitations under the License.
12+
#
13+
from typing import List
14+
15+
from langchain.document_loaders.base import BaseLoader
16+
from helpers.s3inmemoryloader import S3FileLoaderInMemory
17+
18+
from aws_lambda_powertools import Logger, Tracer
19+
from PyPDF2 import PdfReader
20+
from io import BytesIO
21+
22+
23+
logger = Logger(service="INGESTION_FILE_TRANSFORMER")
24+
tracer = Tracer(service="INGESTION_FILE_TRANSFORMER")
25+
26+
@tracer.capture_method
27+
class pdf_transformer(BaseLoader):
28+
"""Transforming logic for pdf documents from s3 ."""
29+
30+
def __init__(self, bucket: str, key: str):
31+
"""Initialize with bucket and key name."""
32+
self.bucket = bucket
33+
self.key = key
34+
35+
def load(self) -> str:
36+
"""Load documents."""
37+
try:
38+
# TODO: add transformation logic
39+
encodedpdf = S3FileLoaderInMemory(self.bucket, self.key).load
40+
pdfFile = PdfReader(BytesIO(encodedpdf))
41+
raw_text = []
42+
for page in pdfFile.pages:
43+
raw_text.append(page.extract_text())
44+
return '\n'.join(raw_text)
45+
except Exception as exception:
46+
logger.exception(f"Reason: {exception}")
47+
return ""
48+

lambda/aws-rag-appsync-stepfn-opensearch/s3_file_transformer/src/helpers/utils.py

+66-8
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55
from aws_lambda_powertools import Logger, Tracer, Metrics
66
from aws_lambda_powertools.metrics import MetricUnit
77
from helpers.image_transformer import image_transformer
8+
from helpers.pdf_transformer import pdf_transformer
89
from botocore.exceptions import ClientError
10+
from langchain_core.prompts import PromptTemplate
911

1012

1113

@@ -18,14 +20,22 @@
1820

1921
@tracer.capture_method
2022
def isvalid_file_format(file_name: str) -> bool:
21-
file_format = ['.pdf','.txt','.jpg','.png','.csv','.docx','.ppt','.html','.jpeg']
23+
file_format = ['.pdf','.txt','.jpg','.png','.jpeg','.svg']
2224
if file_name.endswith(tuple(file_format)):
2325
return True
2426
else:
2527
print(f'Invalid file format :: {file_format}')
2628
return False
2729

28-
30+
@tracer.capture_method
31+
def transform_pdf_document(input_bucket: str,file_name: str,output_bucket: str,output_file_name:str):
32+
document_content = pdf_transformer(input_bucket,file_name)
33+
if not document_content:
34+
return 'Unable to load document'
35+
else:
36+
encoded_string = document_content.encode("utf-8")
37+
s3.Bucket(output_bucket).put_object(Key=output_file_name, Body=encoded_string)
38+
return 'File transformed'
2939

3040
@tracer.capture_method
3141
def transform_image_document(input_bucket: str,file_name: str,output_bucket: str):
@@ -40,19 +50,67 @@ def transform_image_document(input_bucket: str,file_name: str,output_bucket: str
4050
image_details = {
4151
"image_lables":result_lables,
4252
"image_celeb":result_celeb
43-
}
53+
}
54+
4455
name, extension = os.path.splitext(file_name)
56+
57+
lables_txt= convert_lables_to_sentence(result_lables)
58+
# with open ('/tmp/'+name+'.txt','w') as f:
59+
# f.write(json.dumps(image_details))
60+
# checking with senetence, save the senetence instead of lables
61+
4562
with open ('/tmp/'+name+'.txt','w') as f:
46-
f.write(json.dumps(image_details))
63+
f.write(json.dumps(lables_txt))
64+
4765
s3.upload_file('/tmp/'+name+'.txt',output_bucket,name+".txt")
4866
downloaded_file = download_file(input_bucket,file_name)
4967
print(f'downloaded_file:: {downloaded_file}')
5068

5169
resize_image = imt.image_resize()
52-
upload_file(output_bucket,resize_image)
70+
upload_file(output_bucket,resize_image,file_name)
5371
#upload_file(output_bucket,file_name)
5472
return 'File transformed'
5573

74+
75+
@tracer.capture_method
76+
def convert_lables_to_sentence(labels_str)-> str:
77+
try:
78+
print(f"lables:: {labels_str}")
79+
bedrock_client = boto3.client('bedrock-runtime')
80+
81+
prompt ="""\n\nHuman: Here are the comma seperated list of labels seen in the image:
82+
<labels>
83+
{labels}
84+
</labels>
85+
Please provide a human readable and understandable summary based on these labels
86+
\n\nAssistant:"""
87+
88+
89+
prompt_template = PromptTemplate.from_template(prompt)
90+
prompt_template_for_lables = prompt_template.format(labels=labels_str)
91+
92+
body = json.dumps({"prompt": prompt_template_for_lables,
93+
"max_tokens_to_sample":300,
94+
"temperature":1,
95+
"top_k":250,
96+
"top_p":0.999,
97+
"stop_sequences":[]
98+
})
99+
modelId = 'anthropic.claude-v2'
100+
accept = 'application/json'
101+
contentType = 'application/json'
102+
103+
response = bedrock_client.invoke_model(body=body,
104+
modelId=modelId, accept=accept, contentType=contentType)
105+
response_body = json.loads(response.get('body').read())
106+
response_text_claud = response_body.get('completion')
107+
print(f"response_text_claud:: {response_text_claud}")
108+
return response_text_claud
109+
except Exception as exp:
110+
print(f"Couldn't convert lables to sentence: {exp}")
111+
112+
113+
56114
def download_file(bucket, object )-> str:
57115
try:
58116
file_path = "/tmp/" + os.path.basename(object)
@@ -64,10 +122,10 @@ def download_file(bucket, object )-> str:
64122
except Exception as exp:
65123
print(f"Couldn\'t download file : {exp}")
66124

67-
def upload_file(bucket, object )-> str:
125+
def upload_file(bucket, file_name,key )-> str:
68126
try:
69-
file_path = "/tmp/" + os.path.basename(object)
70-
s3.upload_file(file_path, bucket,object)
127+
file_path = "/tmp/" + os.path.basename(file_name)
128+
s3.upload_file(file_path, bucket,key)
71129
return file_path
72130
except ClientError as client_err:
73131
print(f"Couldn\'t download file {client_err.response['Error']['Message']}")

lambda/aws-rag-appsync-stepfn-opensearch/s3_file_transformer/src/lambda.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from aws_lambda_powertools import Logger, Tracer, Metrics
2020
from aws_lambda_powertools.utilities.typing import LambdaContext
2121
from aws_lambda_powertools.metrics import MetricUnit
22-
from helpers.utils import isvalid_file_format,transform_csv_document,transform_pdf_document,transform_msdoc_document_file,transform_image_document
22+
from helpers.utils import isvalid_file_format,transform_pdf_document,transform_image_document
2323

2424

2525

lambda/aws-rag-appsync-stepfn-opensearch/s3_file_transformer/src/requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ aws-lambda-powertools
22
aws-xray-sdk
33
fastjsonschema
44
typing-extensions
5-
boto3
5+
boto3>=1.34.29
66
requests
77
langchain==0.1.4
88
pypdf2

src/patterns/gen-ai/aws-rag-appsync-stepfn-opensearch/index.ts

+8
Original file line numberDiff line numberDiff line change
@@ -521,6 +521,14 @@ export class RagAppsyncStepfnOpensearch extends Construct {
521521
resources: ['*'],
522522
}));
523523

524+
s3_transformer_job_function_role.addToPolicy(new iam.PolicyStatement({
525+
effect: iam.Effect.ALLOW,
526+
actions: ['bedrock:*'],
527+
resources: [
528+
'arn:' + Aws.PARTITION + ':bedrock:' + Aws.REGION + '::foundation-model',
529+
'arn:' + Aws.PARTITION + ':bedrock:' + Aws.REGION + '::foundation-model/*',
530+
],
531+
}));
524532

525533
s3_transformer_job_function_role.addToPolicy(
526534
new iam.PolicyStatement({

0 commit comments

Comments
 (0)