Skip to content

Commit a39b362

Browse files
committed
Add error response and handle options and exceptions
1 parent d71aeac commit a39b362

File tree

2 files changed

+52
-43
lines changed

2 files changed

+52
-43
lines changed

modules/api/api.py

+48-43
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ def __init__(self, app: FastAPI, queue_lock: Lock):
210210
self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"])
211211
self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"])
212212
self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=ScriptsList)
213-
self.add_api_route("/invocations", self.invocations, methods=["POST"], response_model=Union[TextToImageResponse, ImageToImageResponse, ExtrasSingleImageResponse, ExtrasBatchImagesResponse])
213+
self.add_api_route("/invocations", self.invocations, methods=["POST"], response_model=Union[TextToImageResponse, ImageToImageResponse, ExtrasSingleImageResponse, ExtrasBatchImagesResponse, InvocationsErrorResponse])
214214
self.add_api_route("/ping", self.ping, methods=["GET"], response_model=PingResponse)
215215

216216
self.default_script_arg_txt2img = []
@@ -739,51 +739,56 @@ def invocations(self, req: InvocationsRequest):
739739
print('-------invocation------')
740740
print(req)
741741

742-
if req.vae != None:
743-
shared.opts.data['sd_vae'] = req.vae
744-
refresh_vae_list()
745-
746-
if req.model != None:
747-
sd_model_checkpoint = shared.opts.sd_model_checkpoint
748-
shared.opts.sd_model_checkpoint = req.model
749-
with self.queue_lock:
750-
reload_model_weights()
751-
if sd_model_checkpoint == shared.opts.sd_model_checkpoint:
752-
reload_vae_weights()
753-
754-
quality = req.quality
755-
756-
embeddings_s3uri = shared.cmd_opts.embeddings_s3uri
757-
hypernetwork_s3uri = shared.cmd_opts.hypernetwork_s3uri
758-
759-
self.download_s3files(hypernetwork_s3uri, os.path.join(script_path, shared.cmd_opts.hypernetwork_dir))
760-
shared.reload_hypernetworks()
761-
762742
try:
763-
if req.task == 'text-to-image':
764-
self.download_s3files(embeddings_s3uri, os.path.join(script_path, shared.cmd_opts.embeddings_dir))
765-
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
766-
response = self.text2imgapi(req.txt2img_payload)
767-
response.images = self.post_invocations(response.images, quality)
768-
return response
769-
elif req.task == 'image-to-image':
770-
self.download_s3files(embeddings_s3uri, os.path.join(script_path, shared.cmd_opts.embeddings_dir))
771-
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
772-
response = self.img2imgapi(req.img2img_payload)
773-
response.images = self.post_invocations(response.images, quality)
774-
return response
775-
elif req.task == 'extras-single-image':
776-
response = self.extras_single_image_api(req.extras_single_payload)
777-
response.image = self.post_invocations([response.image], quality)[0]
778-
return response
779-
elif req.task == 'extras-batch-images':
780-
response = self.extras_batch_images_api(req.extras_batch_payload)
781-
response.images = self.post_invocations(response.images, quality)
782-
return response
783-
else:
784-
raise NotImplementedError
743+
if req.vae != None:
744+
shared.opts.data['sd_vae'] = req.vae
745+
refresh_vae_list()
746+
747+
if req.model != None:
748+
sd_model_checkpoint = shared.opts.sd_model_checkpoint
749+
shared.opts.sd_model_checkpoint = req.model
750+
with self.queue_lock:
751+
reload_model_weights()
752+
if sd_model_checkpoint == shared.opts.sd_model_checkpoint:
753+
reload_vae_weights()
754+
755+
quality = req.quality
756+
757+
embeddings_s3uri = shared.cmd_opts.embeddings_s3uri
758+
hypernetwork_s3uri = shared.cmd_opts.hypernetwork_s3uri
759+
760+
self.download_s3files(hypernetwork_s3uri, os.path.join(script_path, shared.cmd_opts.hypernetwork_dir))
761+
shared.reload_hypernetworks()
762+
763+
if req.options != None:
764+
options = json.lods(req.options)
765+
for key in options:
766+
shared.opts.data[key] = options[key]
767+
if req.task == 'text-to-image':
768+
self.download_s3files(embeddings_s3uri, os.path.join(script_path, shared.cmd_opts.embeddings_dir))
769+
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
770+
response = self.text2imgapi(req.txt2img_payload)
771+
response.images = self.post_invocations(response.images, quality)
772+
return response
773+
elif req.task == 'image-to-image':
774+
self.download_s3files(embeddings_s3uri, os.path.join(script_path, shared.cmd_opts.embeddings_dir))
775+
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
776+
response = self.img2imgapi(req.img2img_payload)
777+
response.images = self.post_invocations(response.images, quality)
778+
return response
779+
elif req.task == 'extras-single-image':
780+
response = self.extras_single_image_api(req.extras_single_payload)
781+
response.image = self.post_invocations([response.image], quality)[0]
782+
return response
783+
elif req.task == 'extras-batch-images':
784+
response = self.extras_batch_images_api(req.extras_batch_payload)
785+
response.images = self.post_invocations(response.images, quality)
786+
return response
787+
else:
788+
return InvocationsErrorResponse(error = f'Invalid task - {req.task}')
785789
except Exception as e:
786790
traceback.print_exc()
791+
return InvocationsErrorResponse(error = str(e))
787792

788793
def ping(self):
789794
print('-------ping------')

modules/api/models.py

+4
Original file line numberDiff line numberDiff line change
@@ -296,10 +296,14 @@ class InvocationsRequest(BaseModel):
296296
model: Optional[str]
297297
vae: Optional[str]
298298
quality: Optional[int]
299+
options: Optional[str]
299300
txt2img_payload: Optional[StableDiffusionTxt2ImgProcessingAPI]
300301
img2img_payload: Optional[StableDiffusionImg2ImgProcessingAPI]
301302
extras_single_payload: Optional[ExtrasSingleImageRequest]
302303
extras_batch_payload: Optional[ExtrasBatchImagesRequest]
303304

305+
class InvocationsErrorResponse(BaseModel):
306+
error: str = Field(title="Invocation error", description="Error response from invocation.")
307+
304308
class PingResponse(BaseModel):
305309
status: str

0 commit comments

Comments
 (0)