|
43 | 43 | import requests
|
44 | 44 | import json
|
45 | 45 | import uuid
|
| 46 | + |
| 47 | +from huggingface_hub import hf_hub_download |
| 48 | +import shutil |
| 49 | +import glob |
| 50 | + |
46 | 51 | if not cmd_opts.api:
|
47 | 52 | from extensions.sd_dreambooth_extension.dreambooth.db_config import DreamboothConfig
|
48 | 53 | from extensions.sd_dreambooth_extension.scripts.dreambooth import start_training_from_config, create_model
|
49 | 54 | from extensions.sd_dreambooth_extension.scripts.dreambooth import performance_wizard, training_wizard
|
50 | 55 | from extensions.sd_dreambooth_extension.dreambooth.db_concept import Concept
|
51 | 56 | from modules import paths
|
52 |
| -import glob |
| 57 | +elif not cmd_opts.pureui |
| 58 | + import requests |
| 59 | + cache = dict() |
| 60 | + s3_client = boto3.client('s3') |
| 61 | + s3_resource= boto3.resource('s3') |
| 62 | + |
| 63 | + def s3_download(s3uri, path): |
| 64 | + pos = s3uri.find('/', 5) |
| 65 | + bucket = s3uri[5 : pos] |
| 66 | + key = s3uri[pos + 1 : ] |
| 67 | + |
| 68 | + s3_bucket = s3_resource.Bucket(bucket) |
| 69 | + objs = list(s3_bucket.objects.filter(Prefix=key)) |
| 70 | + |
| 71 | + if os.path.isfile('cache'): |
| 72 | + cache = json.load(open('cache', 'r')) |
| 73 | + |
| 74 | + for obj in objs: |
| 75 | + if obj.key == key: |
| 76 | + continue |
| 77 | + response = s3_client.head_object( |
| 78 | + Bucket = bucket, |
| 79 | + Key = obj.key |
| 80 | + ) |
| 81 | + obj_key = 's3://{0}/{1}'.format(bucket, obj.key) |
| 82 | + if obj_key not in cache or cache[obj_key] != response['ETag']: |
| 83 | + filename = obj.key[obj.key.rfind('/') + 1 : ] |
| 84 | + |
| 85 | + s3_client.download_file(bucket, obj.key, os.path.join(path, filename)) |
| 86 | + cache[obj_key] = response['ETag'] |
| 87 | + |
| 88 | + json.dump(cache, open('cache', 'w')) |
| 89 | + |
| 90 | + def http_download(httpuri, path): |
| 91 | + with requests.get(httpuri, stream=True) as r: |
| 92 | + r.raise_for_status() |
| 93 | + with open(path, 'wb') as f: |
| 94 | + for chunk in r.iter_content(chunk_size=8192): |
| 95 | + f.write(chunk) |
53 | 96 |
|
54 | 97 | if cmd_opts.server_name:
|
55 | 98 | server_name = cmd_opts.server_name
|
@@ -194,6 +237,54 @@ def user_auth(username, password):
|
194 | 237 |
|
195 | 238 | def webui():
|
196 | 239 | launch_api = cmd_opts.api
|
| 240 | + |
| 241 | + if launch_api: |
| 242 | + models_config_s3uri = os.environ.get('models_config_s3uri', None) |
| 243 | + if models_config_s3uri: |
| 244 | + bucket, key = get_bucket_and_key(models_config_s3uri) |
| 245 | + s3_object = s3_client.get_object(Bucket=bucket, Key=key) |
| 246 | + bytes = s3_object["Body"].read() |
| 247 | + payload = bytes.decode('utf8') |
| 248 | + huggingface_models = json.loads(payload).get('huggingface_models', None) |
| 249 | + s3_models = json.loads(payload).get('s3_models', None) |
| 250 | + http_models = json.loads(payload).get('http_models', None) |
| 251 | + else: |
| 252 | + huggingface_models = os.environ.get('huggingface_models', None) |
| 253 | + s3_models = os.environ.get('s3_models', None) |
| 254 | + http_models = os.environ.get('http_models', None) |
| 255 | + |
| 256 | + if huggingface_models: |
| 257 | + huggingface_models = json.loads(huggingface_models) |
| 258 | + huggingface_token = huggingface_models['token'] |
| 259 | + os.system(f'huggingface-cli login --token {huggingface_token}') |
| 260 | + hf_hub_models = huggingface_models['models'] |
| 261 | + for huggingface_model in hf_hub_models: |
| 262 | + repo_id = huggingface_model['repo_id'] |
| 263 | + filename = huggingface_model['filename'] |
| 264 | + name = huggingface_model['name'] |
| 265 | + |
| 266 | + hf_hub_download( |
| 267 | + repo_id=repo_id, |
| 268 | + filename=filename, |
| 269 | + local_dir=f'/tmp/models/{name}', |
| 270 | + cache_dir='/tmp/cache/huggingface' |
| 271 | + ) |
| 272 | + |
| 273 | + if s3_models: |
| 274 | + s3_models = json.loads(s3_models) |
| 275 | + for s3_model in s3_models: |
| 276 | + uri = s3_model['uri'] |
| 277 | + name = s3_model['name'] |
| 278 | + s3_download(uri, f'/tmp/models/{name}') |
| 279 | + |
| 280 | + if http_models: |
| 281 | + http_models = json.loads(http_models) |
| 282 | + for http_model in http_models: |
| 283 | + uri = http_model['uri'] |
| 284 | + filename = http_model['filename'] |
| 285 | + name = http_model['name'] |
| 286 | + http_download(uri, f'/tmp/models/{name}/{filename}') |
| 287 | + |
197 | 288 | initialize()
|
198 | 289 |
|
199 | 290 | while 1:
|
|
0 commit comments