Skip to content

Commit 0cdb815

Browse files
committed
update webui.py
1 parent 91afed1 commit 0cdb815

File tree

1 file changed

+92
-1
lines changed

1 file changed

+92
-1
lines changed

webui.py

+92-1
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,56 @@
4343
import requests
4444
import json
4545
import uuid
46+
47+
from huggingface_hub import hf_hub_download
48+
import shutil
49+
import glob
50+
4651
if not cmd_opts.api:
4752
from extensions.sd_dreambooth_extension.dreambooth.db_config import DreamboothConfig
4853
from extensions.sd_dreambooth_extension.scripts.dreambooth import start_training_from_config, create_model
4954
from extensions.sd_dreambooth_extension.scripts.dreambooth import performance_wizard, training_wizard
5055
from extensions.sd_dreambooth_extension.dreambooth.db_concept import Concept
5156
from modules import paths
52-
import glob
57+
elif not cmd_opts.pureui
58+
import requests
59+
cache = dict()
60+
s3_client = boto3.client('s3')
61+
s3_resource= boto3.resource('s3')
62+
63+
def s3_download(s3uri, path):
64+
pos = s3uri.find('/', 5)
65+
bucket = s3uri[5 : pos]
66+
key = s3uri[pos + 1 : ]
67+
68+
s3_bucket = s3_resource.Bucket(bucket)
69+
objs = list(s3_bucket.objects.filter(Prefix=key))
70+
71+
if os.path.isfile('cache'):
72+
cache = json.load(open('cache', 'r'))
73+
74+
for obj in objs:
75+
if obj.key == key:
76+
continue
77+
response = s3_client.head_object(
78+
Bucket = bucket,
79+
Key = obj.key
80+
)
81+
obj_key = 's3://{0}/{1}'.format(bucket, obj.key)
82+
if obj_key not in cache or cache[obj_key] != response['ETag']:
83+
filename = obj.key[obj.key.rfind('/') + 1 : ]
84+
85+
s3_client.download_file(bucket, obj.key, os.path.join(path, filename))
86+
cache[obj_key] = response['ETag']
87+
88+
json.dump(cache, open('cache', 'w'))
89+
90+
def http_download(httpuri, path):
91+
with requests.get(httpuri, stream=True) as r:
92+
r.raise_for_status()
93+
with open(path, 'wb') as f:
94+
for chunk in r.iter_content(chunk_size=8192):
95+
f.write(chunk)
5396

5497
if cmd_opts.server_name:
5598
server_name = cmd_opts.server_name
@@ -194,6 +237,54 @@ def user_auth(username, password):
194237

195238
def webui():
196239
launch_api = cmd_opts.api
240+
241+
if launch_api:
242+
models_config_s3uri = os.environ.get('models_config_s3uri', None)
243+
if models_config_s3uri:
244+
bucket, key = get_bucket_and_key(models_config_s3uri)
245+
s3_object = s3_client.get_object(Bucket=bucket, Key=key)
246+
bytes = s3_object["Body"].read()
247+
payload = bytes.decode('utf8')
248+
huggingface_models = json.loads(payload).get('huggingface_models', None)
249+
s3_models = json.loads(payload).get('s3_models', None)
250+
http_models = json.loads(payload).get('http_models', None)
251+
else:
252+
huggingface_models = os.environ.get('huggingface_models', None)
253+
s3_models = os.environ.get('s3_models', None)
254+
http_models = os.environ.get('http_models', None)
255+
256+
if huggingface_models:
257+
huggingface_models = json.loads(huggingface_models)
258+
huggingface_token = huggingface_models['token']
259+
os.system(f'huggingface-cli login --token {huggingface_token}')
260+
hf_hub_models = huggingface_models['models']
261+
for huggingface_model in hf_hub_models:
262+
repo_id = huggingface_model['repo_id']
263+
filename = huggingface_model['filename']
264+
name = huggingface_model['name']
265+
266+
hf_hub_download(
267+
repo_id=repo_id,
268+
filename=filename,
269+
local_dir=f'/tmp/models/{name}',
270+
cache_dir='/tmp/cache/huggingface'
271+
)
272+
273+
if s3_models:
274+
s3_models = json.loads(s3_models)
275+
for s3_model in s3_models:
276+
uri = s3_model['uri']
277+
name = s3_model['name']
278+
s3_download(uri, f'/tmp/models/{name}')
279+
280+
if http_models:
281+
http_models = json.loads(http_models)
282+
for http_model in http_models:
283+
uri = http_model['uri']
284+
filename = http_model['filename']
285+
name = http_model['name']
286+
http_download(uri, f'/tmp/models/{name}/{filename}')
287+
197288
initialize()
198289

199290
while 1:

0 commit comments

Comments
 (0)