Skip to content

Commit f95f89c

Browse files
committed
cleanup
1 parent 71169e0 commit f95f89c

File tree

4 files changed

+57
-94
lines changed

4 files changed

+57
-94
lines changed

modules/api/api.py

+6-46
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,6 @@
3737
import uuid
3838
import os
3939
import json
40-
import boto3
41-
cache = dict()
42-
s3_client = boto3.client('s3')
43-
s3_resource= boto3.resource('s3')
44-
generated_images_s3uri = os.environ.get('generated_images_s3uri', None)
4540

4641
def upscaler_to_index(name: str):
4742
try:
@@ -710,8 +705,8 @@ def get_memory(self):
710705
return MemoryResponse(ram = ram, cuda = cuda)
711706

712707
def post_invocations(self, b64images, quality):
713-
if generated_images_s3uri:
714-
bucket, key = self.get_bucket_and_key(generated_images_s3uri)
708+
if shared.generated_images_s3uri:
709+
bucket, key = shared.get_bucket_and_key(shared.generated_images_s3uri)
715710
images = []
716711
for b64image in b64images:
717712
image = decode_base64_to_image(b64image).convert('RGB')
@@ -726,7 +721,7 @@ def post_invocations(self, b64images, quality):
726721
image.save(output, format='PNG', quality=95)
727722

