Skip to content

Commit 2a3e9bf

Browse files
author
Dinesh Sajwan
committed
feat(cosntruct): updated documentation
1 parent 8a3a83b commit 2a3e9bf

File tree

5 files changed

+204
-105
lines changed

5 files changed

+204
-105
lines changed

lambda/aws-contentgen-appsync-lambda/src/image_generator.py

+72-34
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import requests as reqs
33
from aws_lambda_powertools import Logger, Tracer, Metrics
44
from requests_aws4auth import AWS4Auth
5-
5+
from datetime import datetime
66

77
logger = Logger(service="IMAGE_GENERATION")
88
tracer = Tracer(service="IMAGE_GENERATION")
@@ -21,12 +21,14 @@
2121
)
2222

2323
class image_generator():
24-
"""Generate image ."""
24+
"""Generate Image based on consfigured modelId .
25+
Implements text and omage moderation with Amazon Rekognition and
26+
Amzon Comprehend.
27+
"""
2528

26-
def __init__(self,input_text,file_name, rekognition_client,comprehend_client,bedrock_client,bucket):
29+
def __init__(self,input_text, rekognition_client,comprehend_client,bedrock_client,bucket):
2730
"""Initialize with bucket , key and rekognition_client."""
2831

29-
self.file_name = file_name
3032
self.rekognition_client = rekognition_client
3133
self.comprehend_client = comprehend_client
3234
self.input_text =input_text
@@ -36,19 +38,21 @@ def __init__(self,input_text,file_name, rekognition_client,comprehend_client,bed
3638

3739

3840
@tracer.capture_method
39-
def upload_file_to_s3(self,imgbase64encoded):
41+
def upload_file_to_s3(self,imgbase64encoded,file_name):
4042

4143
"""Upload generated file to S3 bucket"""
42-
43-
logger.info(f"uploading file to s3 bucket: {self.bucket}, key: {self.file_name}")
44+
45+
logger.info(f"uploading file to s3 bucket: {self.bucket}, key: {file_name}")
46+
current_datetime = datetime.now().strftime("%Y-%m-%d_%H:%M:%S.%f")
47+
upload_file_name=file_name+current_datetime+".jpg"
4448
try:
45-
respImg= s3.Object(self.bucket, self.file_name).put(Body=base64.b64decode(imgbase64encoded))
49+
respImg= s3.Object(self.bucket, upload_file_name).put(Body=base64.b64decode(imgbase64encoded))
4650

4751
except Exception as e:
4852
logger.error(f"Error occured :: {e}")
4953
return False
5054
return {
51-
"file_name":self.file_name,
55+
"file_name":upload_file_name,
5256
"bucket_name":self.bucket,
5357
}
5458

@@ -81,7 +85,7 @@ def text_moderation(self):
8185
return response
8286

8387
@tracer.capture_method
84-
def image_moderation(self):
88+
def image_moderation(self,file_name):
8589

8690
"""Detect image moderation on the generated image to avoid any toxicity/nudity"""
8791

@@ -94,7 +98,7 @@ def image_moderation(self):
9498
Image={
9599
'S3Object':{
96100
'Bucket':self.bucket,
97-
'Name':self.file_name}
101+
'Name':file_name}
98102
}
99103
)
100104
for label in rekognition_response['ModerationLabels']:
@@ -109,47 +113,33 @@ def image_moderation(self):
109113
def generate_image(self,input_params):
110114

111115
"""Generate image using Using bedrock with configured modelid and params"""
116+
112117

113118
input_text=self.input_text
114-
119+
print(f' input_params :: {input_params}')
115120
# add default negative prompts
116-
if 'negative_prompts' in input_params:
121+
if 'negative_prompts' in input_params and input_params['negative_prompts'] is None:
117122
sample_string_bytes = base64.b64decode(input_params['negative_prompts'])
118123
decoded_negative_prompts = sample_string_bytes.decode("utf-8")
119124
logger.info(f"decoded negative prompts are :: {decoded_negative_prompts}")
120125
negative_prompts= decoded_negative_prompts
121126
else:
122-
negative_prompts= ["poorly rendered","poor background details"]
123-
127+
negative_prompts= ["poorly rendered","poor background details","poorly drawn mountains","disfigured mountain features"]
128+
124129
model_id=input_params['model_config']['modelId']
125130

126131
model_kwargs=input_params['model_config']['model_kwargs']
127132
params= get_inference_parameters(model_kwargs)
128133

129-
logger.info(f'SD params :: {params}')
130134

