37
37
import uuid
38
38
import os
39
39
import json
40
- import boto3
41
- cache = dict ()
42
- s3_client = boto3 .client ('s3' )
43
- s3_resource = boto3 .resource ('s3' )
44
- generated_images_s3uri = os .environ .get ('generated_images_s3uri' , None )
45
40
46
41
def upscaler_to_index (name : str ):
47
42
try :
@@ -710,8 +705,8 @@ def get_memory(self):
710
705
return MemoryResponse (ram = ram , cuda = cuda )
711
706
712
707
def post_invocations (self , b64images , quality ):
713
- if generated_images_s3uri :
714
- bucket , key = self .get_bucket_and_key (generated_images_s3uri )
708
+ if shared . generated_images_s3uri :
709
+ bucket , key = shared .get_bucket_and_key (shared . generated_images_s3uri )
715
710
images = []
716
711
for b64image in b64images :
717
712
image = decode_base64_to_image (b64image ).convert ('RGB' )
@@ -726,7 +721,7 @@ def post_invocations(self, b64images, quality):
726
721
image .save (output , format = 'PNG' , quality = 95 )
727
722
728
723
image_id = str (uuid .uuid4 ())
729
- s3_client .put_object (
724
+ shared . s3_client .put_object (
730
725
Body = output .getvalue (),
731
726
Bucket = bucket ,
732
727
Key = f'{ key } /{ image_id } .png'
@@ -759,7 +754,7 @@ def invocations(self, req: InvocationsRequest):
759
754
hypernetwork_s3uri = shared .cmd_opts .hypernetwork_s3uri
760
755
761
756
if hypernetwork_s3uri != '' :
762
- self . download_s3files (hypernetwork_s3uri , shared .cmd_opts .hypernetwork_dir )
757
+ shared . s3_download (hypernetwork_s3uri , shared .cmd_opts .hypernetwork_dir )
763
758
shared .reload_hypernetworks ()
764
759
765
760
if req .options != None :
@@ -769,14 +764,14 @@ def invocations(self, req: InvocationsRequest):
769
764
770
765
if req .task == 'text-to-image' :
771
766
if embeddings_s3uri != '' :
772
- self . download_s3files (embeddings_s3uri , shared .cmd_opts .embeddings_dir )
767
+ shared . s3_download (embeddings_s3uri , shared .cmd_opts .embeddings_dir )
773
768
sd_hijack .model_hijack .embedding_db .load_textual_inversion_embeddings ()
774
769
response = self .text2imgapi (req .txt2img_payload )
775
770
response .images = self .post_invocations (response .images , quality )
776
771
return response
777
772
elif req .task == 'image-to-image' :
778
773
if embeddings_s3uri != '' :
779
- self . download_s3files (embeddings_s3uri , shared .cmd_opts .embeddings_dir )
774
+ shared . s3_download (embeddings_s3uri , shared .cmd_opts .embeddings_dir )
780
775
sd_hijack .model_hijack .embedding_db .load_textual_inversion_embeddings ()
781
776
response = self .img2imgapi (req .img2img_payload )
782
777
response .images = self .post_invocations (response .images , quality )
@@ -803,38 +798,3 @@ def ping(self):
803
798
def launch (self , server_name , port ):
804
799
self .app .include_router (self .router )
805
800
uvicorn .run (self .app , host = server_name , port = port )
806
-
807
- def get_bucket_and_key (self , s3uri ):
808
- pos = s3uri .find ('/' , 5 )
809
- bucket = s3uri [5 : pos ]
810
- key = s3uri [pos + 1 : ]
811
- return bucket , key
812
-
813
- def download_s3files (self , s3uri , path ):
814
- global cache
815
-
816
- pos = s3uri .find ('/' , 5 )
817
- bucket = s3uri [5 : pos ]
818
- key = s3uri [pos + 1 : ]
819
-
820
- s3_bucket = s3_resource .Bucket (bucket )
821
- objs = list (s3_bucket .objects .filter (Prefix = key ))
822
-
823
- if os .path .isfile ('cache' ):
824
- cache = json .load (open ('cache' , 'r' ))
825
-
826
- for obj in objs :
827
- if obj .key == key :
828
- continue
829
- response = s3_client .head_object (
830
- Bucket = bucket ,
831
- Key = obj .key
832
- )
833
- obj_key = 's3://{0}/{1}' .format (bucket , obj .key )
834
- if obj_key not in cache or cache [obj_key ] != response ['ETag' ]:
835
- filename = obj .key [obj .key .rfind ('/' ) + 1 : ]
836
-
837
- s3_client .download_file (bucket , obj .key , os .path .join (path , filename ))
838
- cache [obj_key ] = response ['ETag' ]
839
-
840
- json .dump (cache , open ('cache' , 'w' ))
0 commit comments