Compare commits
15 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| cf2772fab0 | |||
| 0dfffe53ec | |||
| 2be85f8fe0 | |||
| eb52c803b8 | |||
| f8871dedcf | |||
| b7e0d4a7e1 | |||
| 5cb1ce470d | |||
| 888b928f0d | |||
| b55f09c4e1 | |||
| c7cd9b441d | |||
| 6ef0ff39f2 | |||
| 120a84bd2f | |||
| 368d66c9cc | |||
| 81105ee013 | |||
| 24dae9bc4c |
@@ -21,6 +21,8 @@ class NetworkModuleOFT(network.NetworkModule):
|
|||||||
self.lin_module = None
|
self.lin_module = None
|
||||||
self.org_module: list[torch.Module] = [self.sd_module]
|
self.org_module: list[torch.Module] = [self.sd_module]
|
||||||
|
|
||||||
|
self.scale = 1.0
|
||||||
|
|
||||||
# kohya-ss
|
# kohya-ss
|
||||||
if "oft_blocks" in weights.w.keys():
|
if "oft_blocks" in weights.w.keys():
|
||||||
self.is_kohya = True
|
self.is_kohya = True
|
||||||
@@ -53,12 +55,18 @@ class NetworkModuleOFT(network.NetworkModule):
|
|||||||
self.constraint = None
|
self.constraint = None
|
||||||
self.block_size, self.num_blocks = factorization(self.out_dim, self.dim)
|
self.block_size, self.num_blocks = factorization(self.out_dim, self.dim)
|
||||||
|
|
||||||
def calc_updown_kb(self, orig_weight, multiplier):
|
def calc_updown(self, orig_weight):
|
||||||
oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype)
|
oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||||
oft_blocks = oft_blocks - oft_blocks.transpose(1, 2) # ensure skew-symmetric orthogonal matrix
|
eye = torch.eye(self.block_size, device=self.oft_blocks.device)
|
||||||
|
|
||||||
|
if self.is_kohya:
|
||||||
|
block_Q = oft_blocks - oft_blocks.transpose(1, 2) # ensure skew-symmetric orthogonal matrix
|
||||||
|
norm_Q = torch.norm(block_Q.flatten())
|
||||||
|
new_norm_Q = torch.clamp(norm_Q, max=self.constraint)
|
||||||
|
block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8))
|
||||||
|
oft_blocks = torch.matmul(eye + block_Q, (eye - block_Q).float().inverse())
|
||||||
|
|
||||||
R = oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype)
|
R = oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||||
R = R * multiplier + torch.eye(self.block_size, device=orig_weight.device)
|
|
||||||
|
|
||||||
# This errors out for MultiheadAttention, might need to be handled up-stream
|
# This errors out for MultiheadAttention, might need to be handled up-stream
|
||||||
merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size)
|
merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size)
|
||||||
@@ -72,26 +80,3 @@ class NetworkModuleOFT(network.NetworkModule):
|
|||||||
updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight
|
updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight
|
||||||
output_shape = orig_weight.shape
|
output_shape = orig_weight.shape
|
||||||
return self.finalize_updown(updown, orig_weight, output_shape)
|
return self.finalize_updown(updown, orig_weight, output_shape)
|
||||||
|
|
||||||
def calc_updown(self, orig_weight):
|
|
||||||
# if alpha is a very small number as in coft, calc_scale() will return a almost zero number so we ignore it
|
|
||||||
multiplier = self.multiplier()
|
|
||||||
return self.calc_updown_kb(orig_weight, multiplier)
|
|
||||||
|
|
||||||
# override to remove the multiplier/scale factor; it's already multiplied in get_weight
|
|
||||||
def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None):
|
|
||||||
if self.bias is not None:
|
|
||||||
updown = updown.reshape(self.bias.shape)
|
|
||||||
updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype)
|
|
||||||
updown = updown.reshape(output_shape)
|
|
||||||
|
|
||||||
if len(output_shape) == 4:
|
|
||||||
updown = updown.reshape(output_shape)
|
|
||||||
|
|
||||||
if orig_weight.size().numel() == updown.size().numel():
|
|
||||||
updown = updown.reshape(orig_weight.shape)
|
|
||||||
|
|
||||||
if ex_bias is not None:
|
|
||||||
ex_bias = ex_bias * self.multiplier()
|
|
||||||
|
|
||||||
return updown, ex_bias
|
|
||||||
|
|||||||
@@ -159,7 +159,8 @@ def load_network(name, network_on_disk):
|
|||||||
bundle_embeddings = {}
|
bundle_embeddings = {}
|
||||||
|
|
||||||
for key_network, weight in sd.items():
|
for key_network, weight in sd.items():
|
||||||
key_network_without_network_parts, network_part = key_network.split(".", 1)
|
key_network_without_network_parts, _, network_part = key_network.partition(".")
|
||||||
|
|
||||||
if key_network_without_network_parts == "bundle_emb":
|
if key_network_without_network_parts == "bundle_emb":
|
||||||
emb_name, vec_name = network_part.split(".", 1)
|
emb_name, vec_name = network_part.split(".", 1)
|
||||||
emb_dict = bundle_embeddings.get(emb_name, {})
|
emb_dict = bundle_embeddings.get(emb_name, {})
|
||||||
|
|||||||
@@ -23,11 +23,12 @@ class ExtraOptionsSection(scripts.Script):
|
|||||||
self.setting_names = []
|
self.setting_names = []
|
||||||
self.infotext_fields = []
|
self.infotext_fields = []
|
||||||
extra_options = shared.opts.extra_options_img2img if is_img2img else shared.opts.extra_options_txt2img
|
extra_options = shared.opts.extra_options_img2img if is_img2img else shared.opts.extra_options_txt2img
|
||||||
|
elem_id_tabname = "extra_options_" + ("img2img" if is_img2img else "txt2img")
|
||||||
|
|
||||||
mapping = {k: v for v, k in generation_parameters_copypaste.infotext_to_setting_name_mapping}
|
mapping = {k: v for v, k in generation_parameters_copypaste.infotext_to_setting_name_mapping}
|
||||||
|
|
||||||
with gr.Blocks() as interface:
|
with gr.Blocks() as interface:
|
||||||
with gr.Accordion("Options", open=False) if shared.opts.extra_options_accordion and extra_options else gr.Group():
|
with gr.Accordion("Options", open=False, elem_id=elem_id_tabname) if shared.opts.extra_options_accordion and extra_options else gr.Group(elem_id=elem_id_tabname):
|
||||||
|
|
||||||
row_count = math.ceil(len(extra_options) / shared.opts.extra_options_cols)
|
row_count = math.ceil(len(extra_options) / shared.opts.extra_options_cols)
|
||||||
|
|
||||||
@@ -70,7 +71,7 @@ This page allows you to add some settings to the main interface of txt2img and i
|
|||||||
"""),
|
"""),
|
||||||
"extra_options_txt2img": shared.OptionInfo([], "Settings for txt2img", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in txt2img interfaces").needs_reload_ui(),
|
"extra_options_txt2img": shared.OptionInfo([], "Settings for txt2img", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in txt2img interfaces").needs_reload_ui(),
|
||||||
"extra_options_img2img": shared.OptionInfo([], "Settings for img2img", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in img2img interfaces").needs_reload_ui(),
|
"extra_options_img2img": shared.OptionInfo([], "Settings for img2img", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in img2img interfaces").needs_reload_ui(),
|
||||||
"extra_options_cols": shared.OptionInfo(1, "Number of columns for added settings", gr.Number, {"precision": 0}).needs_reload_ui(),
|
"extra_options_cols": shared.OptionInfo(1, "Number of columns for added settings", gr.Slider, {"step": 1, "minimum": 1, "maximum": 20}).info("displayed amount will depend on the actual browser window width").needs_reload_ui(),
|
||||||
"extra_options_accordion": shared.OptionInfo(False, "Place added settings into an accordion").needs_reload_ui()
|
"extra_options_accordion": shared.OptionInfo(False, "Place added settings into an accordion").needs_reload_ui()
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
|||||||
@@ -17,11 +17,42 @@ class ScriptHypertile(scripts.Script):
|
|||||||
|
|
||||||
configure_hypertile(p.width, p.height, enable_unet=shared.opts.hypertile_enable_unet)
|
configure_hypertile(p.width, p.height, enable_unet=shared.opts.hypertile_enable_unet)
|
||||||
|
|
||||||
|
self.add_infotext(p)
|
||||||
|
|
||||||
def before_hr(self, p, *args):
|
def before_hr(self, p, *args):
|
||||||
|
|
||||||
|
enable = shared.opts.hypertile_enable_unet_secondpass or shared.opts.hypertile_enable_unet
|
||||||
|
|
||||||
# exclusive hypertile seed for the second pass
|
# exclusive hypertile seed for the second pass
|
||||||
if not shared.opts.hypertile_enable_unet:
|
if enable:
|
||||||
hypertile.set_hypertile_seed(p.all_seeds[0])
|
hypertile.set_hypertile_seed(p.all_seeds[0])
|
||||||
configure_hypertile(p.hr_upscale_to_x, p.hr_upscale_to_y, enable_unet=shared.opts.hypertile_enable_unet_secondpass)
|
|
||||||
|
configure_hypertile(p.hr_upscale_to_x, p.hr_upscale_to_y, enable_unet=enable)
|
||||||
|
|
||||||
|
if enable and not shared.opts.hypertile_enable_unet:
|
||||||
|
p.extra_generation_params["Hypertile U-Net second pass"] = True
|
||||||
|
|
||||||
|
self.add_infotext(p, add_unet_params=True)
|
||||||
|
|
||||||
|
def add_infotext(self, p, add_unet_params=False):
|
||||||
|
def option(name):
|
||||||
|
value = getattr(shared.opts, name)
|
||||||
|
default_value = shared.opts.get_default(name)
|
||||||
|
return None if value == default_value else value
|
||||||
|
|
||||||
|
if shared.opts.hypertile_enable_unet:
|
||||||
|
p.extra_generation_params["Hypertile U-Net"] = True
|
||||||
|
|
||||||
|
if shared.opts.hypertile_enable_unet or add_unet_params:
|
||||||
|
p.extra_generation_params["Hypertile U-Net max depth"] = option('hypertile_max_depth_unet')
|
||||||
|
p.extra_generation_params["Hypertile U-Net max tile size"] = option('hypertile_max_tile_unet')
|
||||||
|
p.extra_generation_params["Hypertile U-Net swap size"] = option('hypertile_swap_size_unet')
|
||||||
|
|
||||||
|
if shared.opts.hypertile_enable_vae:
|
||||||
|
p.extra_generation_params["Hypertile VAE"] = True
|
||||||
|
p.extra_generation_params["Hypertile VAE max depth"] = option('hypertile_max_depth_vae')
|
||||||
|
p.extra_generation_params["Hypertile VAE max tile size"] = option('hypertile_max_tile_vae')
|
||||||
|
p.extra_generation_params["Hypertile VAE swap size"] = option('hypertile_swap_size_vae')
|
||||||
|
|
||||||
|
|
||||||
def configure_hypertile(width, height, enable_unet=True):
|
def configure_hypertile(width, height, enable_unet=True):
|
||||||
@@ -57,16 +88,16 @@ def on_ui_settings():
|
|||||||
benefit.
|
benefit.
|
||||||
"""),
|
"""),
|
||||||
|
|
||||||
"hypertile_enable_unet": shared.OptionInfo(False, "Enable Hypertile U-Net").info("noticeable change in details of the generated picture; if enabled, overrides the setting below"),
|
"hypertile_enable_unet": shared.OptionInfo(False, "Enable Hypertile U-Net", infotext="Hypertile U-Net").info("enables hypertile for all modes, including hires fix second pass; noticeable change in details of the generated picture"),
|
||||||
"hypertile_enable_unet_secondpass": shared.OptionInfo(False, "Enable Hypertile U-Net for hires fix second pass"),
|
"hypertile_enable_unet_secondpass": shared.OptionInfo(False, "Enable Hypertile U-Net for hires fix second pass", infotext="Hypertile U-Net second pass").info("enables hypertile just for hires fix second pass - regardless of whether the above setting is enabled"),
|
||||||
"hypertile_max_depth_unet": shared.OptionInfo(3, "Hypertile U-Net max depth", gr.Slider, {"minimum": 0, "maximum": 3, "step": 1}),
|
"hypertile_max_depth_unet": shared.OptionInfo(3, "Hypertile U-Net max depth", gr.Slider, {"minimum": 0, "maximum": 3, "step": 1}, infotext="Hypertile U-Net max depth").info("larger = more neural network layers affected; minor effect on performance"),
|
||||||
"hypertile_max_tile_unet": shared.OptionInfo(256, "Hypertile U-net max tile size", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}),
|
"hypertile_max_tile_unet": shared.OptionInfo(256, "Hypertile U-Net max tile size", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}, infotext="Hypertile U-Net max tile size").info("larger = worse performance"),
|
||||||
"hypertile_swap_size_unet": shared.OptionInfo(3, "Hypertile U-net swap size", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}),
|
"hypertile_swap_size_unet": shared.OptionInfo(3, "Hypertile U-Net swap size", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}, infotext="Hypertile U-Net swap size"),
|
||||||
|
|
||||||
"hypertile_enable_vae": shared.OptionInfo(False, "Enable Hypertile VAE").info("minimal change in the generated picture"),
|
"hypertile_enable_vae": shared.OptionInfo(False, "Enable Hypertile VAE", infotext="Hypertile VAE").info("minimal change in the generated picture"),
|
||||||
"hypertile_max_depth_vae": shared.OptionInfo(3, "Hypertile VAE max depth", gr.Slider, {"minimum": 0, "maximum": 3, "step": 1}),
|
"hypertile_max_depth_vae": shared.OptionInfo(3, "Hypertile VAE max depth", gr.Slider, {"minimum": 0, "maximum": 3, "step": 1}, infotext="Hypertile VAE max depth"),
|
||||||
"hypertile_max_tile_vae": shared.OptionInfo(128, "Hypertile VAE max tile size", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}),
|
"hypertile_max_tile_vae": shared.OptionInfo(128, "Hypertile VAE max tile size", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}, infotext="Hypertile VAE max tile size"),
|
||||||
"hypertile_swap_size_vae": shared.OptionInfo(3, "Hypertile VAE swap size ", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}),
|
"hypertile_swap_size_vae": shared.OptionInfo(3, "Hypertile VAE swap size ", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}, infotext="Hypertile VAE swap size"),
|
||||||
}
|
}
|
||||||
|
|
||||||
for name, opt in options.items():
|
for name, opt in options.items():
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ function updateOnBackgroundChange() {
|
|||||||
if (modalImage && modalImage.offsetParent) {
|
if (modalImage && modalImage.offsetParent) {
|
||||||
let currentButton = selected_gallery_button();
|
let currentButton = selected_gallery_button();
|
||||||
let preview = gradioApp().querySelectorAll('.livePreview > img');
|
let preview = gradioApp().querySelectorAll('.livePreview > img');
|
||||||
if (preview.length > 0) {
|
if (opts.js_live_preview_in_modal_lightbox && preview.length > 0) {
|
||||||
// show preview image if available
|
// show preview image if available
|
||||||
modalImage.src = preview[preview.length - 1].src;
|
modalImage.src = preview[preview.length - 1].src;
|
||||||
} else if (currentButton?.children?.length > 0 && modalImage.src != currentButton.children[0].src) {
|
} else if (currentButton?.children?.length > 0 && modalImage.src != currentButton.children[0].src) {
|
||||||
|
|||||||
@@ -215,9 +215,33 @@ function restoreProgressImg2img() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Configure the width and height elements on `tabname` to accept
|
||||||
|
* pasting of resolutions in the form of "width x height".
|
||||||
|
*/
|
||||||
|
function setupResolutionPasting(tabname) {
|
||||||
|
var width = gradioApp().querySelector(`#${tabname}_width input[type=number]`);
|
||||||
|
var height = gradioApp().querySelector(`#${tabname}_height input[type=number]`);
|
||||||
|
for (const el of [width, height]) {
|
||||||
|
el.addEventListener('paste', function(event) {
|
||||||
|
var pasteData = event.clipboardData.getData('text/plain');
|
||||||
|
var parsed = pasteData.match(/^\s*(\d+)\D+(\d+)\s*$/);
|
||||||
|
if (parsed) {
|
||||||
|
width.value = parsed[1];
|
||||||
|
height.value = parsed[2];
|
||||||
|
updateInput(width);
|
||||||
|
updateInput(height);
|
||||||
|
event.preventDefault();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
onUiLoaded(function() {
|
onUiLoaded(function() {
|
||||||
showRestoreProgressButton('txt2img', localGet("txt2img_task_id"));
|
showRestoreProgressButton('txt2img', localGet("txt2img_task_id"));
|
||||||
showRestoreProgressButton('img2img', localGet("img2img_task_id"));
|
showRestoreProgressButton('img2img', localGet("img2img_task_id"));
|
||||||
|
setupResolutionPasting('txt2img');
|
||||||
|
setupResolutionPasting('img2img');
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from modules import shared, images, devices, scripts, scripts_postprocessing, ui
|
|||||||
from modules.shared import opts
|
from modules.shared import opts
|
||||||
|
|
||||||
|
|
||||||
def run_postprocessing(id_task, extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output: bool = True):
|
def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output: bool = True):
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
|
||||||
shared.state.begin(job="extras")
|
shared.state.begin(job="extras")
|
||||||
@@ -128,6 +128,10 @@ def run_postprocessing(id_task, extras_mode, image, image_folder, input_dir, out
|
|||||||
return outputs, ui_common.plaintext_to_html(infotext), ''
|
return outputs, ui_common.plaintext_to_html(infotext), ''
|
||||||
|
|
||||||
|
|
||||||
|
def run_postprocessing_webui(id_task, *args, **kwargs):
|
||||||
|
return run_postprocessing(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True):
|
def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True):
|
||||||
"""old handler for API"""
|
"""old handler for API"""
|
||||||
|
|
||||||
|
|||||||
@@ -215,7 +215,7 @@ class LoadStateDictOnMeta(ReplaceHelper):
|
|||||||
would be on the meta device.
|
would be on the meta device.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if state_dict == sd:
|
if state_dict is sd:
|
||||||
state_dict = {k: v.to(device="meta", dtype=v.dtype) for k, v in state_dict.items()}
|
state_dict = {k: v.to(device="meta", dtype=v.dtype) for k, v in state_dict.items()}
|
||||||
|
|
||||||
original(module, state_dict, strict=strict)
|
original(module, state_dict, strict=strict)
|
||||||
|
|||||||
@@ -256,6 +256,7 @@ options_templates.update(options_section(('ui_prompt_editing', "Prompt editing",
|
|||||||
"keyedit_precision_extra": OptionInfo(0.05, "Precision for <extra networks:0.9> when editing the prompt with Ctrl+up/down", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
|
"keyedit_precision_extra": OptionInfo(0.05, "Precision for <extra networks:0.9> when editing the prompt with Ctrl+up/down", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
|
||||||
"keyedit_delimiters": OptionInfo(r".,\/!?%^*;:{}=`~() ", "Word delimiters when editing the prompt with Ctrl+up/down"),
|
"keyedit_delimiters": OptionInfo(r".,\/!?%^*;:{}=`~() ", "Word delimiters when editing the prompt with Ctrl+up/down"),
|
||||||
"keyedit_delimiters_whitespace": OptionInfo(["Tab", "Carriage Return", "Line Feed"], "Ctrl+up/down whitespace delimiters", gr.CheckboxGroup, lambda: {"choices": ["Tab", "Carriage Return", "Line Feed"]}),
|
"keyedit_delimiters_whitespace": OptionInfo(["Tab", "Carriage Return", "Line Feed"], "Ctrl+up/down whitespace delimiters", gr.CheckboxGroup, lambda: {"choices": ["Tab", "Carriage Return", "Line Feed"]}),
|
||||||
|
"keyedit_move": OptionInfo(True, "Alt+left/right moves prompt elements"),
|
||||||
"disable_token_counters": OptionInfo(False, "Disable prompt token counters").needs_reload_ui(),
|
"disable_token_counters": OptionInfo(False, "Disable prompt token counters").needs_reload_ui(),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
@@ -330,6 +331,7 @@ options_templates.update(options_section(('ui', "Live previews", "ui"), {
|
|||||||
"live_preview_content": OptionInfo("Prompt", "Live preview subject", gr.Radio, {"choices": ["Combined", "Prompt", "Negative prompt"]}),
|
"live_preview_content": OptionInfo("Prompt", "Live preview subject", gr.Radio, {"choices": ["Combined", "Prompt", "Negative prompt"]}),
|
||||||
"live_preview_refresh_period": OptionInfo(1000, "Progressbar and preview update period").info("in milliseconds"),
|
"live_preview_refresh_period": OptionInfo(1000, "Progressbar and preview update period").info("in milliseconds"),
|
||||||
"live_preview_fast_interrupt": OptionInfo(False, "Return image with chosen live preview method on interrupt").info("makes interrupts faster"),
|
"live_preview_fast_interrupt": OptionInfo(False, "Return image with chosen live preview method on interrupt").info("makes interrupts faster"),
|
||||||
|
"js_live_preview_in_modal_lightbox": OptionInfo(False, "Show Live preview in full page image viewer"),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('sampler-params', "Sampler parameters", "sd"), {
|
options_templates.update(options_section(('sampler-params', "Sampler parameters", "sd"), {
|
||||||
|
|||||||
+10
-44
@@ -2,7 +2,6 @@ import csv
|
|||||||
import fnmatch
|
import fnmatch
|
||||||
import os
|
import os
|
||||||
import os.path
|
import os.path
|
||||||
import re
|
|
||||||
import typing
|
import typing
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
@@ -14,22 +13,6 @@ class PromptStyle(typing.NamedTuple):
|
|||||||
path: str = None
|
path: str = None
|
||||||
|
|
||||||
|
|
||||||
def clean_text(text: str) -> str:
|
|
||||||
"""
|
|
||||||
Iterating through a list of regular expressions and replacement strings, we
|
|
||||||
clean up the prompt and style text to make it easier to match against each
|
|
||||||
other.
|
|
||||||
"""
|
|
||||||
re_list = [
|
|
||||||
("multiple commas", re.compile("(,+\s+)+,?"), ", "),
|
|
||||||
("multiple spaces", re.compile("\s{2,}"), " "),
|
|
||||||
]
|
|
||||||
for _, regex, replace in re_list:
|
|
||||||
text = regex.sub(replace, text)
|
|
||||||
|
|
||||||
return text.strip(", ")
|
|
||||||
|
|
||||||
|
|
||||||
def merge_prompts(style_prompt: str, prompt: str) -> str:
|
def merge_prompts(style_prompt: str, prompt: str) -> str:
|
||||||
if "{prompt}" in style_prompt:
|
if "{prompt}" in style_prompt:
|
||||||
res = style_prompt.replace("{prompt}", prompt)
|
res = style_prompt.replace("{prompt}", prompt)
|
||||||
@@ -44,7 +27,7 @@ def apply_styles_to_prompt(prompt, styles):
|
|||||||
for style in styles:
|
for style in styles:
|
||||||
prompt = merge_prompts(style, prompt)
|
prompt = merge_prompts(style, prompt)
|
||||||
|
|
||||||
return clean_text(prompt)
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
def unwrap_style_text_from_prompt(style_text, prompt):
|
def unwrap_style_text_from_prompt(style_text, prompt):
|
||||||
@@ -56,8 +39,8 @@ def unwrap_style_text_from_prompt(style_text, prompt):
|
|||||||
Note that the "cleaned" version of the style text is only used for matching
|
Note that the "cleaned" version of the style text is only used for matching
|
||||||
purposes here. It isn't returned; the original style text is not modified.
|
purposes here. It isn't returned; the original style text is not modified.
|
||||||
"""
|
"""
|
||||||
stripped_prompt = clean_text(prompt)
|
stripped_prompt = prompt
|
||||||
stripped_style_text = clean_text(style_text)
|
stripped_style_text = style_text
|
||||||
if "{prompt}" in stripped_style_text:
|
if "{prompt}" in stripped_style_text:
|
||||||
# Work out whether the prompt is wrapped in the style text. If so, we
|
# Work out whether the prompt is wrapped in the style text. If so, we
|
||||||
# return True and the "inner" prompt text that isn't part of the style.
|
# return True and the "inner" prompt text that isn't part of the style.
|
||||||
@@ -115,10 +98,8 @@ class StyleDatabase:
|
|||||||
self.path = path
|
self.path = path
|
||||||
|
|
||||||
folder, file = os.path.split(self.path)
|
folder, file = os.path.split(self.path)
|
||||||
self.default_file = file.split("*")[0] + ".csv"
|
filename, _, ext = file.partition('*')
|
||||||
if self.default_file == ".csv":
|
self.default_path = os.path.join(folder, filename + ext)
|
||||||
self.default_file = "styles.csv"
|
|
||||||
self.default_path = os.path.join(folder, self.default_file)
|
|
||||||
|
|
||||||
self.prompt_fields = [field for field in PromptStyle._fields if field != "path"]
|
self.prompt_fields = [field for field in PromptStyle._fields if field != "path"]
|
||||||
|
|
||||||
@@ -172,10 +153,8 @@ class StyleDatabase:
|
|||||||
row["name"], prompt, negative_prompt, path
|
row["name"], prompt, negative_prompt, path
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_style_paths(self) -> list():
|
def get_style_paths(self) -> set:
|
||||||
"""
|
"""Returns a set of all distinct paths of files that styles are loaded from."""
|
||||||
Returns a list of all distinct paths, including the default path, of
|
|
||||||
files that styles are loaded from."""
|
|
||||||
# Update any styles without a path to the default path
|
# Update any styles without a path to the default path
|
||||||
for style in list(self.styles.values()):
|
for style in list(self.styles.values()):
|
||||||
if not style.path:
|
if not style.path:
|
||||||
@@ -189,9 +168,9 @@ class StyleDatabase:
|
|||||||
style_paths.add(style.path)
|
style_paths.add(style.path)
|
||||||
|
|
||||||
# Remove any paths for styles that are just list dividers
|
# Remove any paths for styles that are just list dividers
|
||||||
style_paths.remove("do_not_save")
|
style_paths.discard("do_not_save")
|
||||||
|
|
||||||
return list(style_paths)
|
return style_paths
|
||||||
|
|
||||||
def get_style_prompts(self, styles):
|
def get_style_prompts(self, styles):
|
||||||
return [self.styles.get(x, self.no_style).prompt for x in styles]
|
return [self.styles.get(x, self.no_style).prompt for x in styles]
|
||||||
@@ -213,20 +192,7 @@ class StyleDatabase:
|
|||||||
# The path argument is deprecated, but kept for backwards compatibility
|
# The path argument is deprecated, but kept for backwards compatibility
|
||||||
_ = path
|
_ = path
|
||||||
|
|
||||||
# Update any styles without a path to the default path
|
style_paths = self.get_style_paths()
|
||||||
for style in list(self.styles.values()):
|
|
||||||
if not style.path:
|
|
||||||
self.styles[style.name] = style._replace(path=self.default_path)
|
|
||||||
|
|
||||||
# Create a list of all distinct paths, including the default path
|
|
||||||
style_paths = set()
|
|
||||||
style_paths.add(self.default_path)
|
|
||||||
for _, style in self.styles.items():
|
|
||||||
if style.path:
|
|
||||||
style_paths.add(style.path)
|
|
||||||
|
|
||||||
# Remove any paths for styles that are just list dividers
|
|
||||||
style_paths.remove("do_not_save")
|
|
||||||
|
|
||||||
csv_names = [os.path.split(path)[1].lower() for path in style_paths]
|
csv_names = [os.path.split(path)[1].lower() for path in style_paths]
|
||||||
|
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ def create_ui():
|
|||||||
tab_batch_dir.select(fn=lambda: 2, inputs=[], outputs=[tab_index])
|
tab_batch_dir.select(fn=lambda: 2, inputs=[], outputs=[tab_index])
|
||||||
|
|
||||||
submit.click(
|
submit.click(
|
||||||
fn=call_queue.wrap_gradio_gpu_call(postprocessing.run_postprocessing, extra_outputs=[None, '']),
|
fn=call_queue.wrap_gradio_gpu_call(postprocessing.run_postprocessing_webui, extra_outputs=[None, '']),
|
||||||
_js="submit_extras",
|
_js="submit_extras",
|
||||||
inputs=[
|
inputs=[
|
||||||
dummy_component,
|
dummy_component,
|
||||||
|
|||||||
@@ -48,3 +48,12 @@ if has_xpu:
|
|||||||
CondFunc('torch.nn.modules.conv.Conv2d.forward',
|
CondFunc('torch.nn.modules.conv.Conv2d.forward',
|
||||||
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
|
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
|
||||||
lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
|
lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
|
||||||
|
CondFunc('torch.bmm',
|
||||||
|
lambda orig_func, input, mat2, out=None: orig_func(input.to(mat2.dtype), mat2, out=out),
|
||||||
|
lambda orig_func, input, mat2, out=None: input.dtype != mat2.dtype)
|
||||||
|
CondFunc('torch.cat',
|
||||||
|
lambda orig_func, tensors, dim=0, out=None: orig_func([t.to(tensors[0].dtype) for t in tensors], dim=dim, out=out),
|
||||||
|
lambda orig_func, tensors, dim=0, out=None: not all(t.dtype == tensors[0].dtype for t in tensors))
|
||||||
|
CondFunc('torch.nn.functional.scaled_dot_product_attention',
|
||||||
|
lambda orig_func, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False: orig_func(query, key.to(query.dtype), value.to(query.dtype), attn_mask, dropout_p, is_causal),
|
||||||
|
lambda orig_func, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False: query.dtype != key.dtype or query.dtype != value.dtype)
|
||||||
|
|||||||
Reference in New Issue
Block a user