131-
132-
request = json.dumps({
133-
"text_prompts": (
134-
[{"text": input_text, "weight": 1.0}]
135-
+ [{"text": negprompt, "weight": -1.0} for negprompt in negative_prompts]
136-
),
137-
"cfg_scale":params['cfg_scale'],
138-
"seed": params['seed'],
139-
"steps": params['steps'],
140-
"style_preset": params['style_preset'],
141-
"clip_guidance_preset": params['clip_guidance_preset'],
142-
"sampler": params['sampler'],
143-
"width": params['width'],
144-
"height": params['height']
145-
})
146-
135+
body=get_model_payload(model_id,params,input_text,negative_prompts)
136+
print(f' body :: {body}')
147137
try:
148138
return self.bedrock_client.invoke_model(
149139
modelId= model_id,
150140
contentType= "application/json",
151141
accept= "application/json",
152-
body=request
142+
body=body
153143
)
154144
except Exception as e:
155145
logger.error(f"Error occured during generating image:: {e}")
@@ -198,7 +188,44 @@ def send_job_status(self,variables):
198188
)
199189
logger.info('res :: {}',responseJobstatus)
200190

191+
def get_model_payload(modelid,params,input_text,negative_prompts):
192+
193+
body=''
194+
if modelid=='stability.stable-diffusion-xl' :
195+
body = json.dumps({
196+
"text_prompts": (
197+
[{"text": input_text, "weight": 1.0}]
198+
+ [{"text": negprompt, "weight": -1.0} for negprompt in negative_prompts]
199+
),
200+
"cfg_scale":params['cfg_scale'],
201+
"seed": params['seed'],
202+
"steps": params['steps'],
203+
"style_preset": params['style_preset'],
204+
"clip_guidance_preset": params['clip_guidance_preset'],
205+
"sampler": params['sampler'],
206+
"width": params['width'],
207+
"height": params['height']
208+
})
209+
return body
210+
if modelid=='amazon.titan-image-generator-v1' :
201211

212+
body = json.dumps({
213+
"taskType": "TEXT_IMAGE",
214+
"textToImageParams": {
215+
"text": input_text,
216+
#"negativeText": negative_prompts
217+
},
218+
"imageGenerationConfig": {
219+
"numberOfImages": params['numberOfImages'],
220+
"quality":params['quality'],
221+
"height": params['height'],
222+
"width": params['width'],
223+
"cfgScale": params['cfg_scale'],
224+
"seed": params['seed']
225+
}
226+
})
227+
return body
228+
202229
def get_inference_parameters(model_kwargs):
203230
""" Read inference parameters and set default values"""
204231
if 'seed' in model_kwargs:
@@ -233,6 +260,15 @@ def get_inference_parameters(model_kwargs):
233260
sampler= model_kwargs['sampler']
234261
else:
235262
sampler='K_DPMPP_2S_ANCESTRAL'
263+
if 'numberOfImages' in model_kwargs:
264+
numberOfImages= model_kwargs['numberOfImages']
265+
else:
266+
numberOfImages=1
267+
if 'quality' in model_kwargs:
268+
quality= model_kwargs['quality']
269+
else:
270+
quality="standard"
271+
236272
return {
237273
"cfg_scale": cfg_scale,
238274
"seed": seed,
@@ -242,4 +278,6 @@ def get_inference_parameters(model_kwargs):
242278
"sampler": sampler,
243279
"width": width,
244280
"height": height,
245-
}
281+
"numberOfImages": numberOfImages,
282+
"quality": quality
283+
}

lambda/aws-contentgen-appsync-lambda/src/lambda.py

+55-35
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from aws_lambda_powertools.utilities.typing import LambdaContext
2121
from aws_lambda_powertools.metrics import MetricUnit
2222
from aws_lambda_powertools.utilities.validation import validate, SchemaValidationError
23-
from datetime import datetime
23+
2424

2525

2626
logger = Logger(service="IMAGE_GENERATION")
@@ -30,7 +30,6 @@
3030
aws_region = boto3.Session().region_name
3131
bucket = os.environ['OUTPUT_BUCKET']
3232

33-
3433
bedrock_client = boto3.client('bedrock-runtime')
3534
rekognition_client=boto3.client('rekognition')
3635
comprehend_client=boto3.client('comprehend', region_name=aws_region)
@@ -50,17 +49,18 @@ def handler(event, context: LambdaContext) -> dict:
5049
logger.info(f"event is {event}")
5150

5251
input_params=event['detail']['imageInput']
53-
#image_name="public/"+input_params['filename']
5452
input_text=input_params['input_text']
53+
model_id=input_params['model_config']['modelId']
5554

56-
current_datetime = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
57-
file_name="generatedimage_"+current_datetime
55+
# add timestamp to file name when uploading it to s3
56+
file_name="generatedimage_"
5857

5958
response={
6059
"filename":file_name,
6160
"generatedImagePath":"",
61+
"status":"Generating",
62+
"image_path":bucket,
6263
"input_text":input_params['input_text'],
63-
"generateImageJobStatus":"Failed",
6464
"jobid":input_params["jobid"],
6565
"message":''
6666
}
@@ -72,54 +72,74 @@ def handler(event, context: LambdaContext) -> dict:
7272
response["message"]="Input text is empty."
7373
return response
7474

