38
38
import uuid
39
39
import os
40
40
import json
41
- import boto3
42
- cache = dict ()
43
- s3_client = boto3 .client ('s3' )
44
- s3_resource = boto3 .resource ('s3' )
45
- generated_images_s3uri = os .environ .get ('generated_images_s3uri' , None )
46
41
47
42
def upscaler_to_index (name : str ):
48
43
try :
@@ -107,6 +102,35 @@ def encode_pil_to_base64(image):
107
102
108
103
return base64 .b64encode (bytes_data )
109
104
105
+ def export_pil_to_bytes (image ):
106
+ with io .BytesIO () as output_bytes :
107
+
108
+ if opts .samples_format .lower () == 'png' :
109
+ use_metadata = False
110
+ metadata = PngImagePlugin .PngInfo ()
111
+ for key , value in image .info .items ():
112
+ if isinstance (key , str ) and isinstance (value , str ):
113
+ metadata .add_text (key , value )
114
+ use_metadata = True
115
+ image .save (output_bytes , format = "PNG" , pnginfo = (metadata if use_metadata else None ), quality = opts .jpeg_quality )
116
+
117
+ elif opts .samples_format .lower () in ("jpg" , "jpeg" , "webp" ):
118
+ parameters = image .info .get ('parameters' , None )
119
+ exif_bytes = piexif .dump ({
120
+ "Exif" : { piexif .ExifIFD .UserComment : piexif .helper .UserComment .dump (parameters or "" , encoding = "unicode" ) }
121
+ })
122
+ if opts .samples_format .lower () in ("jpg" , "jpeg" ):
123
+ image .save (output_bytes , format = "JPEG" , exif = exif_bytes , quality = opts .jpeg_quality )
124
+ else :
125
+ image .save (output_bytes , format = "WEBP" , exif = exif_bytes , quality = opts .jpeg_quality )
126
+
127
+ else :
128
+ raise HTTPException (status_code = 500 , detail = "Invalid image format" )
129
+
130
+ bytes_data = output_bytes .getvalue ()
131
+
132
+ return bytes_data
133
+
110
134
def api_middleware (app : FastAPI ):
111
135
rich_available = True
112
136
try :
@@ -211,7 +235,7 @@ def __init__(self, app: FastAPI, queue_lock: Lock):
211
235
self .add_api_route ("/sdapi/v1/unload-checkpoint" , self .unloadapi , methods = ["POST" ])
212
236
self .add_api_route ("/sdapi/v1/reload-checkpoint" , self .reloadapi , methods = ["POST" ])
213
237
self .add_api_route ("/sdapi/v1/scripts" , self .get_scripts_list , methods = ["GET" ], response_model = ScriptsList )
214
- self .add_api_route ("/invocations" , self .invocations , methods = ["POST" ], response_model = Union [TextToImageResponse , ImageToImageResponse , ExtrasSingleImageResponse , ExtrasBatchImagesResponse ,MemoryResponse ,List [SDModelItem ],List [UpscalerItem ],OptionsModel ,List [SamplerItem ],FlagsModel ,ProgressResponse ])
238
+ self .add_api_route ("/invocations" , self .invocations , methods = ["POST" ], response_model = Union [TextToImageResponse , ImageToImageResponse , ExtrasSingleImageResponse , ExtrasBatchImagesResponse , InvocationsErrorResponse , InterrogateResponse , MemoryResponse , List [SDModelItem ], List [UpscalerItem ], OptionsModel , List [SamplerItem ], FlagsModel , ProgressResponse ])
215
239
self .add_api_route ("/ping" , self .ping , methods = ["GET" ], response_model = PingResponse )
216
240
217
241
self .default_script_arg_txt2img = []
@@ -715,28 +739,21 @@ def get_memory(self):
715
739
return MemoryResponse (ram = ram , cuda = cuda )
716
740
717
741
def post_invocations (self , b64images , quality ):
718
- if generated_images_s3uri :
719
- bucket , key = self .get_bucket_and_key (generated_images_s3uri )
742
+ if shared .generated_images_s3uri :
743
+ bucket , key = shared .get_bucket_and_key (shared .generated_images_s3uri )
744
+ if key .endswith ('/' ):
745
+ key = key [ : - 1 ]
720
746
images = []
721
747
for b64image in b64images :
722
- image = decode_base64_to_image (b64image ).convert ('RGB' )
723
- output = io .BytesIO ()
724
-
725
- try :
726
- if not quality :
727
- quality = 95
728
-
729
- image .save (output , format = 'PNG' , quality = quality )
730
- except Exception :
731
- image .save (output , format = 'PNG' , quality = 95 )
732
-
733
- image_id = str (uuid .uuid4 ())
734
- s3_client .put_object (
735
- Body = output .getvalue (),
748
+ bytes_data = export_pil_to_bytes (decode_base64_to_image (b64image ))
749
+ image_id = datetime .datetime .now ().strftime (f"%Y%m%d%H%M%S-{ uuid .uuid4 ()} " )
750
+ suffix = opts .samples_format .lower ()
751
+ shared .s3_client .put_object (
752
+ Body = bytes_data ,
736
753
Bucket = bucket ,
737
- Key = f'{ key } /{ image_id } .png '
754
+ Key = f'{ key } /{ image_id } .{ suffix } '
738
755
)
739
- images .append (f's3://{ bucket } /{ key } /{ image_id } .png ' )
756
+ images .append (f's3://{ bucket } /{ key } /{ image_id } .{ suffix } ' )
740
757
return images
741
758
else :
742
759
return b64images
@@ -804,7 +821,7 @@ def invocations(self, req: InvocationsRequest):
804
821
hypernetwork_s3uri = shared .cmd_opts .hypernetwork_s3uri
805
822
806
823
if hypernetwork_s3uri != '' :
807
- self . download_s3files (hypernetwork_s3uri , shared .cmd_opts .hypernetwork_dir )
824
+ shared . s3_download (hypernetwork_s3uri , shared .cmd_opts .hypernetwork_dir )
808
825
shared .reload_hypernetworks ()
809
826
810
827
if req .options != None :
@@ -816,14 +833,14 @@ def invocations(self, req: InvocationsRequest):
816
833
817
834
if req .task == 'text-to-image' :
818
835
if embeddings_s3uri != '' :
819
- self . download_s3files (embeddings_s3uri , shared .cmd_opts .embeddings_dir )
836
+ shared . s3_download (embeddings_s3uri , shared .cmd_opts .embeddings_dir )
820
837
sd_hijack .model_hijack .embedding_db .load_textual_inversion_embeddings ()
821
838
response = self .text2imgapi (req .txt2img_payload )
822
839
response .images = self .post_invocations (response .images , quality )
823
840
return response
824
841
elif req .task == 'image-to-image' :
825
842
if embeddings_s3uri != '' :
826
- self . download_s3files (embeddings_s3uri , shared .cmd_opts .embeddings_dir )
843
+ shared . s3_download (embeddings_s3uri , shared .cmd_opts .embeddings_dir )
827
844
sd_hijack .model_hijack .embedding_db .load_textual_inversion_embeddings ()
828
845
response = self .img2imgapi (req .img2img_payload )
829
846
response .images = self .post_invocations (response .images , quality )
@@ -836,6 +853,10 @@ def invocations(self, req: InvocationsRequest):
836
853
response = self .extras_batch_images_api (req .extras_batch_payload )
837
854
response .images = self .post_invocations (response .images , quality )
838
855
return response
856
+ elif req .task == 'interrogate' :
857
+ response = self .interrogateapi (req .interrogate_payload )
858
+ return response
859
+
839
860
elif req .task == 'get-progress' :
840
861
response = self .progressapi (req .progress_payload )
841
862
print ("____________getting progress result: " )
@@ -873,44 +894,8 @@ def invocations(self, req: InvocationsRequest):
873
894
return InvocationsErrorResponse (error = str (e ))
874
895
875
896
def ping (self ):
876
- print ('-------ping------' )
877
897
return {'status' : 'Healthy' }
878
898
879
899
def launch (self , server_name , port ):
880
900
self .app .include_router (self .router )
881
901
uvicorn .run (self .app , host = server_name , port = port )
882
-
883
- def get_bucket_and_key (self , s3uri ):
884
- pos = s3uri .find ('/' , 5 )
885
- bucket = s3uri [5 : pos ]
886
- key = s3uri [pos + 1 : ]
887
- return bucket , key
888
-
889
- def download_s3files (self , s3uri , path ):
890
- global cache
891
-
892
- pos = s3uri .find ('/' , 5 )
893
- bucket = s3uri [5 : pos ]
894
- key = s3uri [pos + 1 : ]
895
-
896
- s3_bucket = s3_resource .Bucket (bucket )
897
- objs = list (s3_bucket .objects .filter (Prefix = key ))
898
-
899
- if os .path .isfile ('cache' ):
900
- cache = json .load (open ('cache' , 'r' ))
901
-
902
- for obj in objs :
903
- if obj .key == key :
904
- continue
905
- response = s3_client .head_object (
906
- Bucket = bucket ,
907
- Key = obj .key
908
- )
909
- obj_key = 's3://{0}/{1}' .format (bucket , obj .key )
910
- if obj_key not in cache or cache [obj_key ] != response ['ETag' ]:
911
- filename = obj .key [obj .key .rfind ('/' ) + 1 : ]
912
-
913
- s3_client .download_file (bucket , obj .key , os .path .join (path , filename ))
914
- cache [obj_key ] = response ['ETag' ]
915
-
916
- json .dump (cache , open ('cache' , 'w' ))
0 commit comments