Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 920a3a4dce |
@@ -19,6 +19,7 @@ class SdVersion(enum.Enum):
|
||||
SD1 = 2
|
||||
SD2 = 3
|
||||
SDXL = 4
|
||||
SD3 = 5
|
||||
|
||||
|
||||
class NetworkOnDisk:
|
||||
@@ -59,6 +60,7 @@ class NetworkOnDisk:
|
||||
self.sd_version = self.detect_version()
|
||||
|
||||
def detect_version(self):
|
||||
# TODO: SdVersion.SD3 detection
|
||||
if str(self.metadata.get('ss_base_model_version', "")).startswith("sdxl_"):
|
||||
return SdVersion.SDXL
|
||||
elif str(self.metadata.get('ss_v2', "")) == "True":
|
||||
|
||||
@@ -38,7 +38,8 @@ shared.options_templates.update(shared.options_section(('extra_networks', "Extra
|
||||
"lora_add_hashes_to_infotext": shared.OptionInfo(True, "Add Lora hashes to infotext"),
|
||||
"lora_bundled_ti_to_infotext": shared.OptionInfo(True, "Add Lora name as TI hashes for bundled Textual Inversion").info('"Add Textual Inversion hashes to infotext" needs to be enabled'),
|
||||
"lora_show_all": shared.OptionInfo(False, "Always show all networks on the Lora page").info("otherwise, those detected as for incompatible version of Stable Diffusion will be hidden"),
|
||||
"lora_hide_unknown_for_versions": shared.OptionInfo([], "Hide networks of unknown versions for model versions", gr.CheckboxGroup, {"choices": ["SD1", "SD2", "SDXL"]}),
|
||||
"lora_hide_unknown_for_versions": shared.OptionInfo([], "Hide networks of unknown versions for model versions", gr.CheckboxGroup, {"choices": ["SD1", "SD2", "SDXL", "SD3"]}),
|
||||
"TEMP_setting_sd3_lora_filter": shared.OptionInfo(["SD1", "Unknown"], "For SD3 model also show Lora of other sd version", gr.CheckboxGroup, {"choices": ["SD1", "SD2", "SDXL", "Unknown"]}).info('Temporary setting until SD3 Lora detection is implemented'),
|
||||
"lora_in_memory_limit": shared.OptionInfo(0, "Number of Lora networks to keep cached in memory", gr.Number, {"precision": 0}),
|
||||
"lora_not_found_warning_console": shared.OptionInfo(False, "Lora not found warning in console"),
|
||||
"lora_not_found_gradio_warning": shared.OptionInfo(False, "Lora not found warning popup in webui"),
|
||||
|
||||
@@ -160,7 +160,7 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor)
|
||||
def create_extra_default_items_in_left_column(self):
|
||||
|
||||
# this would be a lot better as gr.Radio but I can't make it work
|
||||
self.select_sd_version = gr.Dropdown(['SD1', 'SD2', 'SDXL', 'Unknown'], value='Unknown', label='Stable Diffusion version', interactive=True)
|
||||
self.select_sd_version = gr.Dropdown(['SD1', 'SD2', 'SDXL', 'SD3', 'Unknown'], value='Unknown', label='Stable Diffusion version', interactive=True)
|
||||
|
||||
def create_editor(self):
|
||||
self.create_default_editor_elems()
|
||||
|
||||
@@ -3,7 +3,7 @@ import os
|
||||
import network
|
||||
import networks
|
||||
|
||||
from modules import shared, ui_extra_networks
|
||||
from modules import shared, ui_extra_networks, sd_models_types
|
||||
from modules.ui_extra_networks import quote_js
|
||||
from ui_edit_user_metadata import LoraUserMetadataEditor
|
||||
|
||||
@@ -62,8 +62,14 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
|
||||
|
||||
if shared.opts.lora_show_all or not enable_filter or not shared.sd_model:
|
||||
pass
|
||||
elif shared.sd_model.is_sd3:
|
||||
# TODO: add proper SD3 filtering when detection is implemented
|
||||
# TODO: move after Unknown block when implemented
|
||||
if sd_version is network.SdVersion.SD3 or sd_version.name in shared.opts.TEMP_setting_sd3_lora_filter:
|
||||
return item
|
||||
return None
|
||||
elif sd_version == network.SdVersion.Unknown:
|
||||
model_version = network.SdVersion.SDXL if shared.sd_model.is_sdxl else network.SdVersion.SD2 if shared.sd_model.is_sd2 else network.SdVersion.SD1
|
||||
model_version = self.sd_to_lora_version(shared.sd_model)
|
||||
if model_version.name in shared.opts.lora_hide_unknown_for_versions:
|
||||
return None
|
||||
elif shared.sd_model.is_sdxl and sd_version != network.SdVersion.SDXL:
|
||||
@@ -88,3 +94,14 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
|
||||
|
||||
def create_user_metadata_editor(self, ui, tabname):
|
||||
return LoraUserMetadataEditor(ui, tabname, self)
|
||||
|
||||
@staticmethod
|
||||
def sd_to_lora_version(sd_model: sd_models_types.WebuiSdModel):
|
||||
if sd_model.is_sd1:
|
||||
return network.SdVersion.SD1
|
||||
elif sd_model.is_sd2:
|
||||
return network.SdVersion.SD2
|
||||
elif sd_model.is_sdxl:
|
||||
return network.SdVersion.SDXL
|
||||
elif sd_model.is_sd3:
|
||||
return network.SdVersion.SD3
|
||||
|
||||
@@ -14,7 +14,6 @@ def imports():
|
||||
|
||||
import torch # noqa: F401
|
||||
startup_timer.record("import torch")
|
||||
from modules import patch_hf_hub_download # noqa: F401
|
||||
import pytorch_lightning # noqa: F401
|
||||
startup_timer.record("import torch")
|
||||
warnings.filterwarnings(action="ignore", category=DeprecationWarning, module="pytorch_lightning")
|
||||
|
||||
@@ -1,41 +0,0 @@
|
||||
from modules.patches import patch
|
||||
from modules.errors import report
|
||||
from inspect import signature
|
||||
from functools import wraps
|
||||
|
||||
try:
|
||||
from huggingface_hub.utils import LocalEntryNotFoundError
|
||||
from huggingface_hub import file_download
|
||||
|
||||
def try_local_files_only(func):
|
||||
if (param := signature(func).parameters.get('local_files_only', None)) and not param.kind == param.KEYWORD_ONLY:
|
||||
raise ValueError(f'{func.__name__} does not have keyword-only parameter "local_files_only"')
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
from modules.shared import opts
|
||||
try_offline_mode = not kwargs.get('local_files_only') and opts.hd_dl_local_first
|
||||
except Exception:
|
||||
report('Error in try_local_files_only - skip try_local_files_only', exc_info=True)
|
||||
try_offline_mode = False
|
||||
|
||||
if try_offline_mode:
|
||||
try:
|
||||
return func(*args, **{**kwargs, 'local_files_only': True})
|
||||
except LocalEntryNotFoundError:
|
||||
pass
|
||||
except Exception:
|
||||
report('Unexpected exception in try_local_files_only - retry without patch', exc_info=True)
|
||||
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
try:
|
||||
patch(__name__, file_download, 'hf_hub_download', try_local_files_only(file_download.hf_hub_download))
|
||||
except RuntimeError:
|
||||
pass # already patched
|
||||
|
||||
except Exception:
|
||||
report('Error patching hf_hub_download', exc_info=True)
|
||||
@@ -128,7 +128,6 @@ options_templates.update(options_section(('system', "System", "system"), {
|
||||
"disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files.").info("fixes very slow loading speed in some cases"),
|
||||
"hide_ldm_prints": OptionInfo(True, "Prevent Stability-AI's ldm/sgm modules from printing noise to console."),
|
||||
"dump_stacks_on_signal": OptionInfo(False, "Print stack traces before exiting the program with ctrl+c."),
|
||||
"hd_dl_local_first": OptionInfo(False, "Prevent connecting to huggingface for assets if cache is available").info('this will also prevent assets from being updated'),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('profiler', "Profiler", "system"), {
|
||||
|
||||
Reference in New Issue
Block a user