Skip to content

Commit 019e8e3

Browse files
committed
Add our changes
1 parent c8794ea commit 019e8e3

File tree

5 files changed

+242
-12
lines changed

5 files changed

+242
-12
lines changed

.gitignore

+6-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ __pycache__
1313
/gfpgan/weights/*.pth
1414
/ui-config.json
1515
/outputs
16-
/config.json
1716
/log
1817
/webui.settings.bat
1918
/embeddings
@@ -37,3 +36,9 @@ notification.mp3
3736
/node_modules
3837
/package-lock.json
3938
/.coverage*
39+
.DS_Store
40+
build.info
41+
**/.DS_Store
42+
models/Stable-diffusion/00v2.1_768-ema-pruned.safetensors
43+
models/Stable-diffusion/00v2.1_768-ema-pruned.yaml
44+
webui-user.bat

config.json

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
{
2+
"samples_format": "jpeg",
3+
"control_net_model_config": "models/cldm_v21.yaml",
4+
"control_net_max_models_num": 3,
5+
"control_net_model_cache_size": 5,
6+
"control_net_monocular_depth_optim": true,
7+
"control_net_cfg_based_guidance": true,
8+
"control_net_control_transfer": true,
9+
"control_net_no_detectmap": true,
10+
"control_net_detectmap_autosaving": false,
11+
"control_net_only_midctrl_hires": true,
12+
"control_net_allow_script_control": true,
13+
"control_net_skip_img2img_processing": false,
14+
"control_net_only_mid_control": false,
15+
"control_net_sync_field_args": false,
16+
"CLIP_stop_at_last_layers": 2,
17+
"multiple_tqdm": false,
18+
"auto_vae_precision": true,
19+
"upcast_attn": true,
20+
"enable_quantization": true,
21+
"img_max_size_mp": 10.0,
22+
"export_for_4chan": true,
23+
"img_downscale_threshold": 6.0,
24+
"target_side_length": 3600.0,
25+
"randn_source": "CPU"
26+
}

modules/api/api.py

+204-9
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import io
33
import os
44
import time
5+
import copy
56
import datetime
67
import uvicorn
78
import ipaddress
@@ -129,6 +130,72 @@ def decode_base64_to_image(encoding):
129130
raise HTTPException(status_code=500, detail="Invalid encoded image") from e
130131

131132

133+
user_input_data = {}
134+
135+
136+
def set_img_exif_dict():
137+
title = "AIRI"
138+
date_taken = "2001:01:01 01:01:01"
139+
global user_input_data
140+
if "date_taken" in user_input_data:
141+
date_taken = user_input_data['date_taken']
142+
copyright = "© AIRI Lab. All Rights Reserved."
143+
camera_maker = "AIRI Lab"
144+
camera_model = "AIRI Model 1.0"
145+
user_id = "AIRI tester"
146+
if "user_id" in user_input_data:
147+
user_id = user_input_data['user_id']
148+
keywords = "Generated in AIRI platform. https://airilab.com"
149+
description = "An image processed by the AIRI platform."
150+
software = "AIRI Platform v1.0"
151+
# imageid = "imageid?"
152+
# imagenum = "imagenum?"
153+
# seed = "seed?"
154+
exif_dict = {
155+
"0th": {
156+
piexif.ImageIFD.ImageDescription: description.encode('utf-8'),
157+
piexif.ImageIFD.Make: camera_maker.encode('utf-8'),
158+
piexif.ImageIFD.Model: camera_model.encode('utf-8'),
159+
piexif.ImageIFD.Copyright: copyright.encode('utf-8'),
160+
piexif.ImageIFD.Artist: user_id.encode('utf-8'),
161+
piexif.ImageIFD.ProcessingSoftware: software.encode('utf-8'),
162+
piexif.ImageIFD.Software: software.encode('utf-8'),
163+
piexif.ImageIFD.DateTime: date_taken.encode('utf-8'),
164+
piexif.ImageIFD.HostComputer: software.encode('utf-8'),
165+
# piexif.ImageIFD.ImageID: imageid.encode('utf-8'),
166+
# piexif.ImageIFD.ImageNumber: imagenum.encode('utf-8'),
167+
piexif.ImageIFD.ImageHistory: keywords.encode('utf-8'),
168+
# piexif.ImageIFD.ImageResources: description.encode('utf-8'),
169+
# piexif.ImageIFD.Noise: seed.encode('utf-8'),
170+
piexif.ImageIFD.Predictor: camera_model.encode('utf-8'),
171+
piexif.ImageIFD.OriginalRawFileData: keywords.encode('utf-8'),
172+
# piexif.ImageIFD.OriginalRawFileName: imageid.encode('utf-8'),
173+
piexif.ImageIFD.ProfileCopyright: copyright.encode('utf-8'),
174+
piexif.ImageIFD.ProfileEmbedPolicy: software.encode('utf-8'),
175+
piexif.ImageIFD.Rating: "5".encode('utf-8'),
176+
piexif.ImageIFD.ProfileName: user_id.encode('utf-8'),
177+
# piexif.ImageIFD.XPAuthor: user_id.encode('utf-8'),
178+
# piexif.ImageIFD.XPTitle: title.encode('utf-8'),
179+
# piexif.ImageIFD.XPKeywords: keywords.encode('utf-8'),
180+
# piexif.ImageIFD.XPComment: description.encode('utf-8'),
181+
# piexif.ImageIFD.XPSubject: copyright.encode('utf-8'),
182+
},
183+
"Exif": {
184+
piexif.ExifIFD.DateTimeOriginal: date_taken.encode('utf-8'),
185+
piexif.ExifIFD.CameraOwnerName: user_id.encode('utf-8'),
186+
piexif.ExifIFD.DateTimeDigitized: date_taken.encode('utf-8'),
187+
piexif.ExifIFD.DeviceSettingDescription: camera_model.encode('utf-8'),
188+
piexif.ExifIFD.FileSource: keywords.encode('utf-8'),
189+
# piexif.ExifIFD.ImageUniqueID: imageid.encode('utf-8'),
190+
piexif.ExifIFD.LensMake: camera_maker.encode('utf-8'),
191+
piexif.ExifIFD.LensModel: camera_model.encode('utf-8'),
192+
piexif.ExifIFD.MakerNote: description.encode('utf-8'),
193+
piexif.ExifIFD.UserComment: description.encode('utf-8'),
194+
}
195+
}
196+
return exif_dict
197+
198+
132199
def encode_pil_to_base64(image):
133200
with io.BytesIO() as output_bytes:
134201

