Skip to content

Commit 37bb2b1

Browse files
author
xie river
committed
model file loaded dynamically from s3
1 parent 7146b33 commit 37bb2b1

File tree

5 files changed

+161
-11
lines changed

5 files changed

+161
-11
lines changed

localizations/zh_CN.json

+1-1
Original file line numberDiff line numberDiff line change
@@ -840,7 +840,7 @@
840840
"Save Preview(s) Frequency (Epochs)": "保存预览频率 (Epochs)",
841841
"A generic prompt used to generate a sample image to verify model fidelity.": "用于生成样本图像以验证模型保真度的通用提示。",
842842
"Job detail":"训练任务详情",
843-
"S3 bucket name for uploading train images":"上传训练图片集的S3桶名",
843+
"S3 bucket name for uploading/downloading images":"上传训练图片集或者下载生成图片的S3桶名",
844844
"Output S3 folder":"S3文件夹目录",
845845
"Upload Train Images to S3":"上传训练图片到S3",
846846
"Error, please configure a S3 bucket at settings page first":"失败,请先到设置页面配置S3桶名",

modules/api/api.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -457,7 +457,7 @@ def invocations(self, req: InvocationsRequest):
457457
traceback.print_exc()
458458

459459
def ping(self):
460-
print('-------ping------')
460+
# print('-------ping------')
461461
return {'status': 'Healthy'}
462462

463463
def launch(self, server_name, port):

modules/shared.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -423,7 +423,7 @@ def refresh_sagemaker_endpoints(username):
423423
}))
424424

425425
options_templates.update(options_section(('saving-paths', "Paths for saving"), {
426-
"train_files_s3bucket":OptionInfo("","S3 bucket name for uploading train images",component_args=hide_dirs),
426+
"train_files_s3bucket":OptionInfo("","S3 bucket name for uploading/downloading images",component_args=hide_dirs),
427427
"outdir_samples": OptionInfo("", "Output directory for images; if empty, defaults to three directories below", component_args=hide_dirs),
428428
"outdir_txt2img_samples": OptionInfo("outputs/txt2img-images", 'Output directory for txt2img images', component_args=hide_dirs),
429429
"outdir_img2img_samples": OptionInfo("outputs/img2img-images", 'Output directory for img2img images', component_args=hide_dirs),

modules/ui.py

+31-6
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,29 @@
8686
def gr_show(visible=True):
8787
return {"visible": visible, "__type__": "update"}
8888

89+
## Begin output images uploaded to s3 by River
90+
s3_resource = boto3.resource('s3')
91+
92+
def save_images_to_s3(full_fillnames,timestamp):
93+
username = shared.username
94+
sagemaker_endpoint = shared.opts.sagemaker_endpoint
95+
bucket_name = opts.train_files_s3bucket
96+
if bucket_name == '':
97+
return 'Error, please configure a S3 bucket at settings page first'
98+
s3_bucket = s3_resource.Bucket(bucket_name)
99+
folder_name = f"output-images/{username}/{sagemaker_endpoint}/{timestamp}"
100+
try:
101+
for i, fname in enumerate(full_fillnames):
102+
filename = fname.split('/')[-1]
103+
object_name = f"{folder_name}/{filename}"
104+
s3_bucket.upload_file(fname,object_name)
105+
print (f'upload file [{i}]:{filename} to s3://{bucket_name}/{object_name}')
106+
except ClientError as e:
107+
print(e)
108+
return e
109+
return f"s3://{bucket_name}/{folder_name}"
110+
## End output images uploaded to s3 by River
111+
89112

90113
sample_img2img = "assets/stable-samples/img2img/sketch-mountains-input.jpg"
91114
sample_img2img = sample_img2img if os.path.exists(sample_img2img) else None
@@ -147,7 +170,7 @@ def __init__(self, d=None):
147170

148171
os.makedirs(opts.outdir_save, exist_ok=True)
149172

150-
with open(os.path.join(opts.outdir_save, "log.csv"), "a", encoding="utf8", newline='') as file:
173+
with open(os.path.join(opts.outdir_save, "log.csv"), "w", encoding="utf8", newline='') as file:
151174
at_start = file.tell() == 0
152175
writer = csv.writer(file)
153176
if at_start:
@@ -163,16 +186,19 @@ def __init__(self, d=None):
163186
break
164187

165188
fullfn, txt_fullfn = save_image(image, path, "", seed=p.all_seeds[i], prompt=p.all_prompts[i], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs)
166-
167189
filename = os.path.relpath(fullfn, path)
190+
print(f'fullfn:{fullfn},\n txt_fullfn:{txt_fullfn} \nfilename:{filename}')
168191
filenames.append(filename)
169192
fullfns.append(fullfn)
170193
if txt_fullfn:
171194
filenames.append(os.path.basename(txt_fullfn))
172195
fullfns.append(txt_fullfn)
173-
174196
writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]])
175-
197+
198+
timestamp = datetime.now(timezone(timedelta(hours=+8))).strftime('%Y-%m-%dT%H:%M:%S')
199+
logfile = os.path.join(opts.outdir_save, "log.csv")
200+
s3folder = save_images_to_s3(fullfns,timestamp)
201+
save_images_to_s3([logfile],timestamp)
176202
# Make Zip
177203
if do_make_zip:
178204
zip_filepath = os.path.join(path, "images.zip")
@@ -184,7 +210,7 @@ def __init__(self, d=None):
184210
zip_file.writestr(filenames[i], f.read())
185211
fullfns.insert(0, zip_filepath)
186212