75-
img = image_generator(decoded_input_text,file_name, rekognition_client,comprehend_client,bedrock_client,bucket)
75+
img = image_generator(decoded_input_text, rekognition_client,comprehend_client,bedrock_client,bucket)
7676

7777
text_moderation_response=img.text_moderation()
7878
if(text_moderation_response['isToxic']==True):
7979
response["message"]="In appropriate input prompt. Please change the prompt."
80+
response["status"]='Blocked'
8081
else:
8182
bedrock_response = img.generate_image(input_params)
82-
parsed_reponse = parse_response(bedrock_response)
83+
parsed_reponse = parse_response(bedrock_response,model_id)
8384
if(parsed_reponse['image_generated_status']=='Failed'):
84-
response["message"]="No image generated by bedrock API, Please check the prompt"
85+
response["message"]="No image generated by bedrock API, Please check the prompt."
86+
response["status"]='Blocked'
8587
else:
86-
imgbase64encoded= parsed_reponse['image_generated']
87-
imageGenerated=img.upload_file_to_s3(imgbase64encoded)
88-
89-
image_moderation_response=img.image_moderation()
90-
if(image_moderation_response['isToxic']==True):
91-
response["message"]="In-appropriate image generated."
92-
else:
93-
response={
94-
"filename":file_name,
95-
"image_path":bucket,
96-
"input_text":decoded_input_text,
97-
"status":"Completed",
98-
"jobid":input_params["jobid"],
99-
"message":"Image generated successfully"
100-
}
101-
102-
print (f"response :: {response}")
103-
img.send_job_status(response)
88+
num_of_images=0 #if multiple image geneated iterate through all
89+
for image in parsed_reponse['image_generated']:
90+
print(f'num_of_images {num_of_images}')
91+
if model_id=='stability.stable-diffusion-xl' :
92+
imgbase64encoded= parsed_reponse['image_generated'][num_of_images]["base64"]
93+
if model_id=='amazon.titan-image-generator-v1' :
94+
imgbase64encoded= parsed_reponse['image_generated'][num_of_images]
95+
imageGenerated=img.upload_file_to_s3(imgbase64encoded,file_name)
96+
num_of_images=+1
97+
uploaded_file_name=imageGenerated['file_name']
98+
image_moderation_response=img.image_moderation(uploaded_file_name)
99+
if(image_moderation_response['isToxic']==True):
100+
response["message"]="In-appropriate image generated."
101+
response["status"]='Blocked'
102+
else:
103+
response={
104+
"filename":uploaded_file_name,
105+
"generatedImagePath":"",
106+
"status":"Completed",
107+
"image_path":bucket,
108+
"input_text":input_params['input_text'],
109+
"jobid":input_params["jobid"],
110+
"message":'Image generated successfully'
111+
}
112+
print (f"response :: {response}")
113+
img.send_job_status(response)
104114

105115
return response
106116

107117

108-
def parse_response(query_response):
118+
def parse_response(query_response,model_id):
109119
"""Parse response and return generated image and the prompt"""
110-
print(f'query_response:: {query_response}')
111-
if(not query_response):
112-
parsed_reponse['image_generated_status']='Failed'
113-
response_dict = json.loads(query_response["body"].read())
114120
parsed_reponse={
115121
"image_generated":'',
116122
"image_generated_status":'Success'
117123
}
118-
if(response_dict['artifacts'] is None):
124+
if(not query_response):
119125
parsed_reponse['image_generated_status']='Failed'
126+
return parsed_reponse
120127
else:
121-
parsed_reponse['image_generated']=response_dict['artifacts'][0]["base64"]
122-
return parsed_reponse
128+
response_dict = json.loads(query_response["body"].read())
123129

130+
if model_id=='stability.stable-diffusion-xl' :
124131

132+
if(response_dict['artifacts'] is None):
133+
parsed_reponse['image_generated_status']='Failed'
134+
else:
135+
parsed_reponse['image_generated']=response_dict['artifacts']
136+
137+
if model_id=='amazon.titan-image-generator-v1' :
138+
if(response_dict['images'] is None):
139+
parsed_reponse['image_generated_status']='Failed'
140+
else:
141+
numiofimages=response_dict['images']
142+
print(f' number of images ::{len(numiofimages)}')
143+
parsed_reponse['image_generated']=response_dict['images']
125144

145+
return parsed_reponse
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,8 @@
11
aws-lambda-powertools
22
aws-xray-sdk
33
fastjsonschema
4-
typing-extensions
5-
aiohttp
64
boto3>=1.28.69
75
botocore>=1.31.69
86
requests==2.31.0
97
requests-aws4auth==1.2.3
10-
opensearch-py==2.3.1
11-
numpy
12-
langchain==0.0.329
13-
opensearch-py
8+

0 commit comments

Comments
 (0)