Skip to content

Commit

Permalink
xd v1.1.3
Browse files Browse the repository at this point in the history
- added tags to bottom of preview tab closes #12
- will now refresh tags for models when you hit the refresh models button
  • Loading branch information
CurtisDS committed Mar 1, 2023
1 parent f3d38ea commit b1cfec3
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 26 deletions.
2 changes: 1 addition & 1 deletion javascript/event_handlers.js
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ onUiUpdate(function() {
/* ################### ^ DEPRICATED ^ ############################ */

// Sync the refresh main model list button
registerClickEvents(gradioApp().querySelector('#refresh_sd_model_checkpoint'), ['#cp_modelpreview_xd_refresh_sd_model']);
registerClickEvents(gradioApp().querySelector('#refresh_sd_model_checkpoint'), ['#cp_modelpreview_xd_refresh_sd_model','#lo_modelpreview_xd_refresh_sd_model','#hn_modelpreview_xd_refresh_sd_model','#em_modelpreview_xd_refresh_sd_model']);

// Sync the new refresh extra network buttons to this extension
registerClickEvents(gradioApp().querySelector('#txt2img_extra_refresh'), ['#cp_modelpreview_xd_refresh_sd_model','#lo_modelpreview_xd_refresh_sd_model','#hn_modelpreview_xd_refresh_sd_model','#em_modelpreview_xd_refresh_sd_model']);
Expand Down
64 changes: 39 additions & 25 deletions scripts/modelpreview.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def natural_order_number(s):
}

def search_for_tags(model_names, model_tags, paths):

model_tags.clear()
general_tag_pattern = re.compile(r'^.*(?i:\.tags)$')

# support the ability to check multiple paths
Expand Down Expand Up @@ -120,7 +120,6 @@ def search_for_tags(model_names, model_tags, paths):
def list_all_models():
global checkpoint_choices
# gets the list of checkpoints
print("refresh models")
model_list = sd_models.checkpoint_tiles()
checkpoint_choices = sorted(model_list, key=natural_order_number)
search_for_tags(checkpoint_choices, tags["checkpoints"], get_checkpoints_dirs())
Expand Down Expand Up @@ -173,29 +172,29 @@ def list_all_loras():
search_for_tags(lora_choices, tags["loras"], get_lora_dirs())
return lora_choices

def refresh_models():
def refresh_models(choice = None, filter = None):
global checkpoint_choices
# update the choices for the checkpoint list
checkpoint_choices = list_all_models()
return gr.Dropdown.update(choices=checkpoint_choices)
return filter_models(filter), *show_model_preview(choice)

def refresh_embeddings():
def refresh_embeddings(choice = None, filter = None):
global embedding_choices
# update the choices for the embeddings list
embedding_choices = list_all_embeddings()
return gr.Dropdown.update(choices=embedding_choices)
return filter_embeddings(filter), *show_embedding_preview(choice)

def refresh_hypernetworks():
def refresh_hypernetworks(choice = None, filter = None):
global hypernetwork_choices
# update the choices for the hypernetworks list
hypernetwork_choices = list_all_hypernetworks()
return gr.Dropdown.update(choices=hypernetwork_choices)
return filter_hypernetworks(filter), *show_hypernetwork_preview(choice)

def refresh_loras():
def refresh_loras(choice = None, filter = None):
global lora_choices
# update the choices for the lora list
lora_choices = list_all_loras()
return gr.Dropdown.update(choices=lora_choices)
return filter_loras(filter), *show_lora_preview(choice)

def filter_choices(choices, filter, tags_obj):
filtered_choices = choices
Expand Down Expand Up @@ -440,30 +439,27 @@ def get_lora_dirs():
return directories

def show_model_preview(modelname=None):
# remove the hash if exists, the extension, and if the string is a path just return the file name
modelname = clean_modelname(modelname)
# get preview for the model
return show_preview(modelname, get_checkpoints_dirs())
return show_preview(modelname, get_checkpoints_dirs(), "checkpoints")

def show_embedding_preview(modelname=None):
# remove the hash if exists, the extension, and if the string is a path just return the file name
modelname = clean_modelname(modelname)
# get preview for the model
return show_preview(modelname, get_embedding_dirs())
return show_preview(modelname, get_embedding_dirs(), "embeddings")

def show_hypernetwork_preview(modelname=None):
# remove the hash if exists, the extension, and if the string is a path just return the file name
modelname = clean_modelname(modelname)
# get preview for the model
return show_preview(modelname, get_hypernetwork_dirs())
return show_preview(modelname, get_hypernetwork_dirs(), "hypernetworks")

def show_lora_preview(modelname=None):
# remove the hash if exists, the extension, and if the string is a path just return the file name
modelname = clean_modelname(modelname)
# get preview for the model
return show_preview(modelname, get_lora_dirs())
return show_preview(modelname, get_lora_dirs(), "loras")

def show_preview(name, paths):
def show_preview(modelname, paths, tags_key):
if modelname is None or len(modelname) == 0 or paths is None or len(paths) == 0:
return None, None, None, None

# remove the hash if exists, the extension, and if the string is a path just return the file name
name = clean_modelname(modelname)
# get the preview data
html_code, found_md_file, found_txt_file = search_and_display_previews(name, paths)
preview_html = '' if html_code is None else html_code
Expand Down Expand Up @@ -496,7 +492,14 @@ def show_preview(name, paths):
# if nothing was found display a message that nothing was found
if found_txt_file is None and found_md_file is None and (html_code is None or len(html_code) == 0):
html_update = gr.HTML.update(value="<span style='margin-left: 1em;'>No Preview Found</span>", visible=True)
return txt_update, md_update, html_update

# get the tags from the tags object and create a span for them
found_tags = tags[tags_key].get(modelname, None)
if found_tags is not None:
tags_html = gr.HTML.update(value=f'<div class="footer-tags">{found_tags}</div>', visible=True)
else:
tags_html = gr.HTML.update(value='', visible=False)
return txt_update, md_update, html_update, tags_html

def create_tab(tab_label, tab_id_key, list_choices, show_preview_fn, filter_fn, refresh_fn, update_selected_fn):
# create a tab for model previews
Expand All @@ -514,6 +517,8 @@ def create_tab(tab_label, tab_id_key, list_choices, show_preview_fn, filter_fn,
with gr.Row(elem_id=f"{tab_id_key}_modelpreview_xd_flexcolumn_row"):
preview_html = gr.HTML(elem_id=f"{tab_id_key}_modelpreview_xd_html_div", visible=False)
preview_md = gr.Markdown(elem_id=f"{tab_id_key}_modelpreview_xd_markdown_div", visible=False)
with gr.Row(elem_id=f"{tab_id_key}_modelpreview_xd_tags_row"):
preview_tags = gr.HTML(elem_id=f"{tab_id_key}_modelpreview_xd_tags_div", visible=False)

list.change(
fn=show_preview_fn,
Expand All @@ -524,6 +529,7 @@ def create_tab(tab_label, tab_id_key, list_choices, show_preview_fn, filter_fn,
notes_text_area,
preview_md,
preview_html,
preview_tags
]
)

Expand All @@ -539,9 +545,16 @@ def create_tab(tab_label, tab_id_key, list_choices, show_preview_fn, filter_fn,

refresh_list.click(
fn=refresh_fn,
inputs=[],
inputs=[
list,
filter_input,
],
outputs=[
list,
notes_text_area,
preview_md,
preview_html,
preview_tags
]
)

Expand All @@ -555,6 +568,7 @@ def create_tab(tab_label, tab_id_key, list_choices, show_preview_fn, filter_fn,
notes_text_area,
preview_md,
preview_html,
preview_tags
]
)

Expand Down
6 changes: 6 additions & 0 deletions style.css
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,12 @@
flex-direction: column;
}

#tab_modelpreview_xd_interface .footer-tags {
font-size: 85%;
opacity: 0.85;
border-top: 2px solid #e5e7eb;
}

/* Meta Copy Styling */
#cp_modelpreview_xd_html_div .img-meta,
#em_modelpreview_xd_html_div .img-meta,
Expand Down

0 comments on commit b1cfec3

Please sign in to comment.