187-
return gr.File.update(value=fullfns, visible=True), '', '', plaintext_to_html(f"Saved: {filenames[0]}")
213+
return gr.File.update(value=fullfns, visible=True), '', '', plaintext_to_html(f"Saved: {filenames[0]}, \nS3 folder:\n{s3folder}")
188214

189215

190216

@@ -1466,7 +1492,6 @@ def update_orig(image, state):
14661492
with gr.Row().style(equal_height=False):
14671493
with gr.Tabs(elem_id="train_tabs"):
14681494
## Begin add s3 images upload interface by River
1469-
s3_resource = boto3.resource('s3')
14701495
def upload_to_s3(imgs):
14711496
username = shared.username
14721497
timestamp = datetime.now(timezone(timedelta(hours=+8))).strftime('%Y-%m-%dT%H:%M:%S')

webui.py

+127-2
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@
3636
from modules.shared import cmd_opts, opts
3737
import modules.hypernetworks.hypernetwork
3838
import boto3
39+
import threading
40+
import time
41+
3942
import traceback
4043
from botocore.exceptions import ClientError
4144
import requests
@@ -64,6 +67,21 @@ def initialize():
6467
modules.scripts.load_scripts()
6568
return
6669

70+
## auto reload new models from s3 add by River
71+
sd_models_tmp_dir = "/opt/ml/code/stable-diffusion-webui/models/Stable-diffusion/"
72+
cn_models_tmp_dir = "/opt/ml/code/stable-diffusion-webui/models/ControlNet/"
73+
session = boto3.Session()
74+
region_name = session.region_name
75+
sts_client = session.client('sts')
76+
account_id = sts_client.get_caller_identity()['Account']
77+
sg_defaul_bucket_name = f"sagemaker-{region_name}-{account_id}"
78+
s3_folder_sd = "stable-diffusion-webui/models/Stable-diffusion"
79+
s3_folder_cn = "stable-diffusion-webui/models/ControlNet"
80+
81+
sync_s3_folder(sg_defaul_bucket_name,s3_folder_sd,sd_models_tmp_dir,'sd')
82+
sync_s3_folder(sg_defaul_bucket_name,s3_folder_cn,cn_models_tmp_dir,'cn')
83+
## end
84+
6785
modelloader.cleanup_models()
6886
modules.sd_models.setup_model()
6987
codeformer.setup_model(cmd_opts.codeformer_models_path)
@@ -182,6 +200,114 @@ def user_auth(username, password):
182200

183201
return response.status_code == 200
184202