728723
image_id = str(uuid.uuid4())
729-
s3_client.put_object(
724+
shared.s3_client.put_object(
730725
Body=output.getvalue(),
731726
Bucket=bucket,
732727
Key=f'{key}/{image_id}.png'
@@ -759,7 +754,7 @@ def invocations(self, req: InvocationsRequest):
759754
hypernetwork_s3uri = shared.cmd_opts.hypernetwork_s3uri
760755

761756
if hypernetwork_s3uri !='':
762-
self.download_s3files(hypernetwork_s3uri, shared.cmd_opts.hypernetwork_dir)
757+
shared.download_s3files(hypernetwork_s3uri, shared.cmd_opts.hypernetwork_dir)
763758
shared.reload_hypernetworks()
764759

765760
if req.options != None:
@@ -769,14 +764,14 @@ def invocations(self, req: InvocationsRequest):
769764

770765
if req.task == 'text-to-image':
771766
if embeddings_s3uri != '':
772-
self.download_s3files(embeddings_s3uri, shared.cmd_opts.embeddings_dir)
767+
shared.download_s3files(embeddings_s3uri, shared.cmd_opts.embeddings_dir)
773768
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
774769
response = self.text2imgapi(req.txt2img_payload)
775770
response.images = self.post_invocations(response.images, quality)
776771
return response
777772
elif req.task == 'image-to-image':
778773
if embeddings_s3uri != '':
779-
self.download_s3files(embeddings_s3uri, shared.cmd_opts.embeddings_dir)
774+
shared.download_s3files(embeddings_s3uri, shared.cmd_opts.embeddings_dir)
780775
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
781776
response = self.img2imgapi(req.img2img_payload)
782777
response.images = self.post_invocations(response.images, quality)
@@ -803,38 +798,3 @@ def ping(self):
803798
def launch(self, server_name, port):
804799
self.app.include_router(self.router)
805800
uvicorn.run(self.app, host=server_name, port=port)
806-
807-
def get_bucket_and_key(self, s3uri):
808-
pos = s3uri.find('/', 5)
809-
bucket = s3uri[5 : pos]
810-
key = s3uri[pos + 1 : ]
811-
return bucket, key
812-
813-
def download_s3files(self, s3uri, path):
814-
global cache
815-
816-
pos = s3uri.find('/', 5)
817-
bucket = s3uri[5 : pos]
818-
key = s3uri[pos + 1 : ]
819-
820-
s3_bucket = s3_resource.Bucket(bucket)
821-
objs = list(s3_bucket.objects.filter(Prefix=key))
822-
823-
if os.path.isfile('cache'):
824-
cache = json.load(open('cache', 'r'))
825-
826-
for obj in objs:
827-
if obj.key == key:
828-
continue
829-
response = s3_client.head_object(
830-
Bucket = bucket,
831-
Key = obj.key
832-
)
833-
obj_key = 's3://{0}/{1}'.format(bucket, obj.key)
834-
if obj_key not in cache or cache[obj_key] != response['ETag']:
835-
filename = obj.key[obj.key.rfind('/') + 1 : ]
836-
837-
s3_client.download_file(bucket, obj.key, os.path.join(path, filename))
838-
cache[obj_key] = response['ETag']
839-
840-
json.dump(cache, open('cache', 'w'))

modules/cmd_args.py

+1
Original file line numberDiff line numberDiff line change
@@ -109,3 +109,4 @@
109109
parser.add_argument('--dreambooth-config-id', default='', type=str, help='Dreambooth config ID')
110110
parser.add_argument('--embeddings-s3uri', default='', type=str, help='Embedding S3Uri')
111111
parser.add_argument('--hypernetwork-s3uri', default='', type=str, help='Hypernetwork S3Uri')
112+
parser.add_argument('--region-name', default='', type=str, help='Region name')

modules/shared.py

+46
Original file line numberDiff line numberDiff line change
@@ -657,3 +657,49 @@ def html(filename):
657657
return file.read()
658658

659659
return ""
660+
661+
import boto3
662+
import requests
663+
664+
cache = dict()
665+
region_name = boto3.session.Session().region_name if not cmd_opts.train else cmd_opts.region_name
666+
s3_client = boto3.client('s3', region_name=region_name)
667+
endpointUrl = s3_client.meta.endpoint_url
668+
s3_client = boto3.client('s3', endpoint_url=endpointUrl, region_name=region_name)
669+
s3_resource= boto3.resource('s3')
670+
generated_images_s3uri = os.environ.get('generated_images_s3uri', None)
671+
672+
def get_bucket_and_key(s3uri):
673+
pos = s3uri.find('/', 5)
674+
bucket = s3uri[5 : pos]
675+
key = s3uri[pos + 1 : ]
676+
return bucket, key
677+
678+
def s3_download(s3uri, path):
679+
global cache
680+
681+
pos = s3uri.find('/', 5)
682+
bucket = s3uri[5 : pos]
683+
key = s3uri[pos + 1 : ]
684+
685+
if os.path.isfile('cache'):
686+
cache = json.load(open('cache', 'r'))
687+
688+
response = s3_client.head_object(
689+
Bucket=bucket,
690+
Key=key
691+
)
692+
if key not in cache or cache[key] != response['ETag']:
693+
filename = key[key.rfind('/') + 1 : ]
694+
695+
s3_client.download_file(bucket, key, os.path.join(path, filename))
696+
cache[key] = response['ETag']
697+
698+
json.dump(cache, open('cache', 'w'))
699+
700+
def http_download(httpuri, path):
701+
with requests.get(httpuri, stream=True) as r:
702+
r.raise_for_status()
703+
with open(path, 'wb') as f:
704+
for chunk in r.iter_content(chunk_size=8192):
705+
f.write(chunk)

webui.py

+4-48
Original file line numberDiff line numberDiff line change
@@ -72,14 +72,6 @@
7272
from extensions.sd_dreambooth_extension.dreambooth.db_concept import Concept
7373
from modules import paths
7474
import glob
75-
else:
76-
import requests
77-
cache = dict()
78-
region_name = boto3.session.Session().region_name
79-
s3_client = boto3.client('s3', region_name=region_name)
80-
endpointUrl = s3_client.meta.endpoint_url
81-
s3_client = boto3.client('s3', endpoint_url=endpointUrl, region_name=region_name)
82-
s3_resource= boto3.resource('s3')
8375

8476
startup_timer.record("other imports")
8577

@@ -258,8 +250,8 @@ def webui():
258250
if launch_api:
259251
models_config_s3uri = os.environ.get('models_config_s3uri', None)
260252
if models_config_s3uri:
261-
bucket, key = get_bucket_and_key(models_config_s3uri)
262-
s3_object = s3_client.get_object(Bucket=bucket, Key=key)
253+
bucket, key = shared.get_bucket_and_key(models_config_s3uri)
254+
s3_object = shared.s3_client.get_object(Bucket=bucket, Key=key)
263255
bytes = s3_object["Body"].read()
264256
payload = bytes.decode('utf8')
265257
huggingface_models = json.loads(payload).get('huggingface_models', None)
@@ -290,14 +282,14 @@ def webui():
290282
for s3_model in s3_models:
291283
uri = s3_model['uri']
292284
name = s3_model['name']
293-
s3_download(uri, f'/tmp/models/{name}')
285+
shared.s3_download(uri, f'/tmp/models/{name}')
294286

295287
if http_models:
296288
for http_model in http_models:
297289
uri = http_model['uri']
298290
filename = http_model['filename']
299291
name = http_model['name']
300-
http_download(uri, f'/tmp/models/{name}/{filename}')
292+
shared.http_download(uri, f'/tmp/models/{name}/{filename}')
301293

302294
initialize()
303295

@@ -620,7 +612,6 @@ def train():
620612
)
621613
os.makedirs(os.path.dirname("/opt/ml/model/"), exist_ok=True)
622614
os.makedirs(os.path.dirname("/opt/ml/model/Stable-diffusion/"), exist_ok=True)
623-
os.makedirs(os.path.dirname("/opt/ml/model/ControlNet/"), exist_ok=True)
624615
train_steps=int(db_config.revision)
625616
model_file_basename = f'{db_model_name}_{train_steps}_lora' if db_config.use_lora else f'{db_model_name}_{train_steps}'
626617
f1=os.path.join(sd_models_dir, db_model_name, f'{model_file_basename}.yaml')
@@ -637,41 +628,6 @@ def train():
637628
except Exception as e:
638629
traceback.print_exc()
639630
print(e)
640-
else:
641-
def get_bucket_and_key(s3uri):
642-
pos = s3uri.find('/', 5)
643-
bucket = s3uri[5 : pos]
644-
key = s3uri[pos + 1 : ]
645-
return bucket, key
646-
647-
def s3_download(s3uri, path):
648-
global cache
649-
650-
pos = s3uri.find('/', 5)
651-
bucket = s3uri[5 : pos]
652-
key = s3uri[pos + 1 : ]
653-
654-
if os.path.isfile('cache'):
655-
cache = json.load(open('cache', 'r'))
656-
657-
response = s3_client.head_object(
658-
Bucket=bucket,
659-
Key=key
660-
)
661-
if key not in cache or cache[key] != response['ETag']:
662-
filename = key[key.rfind('/') + 1 : ]
663-
664-
s3_client.download_file(bucket, key, os.path.join(path, filename))
665-
cache[key] = response['ETag']
666-
667-
json.dump(cache, open('cache', 'w'))
668-
669-
def http_download(httpuri, path):
670-
with requests.get(httpuri, stream=True) as r:
671-
r.raise_for_status()
672-
with open(path, 'wb') as f:
673-
for chunk in r.iter_content(chunk_size=8192):
674-
f.write(chunk)
675631

676632
if __name__ == "__main__":
677633
if cmd_opts.train:

0 commit comments

Comments
 (0)