|
36 | 36 | from modules.shared import cmd_opts, opts
|
37 | 37 | import modules.hypernetworks.hypernetwork
|
38 | 38 | import boto3
|
| 39 | +import threading |
| 40 | +import time |
| 41 | + |
39 | 42 | import traceback
|
40 | 43 | from botocore.exceptions import ClientError
|
41 | 44 | import requests
|
@@ -64,6 +67,21 @@ def initialize():
|
64 | 67 | modules.scripts.load_scripts()
|
65 | 68 | return
|
66 | 69 |
|
| 70 | + ## auto reload new models from s3 add by River |
| 71 | + sd_models_tmp_dir = "/opt/ml/code/stable-diffusion-webui/models/Stable-diffusion/" |
| 72 | + cn_models_tmp_dir = "/opt/ml/code/stable-diffusion-webui/models/ControlNet/" |
| 73 | + session = boto3.Session() |
| 74 | + region_name = session.region_name |
| 75 | + sts_client = session.client('sts') |
| 76 | + account_id = sts_client.get_caller_identity()['Account'] |
| 77 | + sg_defaul_bucket_name = f"sagemaker-{region_name}-{account_id}" |
| 78 | + s3_folder_sd = "stable-diffusion-webui/models/Stable-diffusion" |
| 79 | + s3_folder_cn = "stable-diffusion-webui/models/ControlNet" |
| 80 | + |
| 81 | + sync_s3_folder(sg_defaul_bucket_name,s3_folder_sd,sd_models_tmp_dir,'sd') |
| 82 | + sync_s3_folder(sg_defaul_bucket_name,s3_folder_cn,cn_models_tmp_dir,'cn') |
| 83 | + ## end |
| 84 | + |
67 | 85 | modelloader.cleanup_models()
|
68 | 86 | modules.sd_models.setup_model()
|
69 | 87 | codeformer.setup_model(cmd_opts.codeformer_models_path)
|
@@ -182,6 +200,114 @@ def user_auth(username, password):
|
182 | 200 |
|
183 | 201 | return response.status_code == 200
|
184 | 202 |
|
| 203 | + |
| 204 | +def register_sd_models(sd_models_dir): |
| 205 | + print ('---register_sd_models()----') |
| 206 | + if 'endpoint_name' in os.environ: |
| 207 | + items = [] |
| 208 | + api_endpoint = os.environ['api_endpoint'] |
| 209 | + endpoint_name = os.environ['endpoint_name'] |
| 210 | + print(f'api_endpoint:{api_endpoint}\nendpoint_name:{endpoint_name}') |
| 211 | + for file in os.listdir(sd_models_dir): |
| 212 | + if os.path.isfile(os.path.join(sd_models_dir, file)) and (file.endswith('.ckpt') or file.endswith('.safetensors')): |
| 213 | + hash = modules.sd_models.model_hash(os.path.join(sd_models_dir, file)) |
| 214 | + item = {} |
| 215 | + item['model_name'] = file |
| 216 | + item['config'] = '/opt/ml/code/stable-diffusion-webui/repositories/stable-diffusion/configs/stable-diffusion/v1-inference.yaml' |
| 217 | + item['filename'] = '/opt/ml/code/stable-diffusion-webui/models/Stable-diffusion/{0}'.format(file) |
| 218 | + item['hash'] = hash |
| 219 | + item['title'] = '{0} [{1}]'.format(file, hash) |
| 220 | + item['endpoint_name'] = endpoint_name |
| 221 | + items.append(item) |
| 222 | + inputs = { |
| 223 | + 'items': items |
| 224 | + } |
| 225 | + params = { |
| 226 | + 'module': 'Stable-diffusion' |
| 227 | + } |
| 228 | + if api_endpoint.startswith('http://') or api_endpoint.startswith('https://'): |
| 229 | + response = requests.post(url=f'{api_endpoint}/sd/models', json=inputs, params=params) |
| 230 | + print(response) |
| 231 | + |
| 232 | +def register_cn_models(cn_models_dir): |
| 233 | + print ('---register_cn_models()----') |
| 234 | + if 'endpoint_name' in os.environ: |
| 235 | + items = [] |
| 236 | + api_endpoint = os.environ['api_endpoint'] |
| 237 | + endpoint_name = os.environ['endpoint_name'] |
| 238 | + print(f'api_endpoint:{api_endpoint}\nendpoint_name:{endpoint_name}') |
| 239 | + |
| 240 | + inputs = { |
| 241 | + 'items': items |
| 242 | + } |
| 243 | + params = { |
| 244 | + 'module': 'ControlNet' |
| 245 | + } |
| 246 | + for file in os.listdir(cn_models_dir): |
| 247 | + if os.path.isfile(os.path.join(cn_models_dir, file)) and \ |
| 248 | + (file.endswith('.pt') or file.endswith('.pth') or file.endswith('.ckpt') or file.endswith('.safetensors')): |
| 249 | + hash = modules.sd_models.model_hash(os.path.join(cn_models_dir, file)) |
| 250 | + item = {} |
| 251 | + item['model_name'] = file |
| 252 | + item['title'] = '{0} [{1}]'.format(os.path.splitext(file)[0], hash) |
| 253 | + item['endpoint_name'] = endpoint_name |
| 254 | + items.append(item) |
| 255 | + |
| 256 | + if api_endpoint.startswith('http://') or api_endpoint.startswith('https://'): |
| 257 | + response = requests.post(url=f'{api_endpoint}/sd/models', json=inputs, params=params) |
| 258 | + print(response) |
| 259 | + |
| 260 | + |
| 261 | +def sync_s3_folder(bucket_name, s3_folder, local_folder,mode): |
| 262 | + print(f"sync S3 bucket '{bucket_name}', folder '{s3_folder}' for new files...") |
| 263 | + # Create tmp folders |
| 264 | + os.makedirs(os.path.dirname(local_folder), exist_ok=True) |
| 265 | + print(f'create dir: {os.path.dirname(local_folder)}') |
| 266 | + # Create an S3 client |
| 267 | + s3 = boto3.client('s3') |
| 268 | + def sync(): |
| 269 | + # List all objects in the S3 folder |
| 270 | + response = s3.list_objects_v2(Bucket=bucket_name, Prefix=s3_folder) |
| 271 | + # Check if there are any new or deleted files |
| 272 | + s3_files = set() |
| 273 | + for obj in response.get('Contents', []): |
| 274 | + s3_files.add(obj['Key'].replace(s3_folder, '').lstrip('/')) |
| 275 | + |
| 276 | + local_files = set(os.listdir(local_folder)) |
| 277 | + |
| 278 | + new_files = s3_files - local_files |
| 279 | + del_files = local_files - s3_files |
| 280 | + |
| 281 | + # Copy new files to local folder |
| 282 | + for file in new_files: |
| 283 | + s3.download_file(bucket_name, s3_folder + '/' + file, os.path.join(local_folder, file)) |
| 284 | + print(f'download_file:from {bucket_name}/{s3_folder}/{file} to {os.path.join(local_folder, file)}') |
| 285 | + |
| 286 | + # Delete vanished files from local folder |
| 287 | + for file in del_files: |
| 288 | + os.remove(os.path.join(local_folder, file)) |
| 289 | + print(f'remove file {os.path.join(local_folder, file)}') |
| 290 | + # If there are changes |
| 291 | + if len(new_files) | len(del_files): |
| 292 | + if mode == 'sd': |
| 293 | + register_sd_models(local_folder) |
| 294 | + elif mode == 'cn': |
| 295 | + register_cn_models(local_folder) |
| 296 | + else: |
| 297 | + print(f'unsupported mode:{mode}') |
| 298 | + # Create a thread function to keep syncing with the S3 folder |
| 299 | + def sync_thread(): |
| 300 | + while True: |
| 301 | + sync() |
| 302 | + time.sleep(60) |
| 303 | + # Initialize at launch |
| 304 | + sync() |
| 305 | + # Start the thread |
| 306 | + thread = threading.Thread(target=sync_thread) |
| 307 | + thread.start() |
| 308 | + return thread |
| 309 | + |
| 310 | + |
185 | 311 | def webui():
|
186 | 312 | launch_api = cmd_opts.api
|
187 | 313 | initialize()
|
@@ -218,7 +344,7 @@ def webui():
|
218 | 344 |
|
219 | 345 | if launch_api:
|
220 | 346 | create_api(app)
|
221 |
| - |
| 347 | + |
222 | 348 | cmd_sd_models_path = cmd_opts.ckpt_dir
|
223 | 349 | sd_models_dir = os.path.join(shared.models_path, "Stable-diffusion")
|
224 | 350 | if cmd_sd_models_path is not None:
|
@@ -274,7 +400,6 @@ def webui():
|
274 | 400 | if api_endpoint.startswith('http://') or api_endpoint.startswith('https://'):
|
275 | 401 | response = requests.post(url=f'{api_endpoint}/sd/models', json=inputs, params=params)
|
276 | 402 | print(response)
|
277 |
| - |
278 | 403 | modules.script_callbacks.app_started_callback(shared.demo, app)
|
279 | 404 |
|
280 | 405 | wait_on_server(shared.demo)
|
|
0 commit comments