203+
204+
def register_sd_models(sd_models_dir):
205+
print ('---register_sd_models()----')
206+
if 'endpoint_name' in os.environ:
207+
items = []
208+
api_endpoint = os.environ['api_endpoint']
209+
endpoint_name = os.environ['endpoint_name']
210+
print(f'api_endpoint:{api_endpoint}\nendpoint_name:{endpoint_name}')
211+
for file in os.listdir(sd_models_dir):
212+
if os.path.isfile(os.path.join(sd_models_dir, file)) and (file.endswith('.ckpt') or file.endswith('.safetensors')):
213+
hash = modules.sd_models.model_hash(os.path.join(sd_models_dir, file))
214+
item = {}
215+
item['model_name'] = file
216+
item['config'] = '/opt/ml/code/stable-diffusion-webui/repositories/stable-diffusion/configs/stable-diffusion/v1-inference.yaml'
217+
item['filename'] = '/opt/ml/code/stable-diffusion-webui/models/Stable-diffusion/{0}'.format(file)
218+
item['hash'] = hash
219+
item['title'] = '{0} [{1}]'.format(file, hash)
220+
item['endpoint_name'] = endpoint_name
221+
items.append(item)
222+
inputs = {
223+
'items': items
224+
}
225+
params = {
226+
'module': 'Stable-diffusion'
227+
}
228+
if api_endpoint.startswith('http://') or api_endpoint.startswith('https://'):
229+
response = requests.post(url=f'{api_endpoint}/sd/models', json=inputs, params=params)
230+
print(response)
231+
232+
def register_cn_models(cn_models_dir):
233+
print ('---register_cn_models()----')
234+
if 'endpoint_name' in os.environ:
235+
items = []
236+
api_endpoint = os.environ['api_endpoint']
237+
endpoint_name = os.environ['endpoint_name']
238+
print(f'api_endpoint:{api_endpoint}\nendpoint_name:{endpoint_name}')
239+
240+
inputs = {
241+
'items': items
242+
}
243+
params = {
244+
'module': 'ControlNet'
245+
}
246+
for file in os.listdir(cn_models_dir):
247+
if os.path.isfile(os.path.join(cn_models_dir, file)) and \
248+
(file.endswith('.pt') or file.endswith('.pth') or file.endswith('.ckpt') or file.endswith('.safetensors')):
249+
hash = modules.sd_models.model_hash(os.path.join(cn_models_dir, file))
250+
item = {}
251+
item['model_name'] = file
252+
item['title'] = '{0} [{1}]'.format(os.path.splitext(file)[0], hash)
253+
item['endpoint_name'] = endpoint_name
254+
items.append(item)
255+
256+
if api_endpoint.startswith('http://') or api_endpoint.startswith('https://'):
257+
response = requests.post(url=f'{api_endpoint}/sd/models', json=inputs, params=params)
258+
print(response)
259+
260+
261+
def sync_s3_folder(bucket_name, s3_folder, local_folder,mode):
262+
print(f"sync S3 bucket '{bucket_name}', folder '{s3_folder}' for new files...")
263+
# Create tmp folders
264+
os.makedirs(os.path.dirname(local_folder), exist_ok=True)
265+
print(f'create dir: {os.path.dirname(local_folder)}')
266+
# Create an S3 client
267+
s3 = boto3.client('s3')
268+
def sync():
269+
# List all objects in the S3 folder
270+
response = s3.list_objects_v2(Bucket=bucket_name, Prefix=s3_folder)
271+
# Check if there are any new or deleted files
272+
s3_files = set()
273+
for obj in response.get('Contents', []):
274+
s3_files.add(obj['Key'].replace(s3_folder, '').lstrip('/'))
275+
276+
local_files = set(os.listdir(local_folder))
277+
278+
new_files = s3_files - local_files
279+
del_files = local_files - s3_files
280+
281+
# Copy new files to local folder
282+
for file in new_files:
283+
s3.download_file(bucket_name, s3_folder + '/' + file, os.path.join(local_folder, file))
284+
print(f'download_file:from {bucket_name}/{s3_folder}/{file} to {os.path.join(local_folder, file)}')
285+
286+
# Delete vanished files from local folder
287+
for file in del_files:
288+
os.remove(os.path.join(local_folder, file))
289+
print(f'remove file {os.path.join(local_folder, file)}')
290+
# If there are changes
291+
if len(new_files) | len(del_files):
292+
if mode == 'sd':
293+
register_sd_models(local_folder)
294+
elif mode == 'cn':
295+
register_cn_models(local_folder)
296+
else:
297+
print(f'unsupported mode:{mode}')
298+
# Create a thread function to keep syncing with the S3 folder
299+
def sync_thread():
300+
while True:
301+
sync()
302+
time.sleep(60)
303+
# Initialize at launch
304+
sync()
305+
# Start the thread
306+
thread = threading.Thread(target=sync_thread)
307+
thread.start()
308+
return thread
309+
310+
185311
def webui():
186312
launch_api = cmd_opts.api
187313
initialize()
@@ -218,7 +344,7 @@ def webui():
218344

219345
if launch_api:
220346
create_api(app)
221-
347+
222348
cmd_sd_models_path = cmd_opts.ckpt_dir
223349
sd_models_dir = os.path.join(shared.models_path, "Stable-diffusion")
224350
if cmd_sd_models_path is not None:
@@ -274,7 +400,6 @@ def webui():
274400
if api_endpoint.startswith('http://') or api_endpoint.startswith('https://'):
275401
response = requests.post(url=f'{api_endpoint}/sd/models', json=inputs, params=params)
276402
print(response)
277-
278403
modules.script_callbacks.app_started_callback(shared.demo, app)
279404

280405
wait_on_server(shared.demo)

0 commit comments

Comments
 (0)