@@ -144,10 +211,14 @@ def encode_pil_to_base64(image):
144211
elif opts.samples_format.lower() in ("jpg", "jpeg", "webp"):
145212
if image.mode == "RGBA":
146213
image = image.convert("RGB")
147-
parameters = image.info.get('parameters', None)
148-
exif_bytes = piexif.dump({
149-
"Exif": { piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(parameters or "", encoding="unicode") }
150-
})
214+
# parameters = image.info.get('parameters', None)
215+
# exif_bytes = piexif.dump({
216+
# "Exif": { piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(parameters or "", encoding="unicode") }
217+
# })
218+
219+
# Convert dict to bytes
220+
exif_bytes = piexif.dump(set_img_exif_dict())
221+
151222
if opts.samples_format.lower() in ("jpg", "jpeg"):
152223
image.save(output_bytes, format="JPEG", exif = exif_bytes, quality=opts.jpeg_quality)
153224
else:
@@ -381,6 +452,12 @@ def init_script_args(self, request, default_script_args, selectable_scripts, sel
381452

382453
# Now check for always on scripts
383454
if request.alwayson_scripts:
455+
global user_input_data
456+
user_input_data = {}
457+
if "user_input" in request.alwayson_scripts:
458+
user_input_data = request.alwayson_scripts["user_input"]
459+
request.alwayson_scripts.pop("user_input")
460+
384461
for alwayson_script_name in request.alwayson_scripts.keys():
385462
alwayson_script = self.get_script(alwayson_script_name, script_runner)
386463
if alwayson_script is None:
@@ -535,7 +612,7 @@ def pnginfoapi(self, req: models.PNGInfoRequest):
535612
if(not req.image.strip()):
536613
return models.PNGInfoResponse(info="")
537614

538-
image = decode_base64_to_image(req.image.strip())
615+
image = decode_to_image(req.image.strip())
539616
if image is None:
540617
return models.PNGInfoResponse(info="")
541618

@@ -580,7 +657,7 @@ def interrogateapi(self, interrogatereq: models.InterrogateRequest):
580657
if image_b64 is None:
581658
raise HTTPException(status_code=404, detail="Image not found")
582659

583-
img = decode_base64_to_image(image_b64)
660+
img = decode_to_image(image_b64)
584661
img = img.convert('RGB')
585662

586663
# Override object param
@@ -623,6 +700,9 @@ def get_config(self):
623700

624701
return options
625702

703+
def get_all_config(self):
704+
return shared.opts.data
705+
626706
def set_config(self, req: Dict[str, Any]):
627707
checkpoint_name = req.get("sd_model_checkpoint", None)
628708
if checkpoint_name is not None and checkpoint_name not in checkpoint_aliases:
@@ -865,12 +945,53 @@ def post_invocations(self, b64images, quality):
865945
else:
866946
return b64images
867947

948+
def truncate_content(self, value, limit=1000):
949+
if isinstance(value, str): # Only truncate if the value is a string
950+
if len(value) > limit:
951+
return value[:limit] + '...'
952+
return value
953+
954+
def req_logging(self, obj, indent=1):
955+
if "__dict__" in dir(obj): # if value is an object, dive into it
956+
items = obj.__dict__.items()
957+
elif isinstance(obj, dict): # if value is a dictionary, get items
958+
items = obj.items()
959+
elif isinstance(obj, list): # if value is a list, enumerate items
960+
items = enumerate(obj)
961+
else: # if value is not an object or dict or list, just print it
962+
print(" " * indent + f"{self.truncate_content(obj)}")
963+
return
964+
965+
for attr, value in items:
966+
if value is None or value == {} or value == []:
967+
continue
968+
if isinstance(value, (list, dict)) or "__dict__" in dir(value):
969+
print(" " * indent + f"{attr}:")
970+
self.req_logging(value, indent + 1)
971+
else:
972+
print(" " * indent + f"{attr}: {self.truncate_content(value)}")
973+
868974
def invocations(self, req: models.InvocationsRequest):
869975
with self.invocations_lock:
870-
print('-------invocation------')
871-
print(req.task)
872-
976+
print("\n ----------------------------invocation--------------------------- ")
873977
try:
978+
print("")
979+
self.req_logging(req)
980+
except Exception as e:
981+
print("console Log ran into issue: ", e)
982+
# print(f"log@{datetime.datetime.now().strftime(f'%Y%m%d%H%M%S')} req in invocations: {req}")
983+
global user_input_data
984+
user_input_data = {}
985+
# if 'alwayson_scripts' in req:
986+
# if "user_input" in req.alwayson_scripts:
987+
# user_input_data = req.alwayson_scripts["user_input"]
988+
# req.alwayson_scripts.pop("user_input")
989+
if req.user_input != None:
990+
user_input_data = req.user_input
991+
# print(f"log@{datetime.datetime.now().strftime(f'%Y%m%d%H%M%S')} user_input processed in invocations")
992+
# req.pop('user_input', None)
993+
994+
try:
874995
if req.vae != None:
875996
shared.opts.data['sd_vae'] = req.vae
876997
refresh_vae_list()
@@ -906,6 +1027,15 @@ def invocations(self, req: models.InvocationsRequest):
9061027
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
9071028
response = self.text2imgapi(req.txt2img_payload)
9081029
response.images = self.post_invocations(response.images, quality)
1030+
response.parameters.clear()
1031+
oldinfo = json.loads(response.info)
1032+
if "all_prompts" in oldinfo:
1033+
oldinfo.pop("all_prompts", None)
1034+
if "all_negative_prompts" in oldinfo:
1035+
oldinfo.pop("all_negative_prompts", None)
1036+
if "infotexts" in oldinfo:
1037+
oldinfo.pop("infotexts", None)
1038+
response.info = json.dumps(oldinfo)
9091039
return response
9101040
elif req.task == 'image-to-image':
9111041
response = requests.get('http://0.0.0.0:8080/controlnet/model_list', params={'update': True})
@@ -916,10 +1046,51 @@ def invocations(self, req: models.InvocationsRequest):
9161046
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
9171047
response = self.img2imgapi(req.img2img_payload)
9181048
response.images = self.post_invocations(response.images, quality)
1049+
response.parameters.clear()
1050+
oldinfo = json.loads(response.info)
1051+
if "all_prompts" in oldinfo:
1052+
oldinfo.pop("all_prompts", None)
1053+
if "all_negative_prompts" in oldinfo:
1054+
oldinfo.pop("all_negative_prompts", None)
1055+
if "infotexts" in oldinfo:
1056+
oldinfo.pop("infotexts", None)
1057+
response.info = json.dumps(oldinfo)
9191058
return response
1059+
elif req.task == 'upscale_from_feed':
1060+
# only get the one image (in base64)
1061+
intermediate_image = self.img2imgapi(req.img2img_payload).images
1062+
print('finished intermediate img2img')
1063+
try:
1064+
# update the base64 image # note might need to change to req.extras_single_payload['image'] if this does not work
1065+
req.extras_single_payload.image = intermediate_image[0]
1066+
response = self.extras_single_image_api(req.extras_single_payload)
1067+
response.image = self.post_invocations([response.image], quality)[0]
1068+
response.parameters.clear()
1069+
oldinfo = json.loads(response.info)
1070+
if "all_prompts" in oldinfo:
1071+
oldinfo.pop("all_prompts", None)
1072+
if "all_negative_prompts" in oldinfo:
1073+
oldinfo.pop("all_negative_prompts", None)
1074+
if "infotexts" in oldinfo:
1075+
oldinfo.pop("infotexts", None)
1076+
response.info = json.dumps(oldinfo)
1077+
# print(f"log@{datetime.datetime.now().strftime(f'%Y%m%d%H%M%S')} ### get_cmd_flags is {self.get_cmd_flags()}")
1078+
return response
1079+
except Exception as e: # this is in fact obselete, because there will be a earlier return if OOM, won't reach here, but leaving here just in case
1080+
print(
1081+
f"An error occurred: {e}, step one upscale failed, reverting to just 4x upscale without Img2Img process")
9201082
elif req.task == 'extras-single-image':
9211083
response = self.extras_single_image_api(req.extras_single_payload)
9221084
response.image = self.post_invocations([response.image], quality)[0]
1085+
if "info" in response:
1086+
oldinfo = json.loads(response.info)
1087+
if "all_prompts" in oldinfo:
1088+
oldinfo.pop("all_prompts", None)
1089+
if "all_negative_prompts" in oldinfo:
1090+
oldinfo.pop("all_negative_prompts", None)
1091+
if "infotexts" in oldinfo:
1092+
oldinfo.pop("infotexts", None)
1093+
response.info = json.dumps(oldinfo)
9231094
return response
9241095
elif req.task == 'extras-batch-images':
9251096
response = self.extras_batch_images_api(req.extras_batch_payload)
@@ -928,6 +1099,30 @@ def invocations(self, req: models.InvocationsRequest):
9281099
elif req.task == 'interrogate':
9291100
response = self.interrogateapi(req.interrogate_payload)
9301101
return response
1102+
1103+
elif req.task == 'get-progress':
1104+
response = self.progressapi(req.progress_payload)
1105+
print(response)
1106+
return response
1107+
elif req.task == 'get-options':
1108+
response = self.get_config()
1109+
return response
1110+
elif req.task == 'get-SDmodels':
1111+
response = self.get_sd_models()
1112+
return response
1113+
elif req.task == 'get-upscalers':
1114+
response = self.get_upscalers()
1115+
return response
1116+
elif req.task == 'get-memory':
1117+
response = self.get_memory()
1118+
return response
1119+
elif req.task == 'get-cmd-flags':
1120+
response = self.get_cmd_flags()
1121+
return response
1122+
elif req.task == 'do-nothing':
1123+
print("nothing has happened")
1124+
return "nothing has happened"
1125+
9311126
elif req.task.startswith('/'):
9321127
if req.extra_payload:
9331128
response = requests.post(url=f'http://0.0.0.0:8080{req.task}', json=req.extra_payload)

modules/api/models.py

+3
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,9 @@ class InvocationsRequest(BaseModel):
326326
extras_batch_payload: Optional[ExtrasBatchImagesRequest]
327327
interrogate_payload: Optional[InterrogateRequest]
328328
extra_payload: Optional[dict]
329+
user_input: Optional[dict]
330+
progress_payload: Optional[ProgressRequest]
331+
cn_x3_image: Optional[str] # this is now obselete, but kept here in case of rollback
329332

330333

331334
class InvocationsErrorResponse(BaseModel):

modules/shared.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -159,8 +159,9 @@ def get_bucket_and_key(s3uri):
159159
def s3_download(s3uri, path):
160160
global cache
161161

162-
print('---path---', path)
163-
os.system(f'ls -l {os.path.dirname(path)}')
162+
#TODO: to delete
163+
# print('---path---', path)
164+
# os.system(f'ls -l {os.path.dirname(path)}')
164165

165166
pos = s3uri.find('/', 5)
166167
bucket = s3uri[5 : pos]

0 commit comments

Comments
 (0)