diff --git a/javascript/event_handlers.js b/javascript/event_handlers.js index 66fea0d..c7710ac 100644 --- a/javascript/event_handlers.js +++ b/javascript/event_handlers.js @@ -113,7 +113,7 @@ onUiUpdate(function() { /* ################### ^ DEPRICATED ^ ############################ */ // Sync the refresh main model list button - registerClickEvents(gradioApp().querySelector('#refresh_sd_model_checkpoint'), ['#cp_modelpreview_xd_refresh_sd_model']); + registerClickEvents(gradioApp().querySelector('#refresh_sd_model_checkpoint'), ['#cp_modelpreview_xd_refresh_sd_model','#lo_modelpreview_xd_refresh_sd_model','#hn_modelpreview_xd_refresh_sd_model','#em_modelpreview_xd_refresh_sd_model']); // Sync the new refresh extra network buttons to this extension registerClickEvents(gradioApp().querySelector('#txt2img_extra_refresh'), ['#cp_modelpreview_xd_refresh_sd_model','#lo_modelpreview_xd_refresh_sd_model','#hn_modelpreview_xd_refresh_sd_model','#em_modelpreview_xd_refresh_sd_model']); diff --git a/scripts/modelpreview.py b/scripts/modelpreview.py index c0c9487..63fc49c 100644 --- a/scripts/modelpreview.py +++ b/scripts/modelpreview.py @@ -80,7 +80,7 @@ def natural_order_number(s): } def search_for_tags(model_names, model_tags, paths): - + model_tags.clear() general_tag_pattern = re.compile(r'^.*(?i:\.tags)$') # support the ability to check multiple paths @@ -120,7 +120,6 @@ def search_for_tags(model_names, model_tags, paths): def list_all_models(): global checkpoint_choices # gets the list of checkpoints - print("refresh models") model_list = sd_models.checkpoint_tiles() checkpoint_choices = sorted(model_list, key=natural_order_number) search_for_tags(checkpoint_choices, tags["checkpoints"], get_checkpoints_dirs()) @@ -173,29 +172,29 @@ def list_all_loras(): search_for_tags(lora_choices, tags["loras"], get_lora_dirs()) return lora_choices -def refresh_models(): +def refresh_models(choice = None, filter = None): global checkpoint_choices # update the choices for the checkpoint list checkpoint_choices = list_all_models() - return gr.Dropdown.update(choices=checkpoint_choices) + return filter_models(filter), *show_model_preview(choice) -def refresh_embeddings(): +def refresh_embeddings(choice = None, filter = None): global embedding_choices # update the choices for the embeddings list embedding_choices = list_all_embeddings() - return gr.Dropdown.update(choices=embedding_choices) + return filter_embeddings(filter), *show_embedding_preview(choice) -def refresh_hypernetworks(): +def refresh_hypernetworks(choice = None, filter = None): global hypernetwork_choices # update the choices for the hypernetworks list hypernetwork_choices = list_all_hypernetworks() - return gr.Dropdown.update(choices=hypernetwork_choices) + return filter_hypernetworks(filter), *show_hypernetwork_preview(choice) -def refresh_loras(): +def refresh_loras(choice = None, filter = None): global lora_choices # update the choices for the lora list lora_choices = list_all_loras() - return gr.Dropdown.update(choices=lora_choices) + return filter_loras(filter), *show_lora_preview(choice) def filter_choices(choices, filter, tags_obj): filtered_choices = choices @@ -440,30 +439,27 @@ def get_lora_dirs(): return directories def show_model_preview(modelname=None): - # remove the hash if exists, the extension, and if the string is a path just return the file name - modelname = clean_modelname(modelname) # get preview for the model - return show_preview(modelname, get_checkpoints_dirs()) + return show_preview(modelname, get_checkpoints_dirs(), "checkpoints") def show_embedding_preview(modelname=None): - # remove the hash if exists, the extension, and if the string is a path just return the file name - modelname = clean_modelname(modelname) # get preview for the model - return show_preview(modelname, get_embedding_dirs()) + return show_preview(modelname, get_embedding_dirs(), "embeddings") def show_hypernetwork_preview(modelname=None): - # remove the hash if exists, the extension, and if the string is a path just return the file name - modelname = clean_modelname(modelname) # get preview for the model - return show_preview(modelname, get_hypernetwork_dirs()) + return show_preview(modelname, get_hypernetwork_dirs(), "hypernetworks") def show_lora_preview(modelname=None): - # remove the hash if exists, the extension, and if the string is a path just return the file name - modelname = clean_modelname(modelname) # get preview for the model - return show_preview(modelname, get_lora_dirs()) + return show_preview(modelname, get_lora_dirs(), "loras") -def show_preview(name, paths): +def show_preview(modelname, paths, tags_key): + if modelname is None or len(modelname) == 0 or paths is None or len(paths) == 0: + return None, None, None, None + + # remove the hash if exists, the extension, and if the string is a path just return the file name + name = clean_modelname(modelname) # get the preview data html_code, found_md_file, found_txt_file = search_and_display_previews(name, paths) preview_html = '' if html_code is None else html_code @@ -496,7 +492,14 @@ def show_preview(name, paths): # if nothing was found display a message that nothing was found if found_txt_file is None and found_md_file is None and (html_code is None or len(html_code) == 0): html_update = gr.HTML.update(value="No Preview Found", visible=True) - return txt_update, md_update, html_update + + # get the tags from the tags object and create a span for them + found_tags = tags[tags_key].get(modelname, None) + if found_tags is not None: + tags_html = gr.HTML.update(value=f'', visible=True) + else: + tags_html = gr.HTML.update(value='', visible=False) + return txt_update, md_update, html_update, tags_html def create_tab(tab_label, tab_id_key, list_choices, show_preview_fn, filter_fn, refresh_fn, update_selected_fn): # create a tab for model previews @@ -514,6 +517,8 @@ def create_tab(tab_label, tab_id_key, list_choices, show_preview_fn, filter_fn, with gr.Row(elem_id=f"{tab_id_key}_modelpreview_xd_flexcolumn_row"): preview_html = gr.HTML(elem_id=f"{tab_id_key}_modelpreview_xd_html_div", visible=False) preview_md = gr.Markdown(elem_id=f"{tab_id_key}_modelpreview_xd_markdown_div", visible=False) + with gr.Row(elem_id=f"{tab_id_key}_modelpreview_xd_tags_row"): + preview_tags = gr.HTML(elem_id=f"{tab_id_key}_modelpreview_xd_tags_div", visible=False) list.change( fn=show_preview_fn, @@ -524,6 +529,7 @@ def create_tab(tab_label, tab_id_key, list_choices, show_preview_fn, filter_fn, notes_text_area, preview_md, preview_html, + preview_tags ] ) @@ -539,9 +545,16 @@ def create_tab(tab_label, tab_id_key, list_choices, show_preview_fn, filter_fn, refresh_list.click( fn=refresh_fn, - inputs=[], + inputs=[ + list, + filter_input, + ], outputs=[ list, + notes_text_area, + preview_md, + preview_html, + preview_tags ] ) @@ -555,6 +568,7 @@ def create_tab(tab_label, tab_id_key, list_choices, show_preview_fn, filter_fn, notes_text_area, preview_md, preview_html, + preview_tags ] ) diff --git a/style.css b/style.css index 7cdfe52..8e01367 100644 --- a/style.css +++ b/style.css @@ -88,6 +88,12 @@ flex-direction: column; } +#tab_modelpreview_xd_interface .footer-tags { + font-size: 85%; + opacity: 0.85; + border-top: 2px solid #e5e7eb; +} + /* Meta Copy Styling */ #cp_modelpreview_xd_html_div .img-meta, #em_modelpreview_xd_html_div .img-meta,