Compare commits

..

1 Commits

Author SHA1 Message Date
w-e-w 920a3a4dce SD3 Lora page filter - detection not implemented 2024-07-31 02:48:47 +09:00
5 changed files with 26 additions and 34 deletions
+2
View File
@@ -19,6 +19,7 @@ class SdVersion(enum.Enum):
SD1 = 2
SD2 = 3
SDXL = 4
SD3 = 5
class NetworkOnDisk:
@@ -59,6 +60,7 @@ class NetworkOnDisk:
self.sd_version = self.detect_version()
def detect_version(self):
# TODO: SdVersion.SD3 detection
if str(self.metadata.get('ss_base_model_version', "")).startswith("sdxl_"):
return SdVersion.SDXL
elif str(self.metadata.get('ss_v2', "")) == "True":
@@ -38,7 +38,8 @@ shared.options_templates.update(shared.options_section(('extra_networks', "Extra
"lora_add_hashes_to_infotext": shared.OptionInfo(True, "Add Lora hashes to infotext"),
"lora_bundled_ti_to_infotext": shared.OptionInfo(True, "Add Lora name as TI hashes for bundled Textual Inversion").info('"Add Textual Inversion hashes to infotext" needs to be enabled'),
"lora_show_all": shared.OptionInfo(False, "Always show all networks on the Lora page").info("otherwise, those detected as for incompatible version of Stable Diffusion will be hidden"),
"lora_hide_unknown_for_versions": shared.OptionInfo([], "Hide networks of unknown versions for model versions", gr.CheckboxGroup, {"choices": ["SD1", "SD2", "SDXL"]}),
"lora_hide_unknown_for_versions": shared.OptionInfo([], "Hide networks of unknown versions for model versions", gr.CheckboxGroup, {"choices": ["SD1", "SD2", "SDXL", "SD3"]}),
"TEMP_setting_sd3_lora_filter": shared.OptionInfo(["SD1", "Unknown"], "For SD3 model also show Lora of other sd version", gr.CheckboxGroup, {"choices": ["SD1", "SD2", "SDXL", "Unknown"]}).info('Temporary setting until SD3 Lora detection is implemented'),
"lora_in_memory_limit": shared.OptionInfo(0, "Number of Lora networks to keep cached in memory", gr.Number, {"precision": 0}),
"lora_not_found_warning_console": shared.OptionInfo(False, "Lora not found warning in console"),
"lora_not_found_gradio_warning": shared.OptionInfo(False, "Lora not found warning popup in webui"),
@@ -160,7 +160,7 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor)
def create_extra_default_items_in_left_column(self):
# this would be a lot better as gr.Radio but I can't make it work
self.select_sd_version = gr.Dropdown(['SD1', 'SD2', 'SDXL', 'Unknown'], value='Unknown', label='Stable Diffusion version', interactive=True)
self.select_sd_version = gr.Dropdown(['SD1', 'SD2', 'SDXL', 'SD3', 'Unknown'], value='Unknown', label='Stable Diffusion version', interactive=True)
def create_editor(self):
self.create_default_editor_elems()
@@ -3,7 +3,7 @@ import os
import network
import networks
from modules import shared, ui_extra_networks
from modules import shared, ui_extra_networks, sd_models_types
from modules.ui_extra_networks import quote_js
from ui_edit_user_metadata import LoraUserMetadataEditor
@@ -62,8 +62,14 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
if shared.opts.lora_show_all or not enable_filter or not shared.sd_model:
pass
elif shared.sd_model.is_sd3:
# TODO: add proper SD3 filtering when detection is implemented
# TODO: move after Unknown block when implemented
if sd_version is network.SdVersion.SD3 or sd_version.name in shared.opts.TEMP_setting_sd3_lora_filter:
return item
return None
elif sd_version == network.SdVersion.Unknown:
model_version = network.SdVersion.SDXL if shared.sd_model.is_sdxl else network.SdVersion.SD2 if shared.sd_model.is_sd2 else network.SdVersion.SD1
model_version = self.sd_to_lora_version(shared.sd_model)
if model_version.name in shared.opts.lora_hide_unknown_for_versions:
return None
elif shared.sd_model.is_sdxl and sd_version != network.SdVersion.SDXL:
@@ -88,3 +94,14 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
def create_user_metadata_editor(self, ui, tabname):
return LoraUserMetadataEditor(ui, tabname, self)
@staticmethod
def sd_to_lora_version(sd_model: sd_models_types.WebuiSdModel):
if sd_model.is_sd1:
return network.SdVersion.SD1
elif sd_model.is_sd2:
return network.SdVersion.SD2
elif sd_model.is_sdxl:
return network.SdVersion.SDXL
elif sd_model.is_sd3:
return network.SdVersion.SD3
+2 -30
View File
@@ -11,30 +11,25 @@ from modules.shared import state
def process_model_tag(tag):
"""\"mode-name\""""
info = sd_models.get_closet_checkpoint_match(tag)
assert info is not None, f'Unknown checkpoint: {tag}'
return info.name
def process_string_tag(tag):
"""\"str\""""
return tag
def process_int_tag(tag):
"""int-number"""
return int(tag)
def process_float_tag(tag):
"""float-number"""
return float(tag)
def process_boolean_tag(tag):
"""true|false"""
return True if (tag.lower() == "true") else False
return True if (tag == "true") else False
prompt_tags = {
@@ -65,27 +60,6 @@ prompt_tags = {
}
def doc_md():
md = '<details><summary>Usage Syntax</summary><p>\n\n'
for key, func in prompt_tags.items():
md += f'`--{key}` `{func.__doc__}`\n'
md += '''
<details><summary>Example</summary><p>
```shell
--prompt "photo of sunset"
--prompt "photo of sunset" --negative_prompt "orange, pink, red, sea, water, lake" --width 1024 --height 768 --sampler_name "DPM++ 2M Karras" --steps 10 --batch_size 2 --cfg_scale 3 --seed 9
--prompt "photo of winter mountains" --steps 7 --sampler_name "DDIM"
--prompt "photo of winter mountains" --width 1024
```
</p></details>
'''
md += '</p></details>'
return md
def cmdargs(line):
args = shlex.split(line)
pos = 0
@@ -110,6 +84,7 @@ def cmdargs(line):
res[tag] = prompt
continue
func = prompt_tags.get(tag, None)
assert func, f'unknown commandline option: {arg}'
@@ -150,9 +125,6 @@ class Script(scripts.Script):
# We don't shrink back to 1, because that causes the control to ignore [enter], and it may
# be unclear to the user that shift-enter is needed.
prompt_txt.change(lambda tb: gr.update(lines=7) if ("\n" in tb) else gr.update(lines=2), inputs=[prompt_txt], outputs=[prompt_txt], show_progress=False)
gr.Markdown(doc_md())
return [checkbox_iterate, checkbox_iterate_batch, prompt_position, prompt_txt]
def run(self, p, checkbox_iterate, checkbox_iterate_batch, prompt_position, prompt_txt: str):