Skip to content

Commit 2dc849a

Browse files
committed
Merge branch 'api' into api_airi_new_merge
2 parents 1cc9b93 + 63613b8 commit 2dc849a

8 files changed

+472
-158
lines changed

modules/api/api.py

+48-63
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,6 @@
3838
import uuid
3939
import os
4040
import json
41-
import boto3
42-
cache = dict()
43-
s3_client = boto3.client('s3')
44-
s3_resource= boto3.resource('s3')
45-
generated_images_s3uri = os.environ.get('generated_images_s3uri', None)
4641

4742
def upscaler_to_index(name: str):
4843
try:
@@ -107,6 +102,35 @@ def encode_pil_to_base64(image):
107102

108103
return base64.b64encode(bytes_data)
109104

105+
def export_pil_to_bytes(image):
106+
with io.BytesIO() as output_bytes:
107+
108+
if opts.samples_format.lower() == 'png':
109+
use_metadata = False
110+
metadata = PngImagePlugin.PngInfo()
111+
for key, value in image.info.items():
112+
if isinstance(key, str) and isinstance(value, str):
113+
metadata.add_text(key, value)
114+
use_metadata = True
115+
image.save(output_bytes, format="PNG", pnginfo=(metadata if use_metadata else None), quality=opts.jpeg_quality)
116+
117+
elif opts.samples_format.lower() in ("jpg", "jpeg", "webp"):
118+
parameters = image.info.get('parameters', None)
119+
exif_bytes = piexif.dump({
120+
"Exif": { piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(parameters or "", encoding="unicode") }
121+
})
122+
if opts.samples_format.lower() in ("jpg", "jpeg"):
123+
image.save(output_bytes, format="JPEG", exif = exif_bytes, quality=opts.jpeg_quality)
124+
else:
125+
image.save(output_bytes, format="WEBP", exif = exif_bytes, quality=opts.jpeg_quality)
126+
127+
else:
128+
raise HTTPException(status_code=500, detail="Invalid image format")
129+
130+
bytes_data = output_bytes.getvalue()
131+
132+
return bytes_data
133+
110134
def api_middleware(app: FastAPI):
111135
rich_available = True
112136
try:
@@ -211,7 +235,7 @@ def __init__(self, app: FastAPI, queue_lock: Lock):
211235
self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"])
212236
self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"])
213237
self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=ScriptsList)
214-
self.add_api_route("/invocations", self.invocations, methods=["POST"], response_model=Union[TextToImageResponse, ImageToImageResponse, ExtrasSingleImageResponse, ExtrasBatchImagesResponse,MemoryResponse,List[SDModelItem],List[UpscalerItem],OptionsModel,List[SamplerItem],FlagsModel,ProgressResponse])
238+
self.add_api_route("/invocations", self.invocations, methods=["POST"], response_model=Union[TextToImageResponse, ImageToImageResponse, ExtrasSingleImageResponse, ExtrasBatchImagesResponse, InvocationsErrorResponse, InterrogateResponse, MemoryResponse, List[SDModelItem], List[UpscalerItem], OptionsModel, List[SamplerItem], FlagsModel, ProgressResponse])
215239
self.add_api_route("/ping", self.ping, methods=["GET"], response_model=PingResponse)
216240

217241
self.default_script_arg_txt2img = []
@@ -715,28 +739,21 @@ def get_memory(self):
715739
return MemoryResponse(ram = ram, cuda = cuda)
716740

