Compare commits

...

15 Commits

Author SHA1 Message Date
AUTOMATIC1111 cf2772fab0 Merge branch 'release_candidate' 2023-12-16 09:58:07 +03:00
AUTOMATIC1111 0dfffe53ec Merge pull request #14307 from AUTOMATIC1111/default-Falst-js_live_preview_in_modal_lightbox
default False js_live_preview_in_modal_lightbox
2023-12-16 09:25:33 +03:00
AUTOMATIC1111 2be85f8fe0 Merge pull request #14237 from ReneKroon/dev
#13354 : solve lora loading issue
2023-12-14 10:15:36 +03:00
AUTOMATIC1111 eb52c803b8 Merge pull request #14216 from wfjsw/state-dict-ref-comparison
change state dict comparison to ref compare
2023-12-14 10:15:22 +03:00
AUTOMATIC1111 f8871dedcf Merge pull request #14230 from AUTOMATIC1111/add-option-Live-preview-in-full-page-image-viewer
add option: Live preview in full page image viewer
2023-12-14 10:15:18 +03:00
AUTOMATIC1111 b7e0d4a7e1 Merge pull request #14229 from Nuullll/ipex-embedding
[IPEX] Fix embedding and ControlNet
2023-12-14 10:14:59 +03:00
AUTOMATIC1111 5cb1ce470d Merge pull request #14266 from kaalibro/dev
Re-add setting lost as part of e294e46
2023-12-14 10:14:54 +03:00
AUTOMATIC1111 888b928f0d Merge pull request #14276 from AUTOMATIC1111/fix-styles
Fix styles
2023-12-14 10:14:50 +03:00
AUTOMATIC1111 b55f09c4e1 Merge pull request #14270 from kaalibro/extra-options-elem-id
Assign id for "extra_options". Replace numeric field with slider.
2023-12-14 10:14:46 +03:00
AUTOMATIC1111 c7cd9b441d Merge pull request #14296 from akx/paste-resolution
Allow pasting in WIDTHxHEIGHT strings into the width/height fields
2023-12-14 10:14:41 +03:00
AUTOMATIC1111 6ef0ff39f2 Merge pull request #14300 from AUTOMATIC1111/oft_fixes
Fix wrong implementation in network_oft
2023-12-14 10:14:19 +03:00
AUTOMATIC1111 120a84bd2f Merge pull request #14203 from AUTOMATIC1111/remove-clean_text()
remove clean_text()
2023-12-05 07:15:54 +03:00
AUTOMATIC1111 368d66c9cc add hypertile infotext 2023-12-04 15:56:11 +03:00
AUTOMATIC1111 81105ee013 repair old handler for postprocessing API in a way that doesn't break interface 2023-12-04 13:11:12 +03:00
AUTOMATIC1111 24dae9bc4c repair old handler for postprocessing API 2023-12-04 12:36:56 +03:00
12 changed files with 111 additions and 88 deletions
+11 -26
View File
@@ -21,6 +21,8 @@ class NetworkModuleOFT(network.NetworkModule):
self.lin_module = None self.lin_module = None
self.org_module: list[torch.Module] = [self.sd_module] self.org_module: list[torch.Module] = [self.sd_module]
self.scale = 1.0
# kohya-ss # kohya-ss
if "oft_blocks" in weights.w.keys(): if "oft_blocks" in weights.w.keys():
self.is_kohya = True self.is_kohya = True
@@ -53,12 +55,18 @@ class NetworkModuleOFT(network.NetworkModule):
self.constraint = None self.constraint = None
self.block_size, self.num_blocks = factorization(self.out_dim, self.dim) self.block_size, self.num_blocks = factorization(self.out_dim, self.dim)
def calc_updown_kb(self, orig_weight, multiplier): def calc_updown(self, orig_weight):
oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype)
oft_blocks = oft_blocks - oft_blocks.transpose(1, 2) # ensure skew-symmetric orthogonal matrix eye = torch.eye(self.block_size, device=self.oft_blocks.device)
if self.is_kohya:
block_Q = oft_blocks - oft_blocks.transpose(1, 2) # ensure skew-symmetric orthogonal matrix
norm_Q = torch.norm(block_Q.flatten())
new_norm_Q = torch.clamp(norm_Q, max=self.constraint)
block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8))
oft_blocks = torch.matmul(eye + block_Q, (eye - block_Q).float().inverse())
R = oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) R = oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype)
R = R * multiplier + torch.eye(self.block_size, device=orig_weight.device)
# This errors out for MultiheadAttention, might need to be handled up-stream # This errors out for MultiheadAttention, might need to be handled up-stream
merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size) merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size)
@@ -72,26 +80,3 @@ class NetworkModuleOFT(network.NetworkModule):
updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight
output_shape = orig_weight.shape output_shape = orig_weight.shape
return self.finalize_updown(updown, orig_weight, output_shape) return self.finalize_updown(updown, orig_weight, output_shape)
def calc_updown(self, orig_weight):
# if alpha is a very small number as in coft, calc_scale() will return a almost zero number so we ignore it
multiplier = self.multiplier()
return self.calc_updown_kb(orig_weight, multiplier)
# override to remove the multiplier/scale factor; it's already multiplied in get_weight
def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None):
if self.bias is not None:
updown = updown.reshape(self.bias.shape)
updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype)
updown = updown.reshape(output_shape)
if len(output_shape) == 4:
updown = updown.reshape(output_shape)
if orig_weight.size().numel() == updown.size().numel():
updown = updown.reshape(orig_weight.shape)
if ex_bias is not None:
ex_bias = ex_bias * self.multiplier()
return updown, ex_bias
+2 -1
View File
@@ -159,7 +159,8 @@ def load_network(name, network_on_disk):
bundle_embeddings = {} bundle_embeddings = {}
for key_network, weight in sd.items(): for key_network, weight in sd.items():
key_network_without_network_parts, network_part = key_network.split(".", 1) key_network_without_network_parts, _, network_part = key_network.partition(".")
if key_network_without_network_parts == "bundle_emb": if key_network_without_network_parts == "bundle_emb":
emb_name, vec_name = network_part.split(".", 1) emb_name, vec_name = network_part.split(".", 1)
emb_dict = bundle_embeddings.get(emb_name, {}) emb_dict = bundle_embeddings.get(emb_name, {})
@@ -23,11 +23,12 @@ class ExtraOptionsSection(scripts.Script):
self.setting_names = [] self.setting_names = []
self.infotext_fields = [] self.infotext_fields = []
extra_options = shared.opts.extra_options_img2img if is_img2img else shared.opts.extra_options_txt2img extra_options = shared.opts.extra_options_img2img if is_img2img else shared.opts.extra_options_txt2img
elem_id_tabname = "extra_options_" + ("img2img" if is_img2img else "txt2img")
mapping = {k: v for v, k in generation_parameters_copypaste.infotext_to_setting_name_mapping} mapping = {k: v for v, k in generation_parameters_copypaste.infotext_to_setting_name_mapping}
with gr.Blocks() as interface: with gr.Blocks() as interface:
with gr.Accordion("Options", open=False) if shared.opts.extra_options_accordion and extra_options else gr.Group(): with gr.Accordion("Options", open=False, elem_id=elem_id_tabname) if shared.opts.extra_options_accordion and extra_options else gr.Group(elem_id=elem_id_tabname):
row_count = math.ceil(len(extra_options) / shared.opts.extra_options_cols) row_count = math.ceil(len(extra_options) / shared.opts.extra_options_cols)
@@ -70,7 +71,7 @@ This page allows you to add some settings to the main interface of txt2img and i
"""), """),
"extra_options_txt2img": shared.OptionInfo([], "Settings for txt2img", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in txt2img interfaces").needs_reload_ui(), "extra_options_txt2img": shared.OptionInfo([], "Settings for txt2img", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in txt2img interfaces").needs_reload_ui(),
"extra_options_img2img": shared.OptionInfo([], "Settings for img2img", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in img2img interfaces").needs_reload_ui(), "extra_options_img2img": shared.OptionInfo([], "Settings for img2img", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in img2img interfaces").needs_reload_ui(),
"extra_options_cols": shared.OptionInfo(1, "Number of columns for added settings", gr.Number, {"precision": 0}).needs_reload_ui(), "extra_options_cols": shared.OptionInfo(1, "Number of columns for added settings", gr.Slider, {"step": 1, "minimum": 1, "maximum": 20}).info("displayed amount will depend on the actual browser window width").needs_reload_ui(),
"extra_options_accordion": shared.OptionInfo(False, "Place added settings into an accordion").needs_reload_ui() "extra_options_accordion": shared.OptionInfo(False, "Place added settings into an accordion").needs_reload_ui()
})) }))
@@ -17,11 +17,42 @@ class ScriptHypertile(scripts.Script):
configure_hypertile(p.width, p.height, enable_unet=shared.opts.hypertile_enable_unet) configure_hypertile(p.width, p.height, enable_unet=shared.opts.hypertile_enable_unet)
self.add_infotext(p)
def before_hr(self, p, *args): def before_hr(self, p, *args):
enable = shared.opts.hypertile_enable_unet_secondpass or shared.opts.hypertile_enable_unet
# exclusive hypertile seed for the second pass # exclusive hypertile seed for the second pass
if not shared.opts.hypertile_enable_unet: if enable:
hypertile.set_hypertile_seed(p.all_seeds[0]) hypertile.set_hypertile_seed(p.all_seeds[0])
configure_hypertile(p.hr_upscale_to_x, p.hr_upscale_to_y, enable_unet=shared.opts.hypertile_enable_unet_secondpass)
configure_hypertile(p.hr_upscale_to_x, p.hr_upscale_to_y, enable_unet=enable)
if enable and not shared.opts.hypertile_enable_unet:
p.extra_generation_params["Hypertile U-Net second pass"] = True
self.add_infotext(p, add_unet_params=True)
def add_infotext(self, p, add_unet_params=False):
def option(name):
value = getattr(shared.opts, name)
default_value = shared.opts.get_default(name)
return None if value == default_value else value
if shared.opts.hypertile_enable_unet:
p.extra_generation_params["Hypertile U-Net"] = True
if shared.opts.hypertile_enable_unet or add_unet_params:
p.extra_generation_params["Hypertile U-Net max depth"] = option('hypertile_max_depth_unet')
p.extra_generation_params["Hypertile U-Net max tile size"] = option('hypertile_max_tile_unet')
p.extra_generation_params["Hypertile U-Net swap size"] = option('hypertile_swap_size_unet')
if shared.opts.hypertile_enable_vae:
p.extra_generation_params["Hypertile VAE"] = True
p.extra_generation_params["Hypertile VAE max depth"] = option('hypertile_max_depth_vae')
p.extra_generation_params["Hypertile VAE max tile size"] = option('hypertile_max_tile_vae')
p.extra_generation_params["Hypertile VAE swap size"] = option('hypertile_swap_size_vae')
def configure_hypertile(width, height, enable_unet=True): def configure_hypertile(width, height, enable_unet=True):
@@ -57,16 +88,16 @@ def on_ui_settings():
benefit. benefit.
"""), """),
"hypertile_enable_unet": shared.OptionInfo(False, "Enable Hypertile U-Net").info("noticeable change in details of the generated picture; if enabled, overrides the setting below"), "hypertile_enable_unet": shared.OptionInfo(False, "Enable Hypertile U-Net", infotext="Hypertile U-Net").info("enables hypertile for all modes, including hires fix second pass; noticeable change in details of the generated picture"),
"hypertile_enable_unet_secondpass": shared.OptionInfo(False, "Enable Hypertile U-Net for hires fix second pass"), "hypertile_enable_unet_secondpass": shared.OptionInfo(False, "Enable Hypertile U-Net for hires fix second pass", infotext="Hypertile U-Net second pass").info("enables hypertile just for hires fix second pass - regardless of whether the above setting is enabled"),
"hypertile_max_depth_unet": shared.OptionInfo(3, "Hypertile U-Net max depth", gr.Slider, {"minimum": 0, "maximum": 3, "step": 1}), "hypertile_max_depth_unet": shared.OptionInfo(3, "Hypertile U-Net max depth", gr.Slider, {"minimum": 0, "maximum": 3, "step": 1}, infotext="Hypertile U-Net max depth").info("larger = more neural network layers affected; minor effect on performance"),
"hypertile_max_tile_unet": shared.OptionInfo(256, "Hypertile U-net max tile size", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}), "hypertile_max_tile_unet": shared.OptionInfo(256, "Hypertile U-Net max tile size", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}, infotext="Hypertile U-Net max tile size").info("larger = worse performance"),
"hypertile_swap_size_unet": shared.OptionInfo(3, "Hypertile U-net swap size", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}), "hypertile_swap_size_unet": shared.OptionInfo(3, "Hypertile U-Net swap size", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}, infotext="Hypertile U-Net swap size"),
"hypertile_enable_vae": shared.OptionInfo(False, "Enable Hypertile VAE").info("minimal change in the generated picture"), "hypertile_enable_vae": shared.OptionInfo(False, "Enable Hypertile VAE", infotext="Hypertile VAE").info("minimal change in the generated picture"),
"hypertile_max_depth_vae": shared.OptionInfo(3, "Hypertile VAE max depth", gr.Slider, {"minimum": 0, "maximum": 3, "step": 1}), "hypertile_max_depth_vae": shared.OptionInfo(3, "Hypertile VAE max depth", gr.Slider, {"minimum": 0, "maximum": 3, "step": 1}, infotext="Hypertile VAE max depth"),
"hypertile_max_tile_vae": shared.OptionInfo(128, "Hypertile VAE max tile size", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}), "hypertile_max_tile_vae": shared.OptionInfo(128, "Hypertile VAE max tile size", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}, infotext="Hypertile VAE max tile size"),
"hypertile_swap_size_vae": shared.OptionInfo(3, "Hypertile VAE swap size ", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}), "hypertile_swap_size_vae": shared.OptionInfo(3, "Hypertile VAE swap size ", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}, infotext="Hypertile VAE swap size"),
} }
for name, opt in options.items(): for name, opt in options.items():
+1 -1
View File
@@ -34,7 +34,7 @@ function updateOnBackgroundChange() {
if (modalImage && modalImage.offsetParent) { if (modalImage && modalImage.offsetParent) {
let currentButton = selected_gallery_button(); let currentButton = selected_gallery_button();
let preview = gradioApp().querySelectorAll('.livePreview > img'); let preview = gradioApp().querySelectorAll('.livePreview > img');
if (preview.length > 0) { if (opts.js_live_preview_in_modal_lightbox && preview.length > 0) {
// show preview image if available // show preview image if available
modalImage.src = preview[preview.length - 1].src; modalImage.src = preview[preview.length - 1].src;
} else if (currentButton?.children?.length > 0 && modalImage.src != currentButton.children[0].src) { } else if (currentButton?.children?.length > 0 && modalImage.src != currentButton.children[0].src) {
+24
View File
@@ -215,9 +215,33 @@ function restoreProgressImg2img() {
} }
/**
* Configure the width and height elements on `tabname` to accept
* pasting of resolutions in the form of "width x height".
*/
function setupResolutionPasting(tabname) {
var width = gradioApp().querySelector(`#${tabname}_width input[type=number]`);
var height = gradioApp().querySelector(`#${tabname}_height input[type=number]`);
for (const el of [width, height]) {
el.addEventListener('paste', function(event) {
var pasteData = event.clipboardData.getData('text/plain');
var parsed = pasteData.match(/^\s*(\d+)\D+(\d+)\s*$/);
if (parsed) {
width.value = parsed[1];
height.value = parsed[2];
updateInput(width);
updateInput(height);
event.preventDefault();
}
});
}
}
onUiLoaded(function() { onUiLoaded(function() {
showRestoreProgressButton('txt2img', localGet("txt2img_task_id")); showRestoreProgressButton('txt2img', localGet("txt2img_task_id"));
showRestoreProgressButton('img2img', localGet("img2img_task_id")); showRestoreProgressButton('img2img', localGet("img2img_task_id"));
setupResolutionPasting('txt2img');
setupResolutionPasting('img2img');
}); });
+5 -1
View File
@@ -6,7 +6,7 @@ from modules import shared, images, devices, scripts, scripts_postprocessing, ui
from modules.shared import opts from modules.shared import opts
def run_postprocessing(id_task, extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output: bool = True): def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output: bool = True):
devices.torch_gc() devices.torch_gc()
shared.state.begin(job="extras") shared.state.begin(job="extras")
@@ -128,6 +128,10 @@ def run_postprocessing(id_task, extras_mode, image, image_folder, input_dir, out
return outputs, ui_common.plaintext_to_html(infotext), '' return outputs, ui_common.plaintext_to_html(infotext), ''
def run_postprocessing_webui(id_task, *args, **kwargs):
return run_postprocessing(*args, **kwargs)
def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True): def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True):
"""old handler for API""" """old handler for API"""
+1 -1
View File
@@ -215,7 +215,7 @@ class LoadStateDictOnMeta(ReplaceHelper):
would be on the meta device. would be on the meta device.
""" """
if state_dict == sd: if state_dict is sd:
state_dict = {k: v.to(device="meta", dtype=v.dtype) for k, v in state_dict.items()} state_dict = {k: v.to(device="meta", dtype=v.dtype) for k, v in state_dict.items()}
original(module, state_dict, strict=strict) original(module, state_dict, strict=strict)
+2
View File
@@ -256,6 +256,7 @@ options_templates.update(options_section(('ui_prompt_editing', "Prompt editing",
"keyedit_precision_extra": OptionInfo(0.05, "Precision for <extra networks:0.9> when editing the prompt with Ctrl+up/down", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}), "keyedit_precision_extra": OptionInfo(0.05, "Precision for <extra networks:0.9> when editing the prompt with Ctrl+up/down", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
"keyedit_delimiters": OptionInfo(r".,\/!?%^*;:{}=`~() ", "Word delimiters when editing the prompt with Ctrl+up/down"), "keyedit_delimiters": OptionInfo(r".,\/!?%^*;:{}=`~() ", "Word delimiters when editing the prompt with Ctrl+up/down"),
"keyedit_delimiters_whitespace": OptionInfo(["Tab", "Carriage Return", "Line Feed"], "Ctrl+up/down whitespace delimiters", gr.CheckboxGroup, lambda: {"choices": ["Tab", "Carriage Return", "Line Feed"]}), "keyedit_delimiters_whitespace": OptionInfo(["Tab", "Carriage Return", "Line Feed"], "Ctrl+up/down whitespace delimiters", gr.CheckboxGroup, lambda: {"choices": ["Tab", "Carriage Return", "Line Feed"]}),
"keyedit_move": OptionInfo(True, "Alt+left/right moves prompt elements"),
"disable_token_counters": OptionInfo(False, "Disable prompt token counters").needs_reload_ui(), "disable_token_counters": OptionInfo(False, "Disable prompt token counters").needs_reload_ui(),
})) }))
@@ -330,6 +331,7 @@ options_templates.update(options_section(('ui', "Live previews", "ui"), {
"live_preview_content": OptionInfo("Prompt", "Live preview subject", gr.Radio, {"choices": ["Combined", "Prompt", "Negative prompt"]}), "live_preview_content": OptionInfo("Prompt", "Live preview subject", gr.Radio, {"choices": ["Combined", "Prompt", "Negative prompt"]}),
"live_preview_refresh_period": OptionInfo(1000, "Progressbar and preview update period").info("in milliseconds"), "live_preview_refresh_period": OptionInfo(1000, "Progressbar and preview update period").info("in milliseconds"),
"live_preview_fast_interrupt": OptionInfo(False, "Return image with chosen live preview method on interrupt").info("makes interrupts faster"), "live_preview_fast_interrupt": OptionInfo(False, "Return image with chosen live preview method on interrupt").info("makes interrupts faster"),
"js_live_preview_in_modal_lightbox": OptionInfo(False, "Show Live preview in full page image viewer"),
})) }))
options_templates.update(options_section(('sampler-params', "Sampler parameters", "sd"), { options_templates.update(options_section(('sampler-params', "Sampler parameters", "sd"), {
+10 -44
View File
@@ -2,7 +2,6 @@ import csv
import fnmatch import fnmatch
import os import os
import os.path import os.path
import re
import typing import typing
import shutil import shutil
@@ -14,22 +13,6 @@ class PromptStyle(typing.NamedTuple):
path: str = None path: str = None
def clean_text(text: str) -> str:
"""
Iterating through a list of regular expressions and replacement strings, we
clean up the prompt and style text to make it easier to match against each
other.
"""
re_list = [
("multiple commas", re.compile("(,+\s+)+,?"), ", "),
("multiple spaces", re.compile("\s{2,}"), " "),
]
for _, regex, replace in re_list:
text = regex.sub(replace, text)
return text.strip(", ")
def merge_prompts(style_prompt: str, prompt: str) -> str: def merge_prompts(style_prompt: str, prompt: str) -> str:
if "{prompt}" in style_prompt: if "{prompt}" in style_prompt:
res = style_prompt.replace("{prompt}", prompt) res = style_prompt.replace("{prompt}", prompt)
@@ -44,7 +27,7 @@ def apply_styles_to_prompt(prompt, styles):
for style in styles: for style in styles:
prompt = merge_prompts(style, prompt) prompt = merge_prompts(style, prompt)
return clean_text(prompt) return prompt
def unwrap_style_text_from_prompt(style_text, prompt): def unwrap_style_text_from_prompt(style_text, prompt):
@@ -56,8 +39,8 @@ def unwrap_style_text_from_prompt(style_text, prompt):
Note that the "cleaned" version of the style text is only used for matching Note that the "cleaned" version of the style text is only used for matching
purposes here. It isn't returned; the original style text is not modified. purposes here. It isn't returned; the original style text is not modified.
""" """
stripped_prompt = clean_text(prompt) stripped_prompt = prompt
stripped_style_text = clean_text(style_text) stripped_style_text = style_text
if "{prompt}" in stripped_style_text: if "{prompt}" in stripped_style_text:
# Work out whether the prompt is wrapped in the style text. If so, we # Work out whether the prompt is wrapped in the style text. If so, we
# return True and the "inner" prompt text that isn't part of the style. # return True and the "inner" prompt text that isn't part of the style.
@@ -115,10 +98,8 @@ class StyleDatabase:
self.path = path self.path = path
folder, file = os.path.split(self.path) folder, file = os.path.split(self.path)
self.default_file = file.split("*")[0] + ".csv" filename, _, ext = file.partition('*')
if self.default_file == ".csv": self.default_path = os.path.join(folder, filename + ext)
self.default_file = "styles.csv"
self.default_path = os.path.join(folder, self.default_file)
self.prompt_fields = [field for field in PromptStyle._fields if field != "path"] self.prompt_fields = [field for field in PromptStyle._fields if field != "path"]
@@ -172,10 +153,8 @@ class StyleDatabase:
row["name"], prompt, negative_prompt, path row["name"], prompt, negative_prompt, path
) )
def get_style_paths(self) -> list(): def get_style_paths(self) -> set:
""" """Returns a set of all distinct paths of files that styles are loaded from."""
Returns a list of all distinct paths, including the default path, of
files that styles are loaded from."""
# Update any styles without a path to the default path # Update any styles without a path to the default path
for style in list(self.styles.values()): for style in list(self.styles.values()):
if not style.path: if not style.path:
@@ -189,9 +168,9 @@ class StyleDatabase:
style_paths.add(style.path) style_paths.add(style.path)
# Remove any paths for styles that are just list dividers # Remove any paths for styles that are just list dividers
style_paths.remove("do_not_save") style_paths.discard("do_not_save")
return list(style_paths) return style_paths
def get_style_prompts(self, styles): def get_style_prompts(self, styles):
return [self.styles.get(x, self.no_style).prompt for x in styles] return [self.styles.get(x, self.no_style).prompt for x in styles]
@@ -213,20 +192,7 @@ class StyleDatabase:
# The path argument is deprecated, but kept for backwards compatibility # The path argument is deprecated, but kept for backwards compatibility
_ = path _ = path
# Update any styles without a path to the default path style_paths = self.get_style_paths()
for style in list(self.styles.values()):
if not style.path:
self.styles[style.name] = style._replace(path=self.default_path)
# Create a list of all distinct paths, including the default path
style_paths = set()
style_paths.add(self.default_path)
for _, style in self.styles.items():
if style.path:
style_paths.add(style.path)
# Remove any paths for styles that are just list dividers
style_paths.remove("do_not_save")
csv_names = [os.path.split(path)[1].lower() for path in style_paths] csv_names = [os.path.split(path)[1].lower() for path in style_paths]
+1 -1
View File
@@ -35,7 +35,7 @@ def create_ui():
tab_batch_dir.select(fn=lambda: 2, inputs=[], outputs=[tab_index]) tab_batch_dir.select(fn=lambda: 2, inputs=[], outputs=[tab_index])
submit.click( submit.click(
fn=call_queue.wrap_gradio_gpu_call(postprocessing.run_postprocessing, extra_outputs=[None, '']), fn=call_queue.wrap_gradio_gpu_call(postprocessing.run_postprocessing_webui, extra_outputs=[None, '']),
_js="submit_extras", _js="submit_extras",
inputs=[ inputs=[
dummy_component, dummy_component,
+9
View File
@@ -48,3 +48,12 @@ if has_xpu:
CondFunc('torch.nn.modules.conv.Conv2d.forward', CondFunc('torch.nn.modules.conv.Conv2d.forward',
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
lambda orig_func, self, input: input.dtype != self.weight.data.dtype) lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
CondFunc('torch.bmm',
lambda orig_func, input, mat2, out=None: orig_func(input.to(mat2.dtype), mat2, out=out),
lambda orig_func, input, mat2, out=None: input.dtype != mat2.dtype)
CondFunc('torch.cat',
lambda orig_func, tensors, dim=0, out=None: orig_func([t.to(tensors[0].dtype) for t in tensors], dim=dim, out=out),
lambda orig_func, tensors, dim=0, out=None: not all(t.dtype == tensors[0].dtype for t in tensors))
CondFunc('torch.nn.functional.scaled_dot_product_attention',
lambda orig_func, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False: orig_func(query, key.to(query.dtype), value.to(query.dtype), attn_mask, dropout_p, is_causal),
lambda orig_func, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False: query.dtype != key.dtype or query.dtype != value.dtype)