Skip to content

Commit 5be73c9

Browse files
committed
update webui.py
1 parent 78b3a69 commit 5be73c9

File tree

1 file changed

+30
-2
lines changed

1 file changed

+30
-2
lines changed

modules/api/api.py

+30-2
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ def __init__(self, app: FastAPI, queue_lock: Lock):
137137
self.cache = dict()
138138
self.s3_client = boto3.client('s3')
139139
self.s3_resource= boto3.resource('s3')
140+
self.generated_images_s3uri = os.environ.get('generated_images_s3uri', None)
140141

141142
def add_api_route(self, path: str, endpoint, **kwargs):
142143
if shared.cmd_opts.api_auth:
@@ -399,6 +400,25 @@ def download_s3files(self, s3uri, path):
399400

400401
json.dump(self.cache, open('cache', 'w'))
401402

403+
def post_invocations(self, b64images):
404+
if self.generated_images_s3uri:
405+
bucket, key = self.get_bucket_and_key(self.generated_images_s3uri)
406+
images = []
407+
for b64image in b64images:
408+
image = decode_base64_to_image(b64image).convert('RGB')
409+
output = io.BytesIO()
410+
image.save(output, format='JPEG')
411+
image_id = str(uuid.uuid4())
412+
self.s3_client.put_object(
413+
Body=output.getvalue(),
414+
Bucket=bucket,
415+
Key=f'{key}/{image_id}.jpg'
416+
)
417+
images.append(f's3://{bucket}/{key}/{image_id}.jpg')
418+
return images
419+
else:
420+
return b64images
421+
402422
def invocations(self, req: InvocationsRequest):
403423
print('-------invocation------')
404424
print(req)
@@ -433,24 +453,26 @@ def invocations(self, req: InvocationsRequest):
433453
self.download_s3files(embeddings_s3uri, os.path.join(script_path, shared.cmd_opts.embeddings_dir))
434454
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
435455
response = self.text2imgapi(req.txt2img_payload)
456+
response.images = self.post_invocations(response.images)
436457
shared.opts.data = default_options
437458
return response
438459
elif req.task == 'image-to-image':
439460
self.download_s3files(embeddings_s3uri, os.path.join(script_path, shared.cmd_opts.embeddings_dir))
440461
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
441462
response = self.img2imgapi(req.img2img_payload)
463+
response.images = self.post_invocations(response.images)
442464
shared.opts.data = default_options
443465
return response
444466
elif req.task == 'extras-single-image':
445467
response = self.extras_single_image_api(req.extras_single_payload)
468+
response.image = self.post_invocations([response.image])[0]
446469
shared.opts.data = default_options
447470
return response
448471
elif req.task == 'extras-batch-images':
449472
response = self.extras_batch_images_api(req.extras_batch_payload)
473+
response.images = self.post_invocations(response.images)
450474
shared.opts.data = default_options
451475
return response
452-
elif req.task == 'sd-models':
453-
return self.get_sd_models()
454476
else:
455477
raise NotImplementedError
456478
except Exception as e:
@@ -463,3 +485,9 @@ def ping(self):
463485
def launch(self, server_name, port):
464486
self.app.include_router(self.router)
465487
uvicorn.run(self.app, host=server_name, port=port)
488+
489+
def get_bucket_and_key(self, s3uri):
490+
pos = s3uri.find('/', 5)
491+
bucket = s3uri[5 : pos]
492+
key = s3uri[pos + 1 : ]
493+
return bucket, key

0 commit comments

Comments
 (0)