717741
def post_invocations(self, b64images, quality):
718-
if generated_images_s3uri:
719-
bucket, key = self.get_bucket_and_key(generated_images_s3uri)
742+
if shared.generated_images_s3uri:
743+
bucket, key = shared.get_bucket_and_key(shared.generated_images_s3uri)
744+
if key.endswith('/'):
745+
key = key[ : -1]
720746
images = []
721747
for b64image in b64images:
722-
image = decode_base64_to_image(b64image).convert('RGB')
723-
output = io.BytesIO()
724-
725-
try:
726-
if not quality:
727-
quality = 95
728-
729-
image.save(output, format='PNG', quality=quality)
730-
except Exception:
731-
image.save(output, format='PNG', quality=95)
732-
733-
image_id = str(uuid.uuid4())
734-
s3_client.put_object(
735-
Body=output.getvalue(),
748+
bytes_data = export_pil_to_bytes(decode_base64_to_image(b64image))
749+
image_id = datetime.datetime.now().strftime(f"%Y%m%d%H%M%S-{uuid.uuid4()}")
750+
suffix = opts.samples_format.lower()
751+
shared.s3_client.put_object(
752+
Body=bytes_data,
736753
Bucket=bucket,
737-
Key=f'{key}/{image_id}.png'
754+
Key=f'{key}/{image_id}.{suffix}'
738755
)
739-
images.append(f's3://{bucket}/{key}/{image_id}.png')
756+
images.append(f's3://{bucket}/{key}/{image_id}.{suffix}')
740757
return images
741758
else:
742759
return b64images
@@ -804,7 +821,7 @@ def invocations(self, req: InvocationsRequest):
804821
hypernetwork_s3uri = shared.cmd_opts.hypernetwork_s3uri
805822

806823
if hypernetwork_s3uri !='':
807-
self.download_s3files(hypernetwork_s3uri, shared.cmd_opts.hypernetwork_dir)
824+
shared.s3_download(hypernetwork_s3uri, shared.cmd_opts.hypernetwork_dir)
808825
shared.reload_hypernetworks()
809826

810827
if req.options != None:
@@ -816,14 +833,14 @@ def invocations(self, req: InvocationsRequest):
816833

817834
if req.task == 'text-to-image':
818835
if embeddings_s3uri != '':
819-
self.download_s3files(embeddings_s3uri, shared.cmd_opts.embeddings_dir)
836+
shared.s3_download(embeddings_s3uri, shared.cmd_opts.embeddings_dir)
820837
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
821838
response = self.text2imgapi(req.txt2img_payload)
822839
response.images = self.post_invocations(response.images, quality)
823840
return response
824841
elif req.task == 'image-to-image':
825842
if embeddings_s3uri != '':
826-
self.download_s3files(embeddings_s3uri, shared.cmd_opts.embeddings_dir)
843+
shared.s3_download(embeddings_s3uri, shared.cmd_opts.embeddings_dir)
827844
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
828845
response = self.img2imgapi(req.img2img_payload)
829846
response.images = self.post_invocations(response.images, quality)
@@ -836,6 +853,10 @@ def invocations(self, req: InvocationsRequest):
836853
response = self.extras_batch_images_api(req.extras_batch_payload)
837854
response.images = self.post_invocations(response.images, quality)
838855
return response
856+
elif req.task == 'interrogate':
857+
response = self.interrogateapi(req.interrogate_payload)
858+
return response
859+
839860
elif req.task == 'get-progress':
840861
response = self.progressapi(req.progress_payload)
841862
print("____________getting progress result: ")
@@ -873,44 +894,8 @@ def invocations(self, req: InvocationsRequest):
873894
return InvocationsErrorResponse(error = str(e))
874895

875896
def ping(self):
876-
print('-------ping------')
877897
return {'status': 'Healthy'}
878898

879899
def launch(self, server_name, port):
880900
self.app.include_router(self.router)
881901
uvicorn.run(self.app, host=server_name, port=port)
882-
883-
def get_bucket_and_key(self, s3uri):
884-
pos = s3uri.find('/', 5)
885-
bucket = s3uri[5 : pos]
886-
key = s3uri[pos + 1 : ]
887-
return bucket, key
888-
889-
def download_s3files(self, s3uri, path):
890-
global cache
891-
892-
pos = s3uri.find('/', 5)
893-
bucket = s3uri[5 : pos]
894-
key = s3uri[pos + 1 : ]
895-
896-
s3_bucket = s3_resource.Bucket(bucket)
897-
objs = list(s3_bucket.objects.filter(Prefix=key))
898-
899-
if os.path.isfile('cache'):
900-
cache = json.load(open('cache', 'r'))
901-
902-
for obj in objs:
903-
if obj.key == key:
904-
continue
905-
response = s3_client.head_object(
906-
Bucket = bucket,
907-
Key = obj.key
908-
)
909-
obj_key = 's3://{0}/{1}'.format(bucket, obj.key)
910-
if obj_key not in cache or cache[obj_key] != response['ETag']:
911-
filename = obj.key[obj.key.rfind('/') + 1 : ]
912-
913-
s3_client.download_file(bucket, obj.key, os.path.join(path, filename))
914-
cache[obj_key] = response['ETag']
915-
916-
json.dump(cache, open('cache', 'w'))

modules/api/models.py

+1
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,7 @@ class InvocationsRequest(BaseModel):
302302
img2img_payload: Optional[StableDiffusionImg2ImgProcessingAPI]
303303
extras_single_payload: Optional[ExtrasSingleImageRequest]
304304
extras_batch_payload: Optional[ExtrasBatchImagesRequest]
305+
interrogate_payload: Optional[InterrogateRequest]
305306
progress_payload:Optional[ProgressRequest]
306307
post_options_payload:Optional[dict]
307308

modules/cmd_args.py

+1
Original file line numberDiff line numberDiff line change
@@ -109,3 +109,4 @@
109109
parser.add_argument('--dreambooth-config-id', default='', type=str, help='Dreambooth config ID')
110110
parser.add_argument('--embeddings-s3uri', default='', type=str, help='Embedding S3Uri')
111111
parser.add_argument('--hypernetwork-s3uri', default='', type=str, help='Hypernetwork S3Uri')
112+
parser.add_argument('--region-name', default='', type=str, help='Region name')

modules/script_callbacks.py

+9
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def __init__(self, imgs, cols, rows):
9393
callbacks_infotext_pasted=[],
9494
callbacks_script_unloaded=[],
9595
callbacks_before_ui=[],
96+
callbacks_update_cn_models=[]
9697
)
9798

