@@ -137,6 +137,7 @@ def __init__(self, app: FastAPI, queue_lock: Lock):
137
137
self .cache = dict ()
138
138
self .s3_client = boto3 .client ('s3' )
139
139
self .s3_resource = boto3 .resource ('s3' )
140
+ self .generated_images_s3uri = os .environ .get ('generated_images_s3uri' , None )
140
141
141
142
def add_api_route (self , path : str , endpoint , ** kwargs ):
142
143
if shared .cmd_opts .api_auth :
@@ -399,6 +400,25 @@ def download_s3files(self, s3uri, path):
399
400
400
401
json .dump (self .cache , open ('cache' , 'w' ))
401
402
403
+ def post_invocations (self , b64images ):
404
+ if self .generated_images_s3uri :
405
+ bucket , key = self .get_bucket_and_key (self .generated_images_s3uri )
406
+ images = []
407
+ for b64image in b64images :
408
+ image = decode_base64_to_image (b64image ).convert ('RGB' )
409
+ output = io .BytesIO ()
410
+ image .save (output , format = 'JPEG' )
411
+ image_id = str (uuid .uuid4 ())
412
+ self .s3_client .put_object (
413
+ Body = output .getvalue (),
414
+ Bucket = bucket ,
415
+ Key = f'{ key } /{ image_id } .jpg'
416
+ )
417
+ images .append (f's3://{ bucket } /{ key } /{ image_id } .jpg' )
418
+ return images
419
+ else :
420
+ return b64images
421
+
402
422
def invocations (self , req : InvocationsRequest ):
403
423
print ('-------invocation------' )
404
424
print (req )
@@ -433,24 +453,26 @@ def invocations(self, req: InvocationsRequest):
433
453
self .download_s3files (embeddings_s3uri , os .path .join (script_path , shared .cmd_opts .embeddings_dir ))
434
454
sd_hijack .model_hijack .embedding_db .load_textual_inversion_embeddings ()
435
455
response = self .text2imgapi (req .txt2img_payload )
456
+ response .images = self .post_invocations (response .images )
436
457
shared .opts .data = default_options
437
458
return response
438
459
elif req .task == 'image-to-image' :
439
460
self .download_s3files (embeddings_s3uri , os .path .join (script_path , shared .cmd_opts .embeddings_dir ))
440
461
sd_hijack .model_hijack .embedding_db .load_textual_inversion_embeddings ()
441
462
response = self .img2imgapi (req .img2img_payload )
463
+ response .images = self .post_invocations (response .images )
442
464
shared .opts .data = default_options
443
465
return response
444
466
elif req .task == 'extras-single-image' :
445
467
response = self .extras_single_image_api (req .extras_single_payload )
468
+ response .image = self .post_invocations ([response .image ])[0 ]
446
469
shared .opts .data = default_options
447
470
return response
448
471
elif req .task == 'extras-batch-images' :
449
472
response = self .extras_batch_images_api (req .extras_batch_payload )
473
+ response .images = self .post_invocations (response .images )
450
474
shared .opts .data = default_options
451
475
return response
452
- elif req .task == 'sd-models' :
453
- return self .get_sd_models ()
454
476
else :
455
477
raise NotImplementedError
456
478
except Exception as e :
@@ -463,3 +485,9 @@ def ping(self):
463
485
def launch (self , server_name , port ):
464
486
self .app .include_router (self .router )
465
487
uvicorn .run (self .app , host = server_name , port = port )
488
+
489
+ def get_bucket_and_key (self , s3uri ):
490
+ pos = s3uri .find ('/' , 5 )
491
+ bucket = s3uri [5 : pos ]
492
+ key = s3uri [pos + 1 : ]
493
+ return bucket , key
0 commit comments