2
2
import requests as reqs
3
3
from aws_lambda_powertools import Logger , Tracer , Metrics
4
4
from requests_aws4auth import AWS4Auth
5
-
5
+ from datetime import datetime
6
6
7
7
logger = Logger (service = "IMAGE_GENERATION" )
8
8
tracer = Tracer (service = "IMAGE_GENERATION" )
21
21
)
22
22
23
23
class image_generator ():
24
- """Generate image ."""
24
+ """Generate Image based on consfigured modelId .
25
+ Implements text and omage moderation with Amazon Rekognition and
26
+ Amzon Comprehend.
27
+ """
25
28
26
- def __init__ (self ,input_text ,file_name , rekognition_client ,comprehend_client ,bedrock_client ,bucket ):
29
+ def __init__ (self ,input_text , rekognition_client ,comprehend_client ,bedrock_client ,bucket ):
27
30
"""Initialize with bucket , key and rekognition_client."""
28
31
29
- self .file_name = file_name
30
32
self .rekognition_client = rekognition_client
31
33
self .comprehend_client = comprehend_client
32
34
self .input_text = input_text
@@ -36,19 +38,21 @@ def __init__(self,input_text,file_name, rekognition_client,comprehend_client,bed
36
38
37
39
38
40
@tracer .capture_method
39
- def upload_file_to_s3 (self ,imgbase64encoded ):
41
+ def upload_file_to_s3 (self ,imgbase64encoded , file_name ):
40
42
41
43
"""Upload generated file to S3 bucket"""
42
-
43
- logger .info (f"uploading file to s3 bucket: { self .bucket } , key: { self .file_name } " )
44
+
45
+ logger .info (f"uploading file to s3 bucket: { self .bucket } , key: { file_name } " )
46
+ current_datetime = datetime .now ().strftime ("%Y-%m-%d_%H:%M:%S.%f" )
47
+ upload_file_name = file_name + current_datetime + ".jpg"
44
48
try :
45
- respImg = s3 .Object (self .bucket , self . file_name ).put (Body = base64 .b64decode (imgbase64encoded ))
49
+ respImg = s3 .Object (self .bucket , upload_file_name ).put (Body = base64 .b64decode (imgbase64encoded ))
46
50
47
51
except Exception as e :
48
52
logger .error (f"Error occured :: { e } " )
49
53
return False
50
54
return {
51
- "file_name" :self . file_name ,
55
+ "file_name" :upload_file_name ,
52
56
"bucket_name" :self .bucket ,
53
57
}
54
58
@@ -81,7 +85,7 @@ def text_moderation(self):
81
85
return response
82
86
83
87
@tracer .capture_method
84
- def image_moderation (self ):
88
+ def image_moderation (self , file_name ):
85
89
86
90
"""Detect image moderation on the generated image to avoid any toxicity/nudity"""
87
91
@@ -94,7 +98,7 @@ def image_moderation(self):
94
98
Image = {
95
99
'S3Object' :{
96
100
'Bucket' :self .bucket ,
97
- 'Name' :self . file_name }
101
+ 'Name' :file_name }
98
102
}
99
103
)
100
104
for label in rekognition_response ['ModerationLabels' ]:
@@ -109,47 +113,33 @@ def image_moderation(self):
109
113
def generate_image (self ,input_params ):
110
114
111
115
"""Generate image using Using bedrock with configured modelid and params"""
116
+
112
117
113
118
input_text = self .input_text
114
-
119
+ print ( f' input_params :: { input_params } ' )
115
120
# add default negative prompts
116
- if 'negative_prompts' in input_params :
121
+ if 'negative_prompts' in input_params and input_params [ 'negative_prompts' ] is None :
117
122
sample_string_bytes = base64 .b64decode (input_params ['negative_prompts' ])
118
123
decoded_negative_prompts = sample_string_bytes .decode ("utf-8" )
119
124
logger .info (f"decoded negative prompts are :: { decoded_negative_prompts } " )
120
125
negative_prompts = decoded_negative_prompts
121
126
else :
122
- negative_prompts = ["poorly rendered" ,"poor background details" ]
123
-
127
+ negative_prompts = ["poorly rendered" ,"poor background details" , "poorly drawn mountains" , "disfigured mountain features" ]
128
+
124
129
model_id = input_params ['model_config' ]['modelId' ]
125
130
126
131
model_kwargs = input_params ['model_config' ]['model_kwargs' ]
127
132
params = get_inference_parameters (model_kwargs )
128
133
129
- logger .info (f'SD params :: { params } ' )
130
134
131
-
132
- request = json .dumps ({
133
- "text_prompts" : (
134
- [{"text" : input_text , "weight" : 1.0 }]
135
- + [{"text" : negprompt , "weight" : - 1.0 } for negprompt in negative_prompts ]
136
- ),
137
- "cfg_scale" :params ['cfg_scale' ],
138
- "seed" : params ['seed' ],
139
- "steps" : params ['steps' ],
140
- "style_preset" : params ['style_preset' ],
141
- "clip_guidance_preset" : params ['clip_guidance_preset' ],
142
- "sampler" : params ['sampler' ],
143
- "width" : params ['width' ],
144
- "height" : params ['height' ]
145
- })
146
-
135
+ body = get_model_payload (model_id ,params ,input_text ,negative_prompts )
136
+ print (f' body :: { body } ' )
147
137
try :
148
138
return self .bedrock_client .invoke_model (
149
139
modelId = model_id ,
150
140
contentType = "application/json" ,
151
141
accept = "application/json" ,
152
- body = request
142
+ body = body
153
143
)
154
144
except Exception as e :
155
145
logger .error (f"Error occured during generating image:: { e } " )
@@ -198,7 +188,44 @@ def send_job_status(self,variables):
198
188
)
199
189
logger .info ('res :: {}' ,responseJobstatus )
200
190
191
+ def get_model_payload (modelid ,params ,input_text ,negative_prompts ):
192
+
193
+ body = ''
194
+ if modelid == 'stability.stable-diffusion-xl' :
195
+ body = json .dumps ({
196
+ "text_prompts" : (
197
+ [{"text" : input_text , "weight" : 1.0 }]
198
+ + [{"text" : negprompt , "weight" : - 1.0 } for negprompt in negative_prompts ]
199
+ ),
200
+ "cfg_scale" :params ['cfg_scale' ],
201
+ "seed" : params ['seed' ],
202
+ "steps" : params ['steps' ],
203
+ "style_preset" : params ['style_preset' ],
204
+ "clip_guidance_preset" : params ['clip_guidance_preset' ],
205
+ "sampler" : params ['sampler' ],
206
+ "width" : params ['width' ],
207
+ "height" : params ['height' ]
208
+ })
209
+ return body
210
+ if modelid == 'amazon.titan-image-generator-v1' :
201
211
212
+ body = json .dumps ({
213
+ "taskType" : "TEXT_IMAGE" ,
214
+ "textToImageParams" : {
215
+ "text" : input_text ,
216
+ #"negativeText": negative_prompts
217
+ },
218
+ "imageGenerationConfig" : {
219
+ "numberOfImages" : params ['numberOfImages' ],
220
+ "quality" :params ['quality' ],
221
+ "height" : params ['height' ],
222
+ "width" : params ['width' ],
223
+ "cfgScale" : params ['cfg_scale' ],
224
+ "seed" : params ['seed' ]
225
+ }
226
+ })
227
+ return body
228
+
202
229
def get_inference_parameters (model_kwargs ):
203
230
""" Read inference parameters and set default values"""
204
231
if 'seed' in model_kwargs :
@@ -233,6 +260,15 @@ def get_inference_parameters(model_kwargs):
233
260
sampler = model_kwargs ['sampler' ]
234
261
else :
235
262
sampler = 'K_DPMPP_2S_ANCESTRAL'
263
+ if 'numberOfImages' in model_kwargs :
264
+ numberOfImages = model_kwargs ['numberOfImages' ]
265
+ else :
266
+ numberOfImages = 1
267
+ if 'quality' in model_kwargs :
268
+ quality = model_kwargs ['quality' ]
269
+ else :
270
+ quality = "standard"
271
+
236
272
return {
237
273
"cfg_scale" : cfg_scale ,
238
274
"seed" : seed ,
@@ -242,4 +278,6 @@ def get_inference_parameters(model_kwargs):
242
278
"sampler" : sampler ,
243
279
"width" : width ,
244
280
"height" : height ,
245
- }
281
+ "numberOfImages" : numberOfImages ,
282
+ "quality" : quality
283
+ }
0 commit comments