9899

@@ -224,6 +225,12 @@ def before_ui_callback():
224225
except Exception:
225226
report_exception(c, 'before_ui')
226227

228+
def update_cn_models_callback():
229+
for c in callback_map['callbacks_update_cn_models']:
230+
try:
231+
c.callback()
232+
except Exception:
233+
report_exception(c, 'callbacks_update_cn_models')
227234

228235
def add_callback(callbacks, fun):
229236
stack = [x for x in inspect.stack() if x.filename != __file__]
@@ -247,6 +254,8 @@ def remove_callbacks_for_function(callback_func):
247254
for callback_to_remove in [cb for cb in callback_list if cb.callback == callback_func]:
248255
callback_list.remove(callback_to_remove)
249256

257+
def on_update_cn_models(callback):
258+
add_callback(callback_map['callbacks_update_cn_models'], callback)
250259

251260
def on_app_started(callback):
252261
"""register a function to be called when the webui started, the gradio `Block` component and

modules/shared.py

+147-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import os
55
import sys
66
import time
7-
7+
import threading
88
from PIL import Image
99
import gradio as gr
1010
import tqdm
@@ -15,9 +15,59 @@
1515
import modules.devices as devices
1616
from modules import localization, script_loading, errors, ui_components, shared_items, cmd_args
1717
from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir
18+
from botocore.exceptions import ClientError
19+
import glob
1820

1921
demo = None
2022

23+
models_s3_bucket = None
24+
s3_folder_sd = None
25+
s3_folder_cn = None
26+
s3_folder_lora = None
27+
syncLock = threading.Lock()
28+
tmp_models_dir = '/tmp/models'
29+
tmp_cache_dir = '/tmp/model_sync_cache'
30+
class ModelsRef:
31+
def __init__(self):
32+
self.models_ref = {}
33+
34+
def get_models_ref_dict(self):
35+
return self.models_ref
36+
37+
def add_models_ref(self, model_name):
38+
if model_name in self.models_ref:
39+
self.models_ref[model_name] += 1
40+
else:
41+
self.models_ref[model_name] = 0
42+
43+
def remove_model_ref(self,model_name):
44+
if self.models_ref.get(model_name):
45+
del self.models_ref[model_name]
46+
47+
def get_models_ref(self, model_name):
48+
return self.models_ref.get(model_name)
49+
50+
def get_least_ref_model(self):
51+
sorted_models = sorted(self.models_ref.items(), key=lambda item: item[1])
52+
if sorted_models:
53+
least_ref_model, least_counter = sorted_models[0]
54+
return least_ref_model,least_counter
55+
else:
56+
return None,None
57+
58+
def pop_least_ref_model(self):
59+
sorted_models = sorted(self.models_ref.items(), key=lambda item: item[1])
60+
if sorted_models:
61+
least_ref_model, least_counter = sorted_models[0]
62+
del self.models_ref[least_ref_model]
63+
return least_ref_model,least_counter
64+
else:
65+
return None,None
66+
67+
sd_models_Ref = ModelsRef()
68+
cn_models_Ref = ModelsRef()
69+
lora_models_Ref = ModelsRef()
70+
2171
parser = cmd_args.parser
2272

2373
script_loading.preload_extensions(extensions_dir, parser)
@@ -657,3 +707,99 @@ def html(filename):
657707
return file.read()
658708

659709
return ""
710+
711+
import boto3
712+
import requests
713+
714+
cache = dict()
715+
region_name = boto3.session.Session().region_name if not cmd_opts.train else cmd_opts.region_name
716+
s3_client = boto3.client('s3', region_name=region_name)
717+
endpointUrl = s3_client.meta.endpoint_url
718+
s3_client = boto3.client('s3', endpoint_url=endpointUrl, region_name=region_name)
719+
s3_resource= boto3.resource('s3')
720+
generated_images_s3uri = os.environ.get('generated_images_s3uri', None)
721+
722+
def get_bucket_and_key(s3uri):
723+
pos = s3uri.find('/', 5)
724+
bucket = s3uri[5 : pos]
725+
key = s3uri[pos + 1 : ]
726+
return bucket, key
727+
728+
def s3_download(s3uri, path):
729+
global cache
730+
731+
print('---path---', path)
732+
os.system(f'ls -l {os.path.dirname(path)}')
733+
734+
pos = s3uri.find('/', 5)
735+
bucket = s3uri[5 : pos]
736+
key = s3uri[pos + 1 : ]
737+
738+
objects = []
739+
paginator = s3_client.get_paginator('list_objects_v2')
740+
page_iterator = paginator.paginate(Bucket=bucket, Prefix=key)
741+
for page in page_iterator:
742+
if 'Contents' in page:
743+
for obj in page['Contents']:
744+
objects.append(obj)
745+
if 'NextContinuationToken' in page:
746+
page_iterator = paginator.paginate(Bucket=bucket, Prefix=key,
747+
ContinuationToken=page['NextContinuationToken'])
748+
749+
if os.path.isfile('cache'):
750+
cache = json.load(open('cache', 'r'))
751+
752+
for obj in objects:
753+
if obj['Size'] == 0:
754+
continue
755+
response = s3_client.head_object(
756+
Bucket = bucket,
757+
Key = obj['Key']
758+
)
759+
obj_key = 's3://{0}/{1}'.format(bucket, obj['Key'])
760+
if obj_key not in cache or cache[obj_key] != response['ETag']:
761+
filename = obj['Key'][obj['Key'].rfind('/') + 1 : ]
762+
763+
s3_client.download_file(bucket, obj['Key'], os.path.join(path, filename))
764+
cache[obj_key] = response['ETag']
765+
766+
json.dump(cache, open('cache', 'w'))
767+
768+
def http_download(httpuri, path):
769+
with requests.get(httpuri, stream=True) as r:
770+
r.raise_for_status()
771+
with open(path, 'wb') as f:
772+
for chunk in r.iter_content(chunk_size=8192):
773+
f.write(chunk)
774+
775+
def upload_s3files(s3uri, file_path_with_pattern):
776+
pos = s3uri.find('/', 5)
777+
bucket = s3uri[5 : pos]
778+
key = s3uri[pos + 1 : ]
779+
780+
try:
781+
for file_path in glob.glob(file_path_with_pattern):
782+
file_name = os.path.basename(file_path)
783+
__s3file = f'{key}{file_name}'
784+
print(file_path, __s3file)
785+
s3_client.upload_file(file_path, bucket, __s3file)
786+
except ClientError as e:
787+
print(e)
788+
return False
789+
return True
790+
791+
def upload_s3folder(s3uri, file_path):
792+
pos = s3uri.find('/', 5)
793+
bucket = s3uri[5 : pos]
794+
key = s3uri[pos + 1 : ]
795+
796+
try:
797+
for path, _, files in os.walk(file_path):
798+
for file in files:
799+
dest_path = path.replace(file_path,"")
800+
__s3file = f'{key}{dest_path}/{file}'
801+
__local_file = os.path.join(path, file)
802+
print(__local_file, __s3file)
803+
s3_client.upload_file(__local_file, bucket, __s3file)
804+
except Exception as e:
805+
print(e)

0 commit comments

Comments
 (0)