@@ -210,7 +210,7 @@ def __init__(self, app: FastAPI, queue_lock: Lock):
210
210
self .add_api_route ("/sdapi/v1/unload-checkpoint" , self .unloadapi , methods = ["POST" ])
211
211
self .add_api_route ("/sdapi/v1/reload-checkpoint" , self .reloadapi , methods = ["POST" ])
212
212
self .add_api_route ("/sdapi/v1/scripts" , self .get_scripts_list , methods = ["GET" ], response_model = ScriptsList )
213
- self .add_api_route ("/invocations" , self .invocations , methods = ["POST" ], response_model = Union [TextToImageResponse , ImageToImageResponse , ExtrasSingleImageResponse , ExtrasBatchImagesResponse ])
213
+ self .add_api_route ("/invocations" , self .invocations , methods = ["POST" ], response_model = Union [TextToImageResponse , ImageToImageResponse , ExtrasSingleImageResponse , ExtrasBatchImagesResponse , InvocationsErrorResponse ])
214
214
self .add_api_route ("/ping" , self .ping , methods = ["GET" ], response_model = PingResponse )
215
215
216
216
self .default_script_arg_txt2img = []
@@ -739,51 +739,56 @@ def invocations(self, req: InvocationsRequest):
739
739
print ('-------invocation------' )
740
740
print (req )
741
741
742
- if req .vae != None :
743
- shared .opts .data ['sd_vae' ] = req .vae
744
- refresh_vae_list ()
745
-
746
- if req .model != None :
747
- sd_model_checkpoint = shared .opts .sd_model_checkpoint
748
- shared .opts .sd_model_checkpoint = req .model
749
- with self .queue_lock :
750
- reload_model_weights ()
751
- if sd_model_checkpoint == shared .opts .sd_model_checkpoint :
752
- reload_vae_weights ()
753
-
754
- quality = req .quality
755
-
756
- embeddings_s3uri = shared .cmd_opts .embeddings_s3uri
757
- hypernetwork_s3uri = shared .cmd_opts .hypernetwork_s3uri
758
-
759
- self .download_s3files (hypernetwork_s3uri , os .path .join (script_path , shared .cmd_opts .hypernetwork_dir ))
760
- shared .reload_hypernetworks ()
761
-
762
742
try :
763
- if req .task == 'text-to-image' :
764
- self .download_s3files (embeddings_s3uri , os .path .join (script_path , shared .cmd_opts .embeddings_dir ))
765
- sd_hijack .model_hijack .embedding_db .load_textual_inversion_embeddings ()
766
- response = self .text2imgapi (req .txt2img_payload )
767
- response .images = self .post_invocations (response .images , quality )
768
- return response
769
- elif req .task == 'image-to-image' :
770
- self .download_s3files (embeddings_s3uri , os .path .join (script_path , shared .cmd_opts .embeddings_dir ))
771
- sd_hijack .model_hijack .embedding_db .load_textual_inversion_embeddings ()
772
- response = self .img2imgapi (req .img2img_payload )
773
- response .images = self .post_invocations (response .images , quality )
774
- return response
775
- elif req .task == 'extras-single-image' :
776
- response = self .extras_single_image_api (req .extras_single_payload )
777
- response .image = self .post_invocations ([response .image ], quality )[0 ]
778
- return response
779
- elif req .task == 'extras-batch-images' :
780
- response = self .extras_batch_images_api (req .extras_batch_payload )
781
- response .images = self .post_invocations (response .images , quality )
782
- return response
783
- else :
784
- raise NotImplementedError
743
+ if req .vae != None :
744
+ shared .opts .data ['sd_vae' ] = req .vae
745
+ refresh_vae_list ()
746
+
747
+ if req .model != None :
748
+ sd_model_checkpoint = shared .opts .sd_model_checkpoint
749
+ shared .opts .sd_model_checkpoint = req .model
750
+ with self .queue_lock :
751
+ reload_model_weights ()
752
+ if sd_model_checkpoint == shared .opts .sd_model_checkpoint :
753
+ reload_vae_weights ()
754
+
755
+ quality = req .quality
756
+
757
+ embeddings_s3uri = shared .cmd_opts .embeddings_s3uri
758
+ hypernetwork_s3uri = shared .cmd_opts .hypernetwork_s3uri
759
+
760
+ self .download_s3files (hypernetwork_s3uri , os .path .join (script_path , shared .cmd_opts .hypernetwork_dir ))
761
+ shared .reload_hypernetworks ()
762
+
763
+ if req .options != None :
764
+ options = json .lods (req .options )
765
+ for key in options :
766
+ shared .opts .data [key ] = options [key ]
767
+ if req .task == 'text-to-image' :
768
+ self .download_s3files (embeddings_s3uri , os .path .join (script_path , shared .cmd_opts .embeddings_dir ))
769
+ sd_hijack .model_hijack .embedding_db .load_textual_inversion_embeddings ()
770
+ response = self .text2imgapi (req .txt2img_payload )
771
+ response .images = self .post_invocations (response .images , quality )
772
+ return response
773
+ elif req .task == 'image-to-image' :
774
+ self .download_s3files (embeddings_s3uri , os .path .join (script_path , shared .cmd_opts .embeddings_dir ))
775
+ sd_hijack .model_hijack .embedding_db .load_textual_inversion_embeddings ()
776
+ response = self .img2imgapi (req .img2img_payload )
777
+ response .images = self .post_invocations (response .images , quality )
778
+ return response
779
+ elif req .task == 'extras-single-image' :
780
+ response = self .extras_single_image_api (req .extras_single_payload )
781
+ response .image = self .post_invocations ([response .image ], quality )[0 ]
782
+ return response
783
+ elif req .task == 'extras-batch-images' :
784
+ response = self .extras_batch_images_api (req .extras_batch_payload )
785
+ response .images = self .post_invocations (response .images , quality )
786
+ return response
787
+ else :
788
+ return InvocationsErrorResponse (error = f'Invalid task - { req .task } ' )
785
789
except Exception as e :
786
790
traceback .print_exc ()
791
+ return InvocationsErrorResponse (error = str (e ))
787
792
788
793
def ping (self ):
789
794
print ('-------ping------' )
0 commit comments