Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| ac8c05398b | |||
| 025080218f | |||
| 023454b49e | |||
| cd869bb7a3 |
@@ -226,8 +226,6 @@ onUiLoaded(async() => {
|
|||||||
canvas_show_tooltip: true,
|
canvas_show_tooltip: true,
|
||||||
canvas_auto_expand: true,
|
canvas_auto_expand: true,
|
||||||
canvas_blur_prompt: false,
|
canvas_blur_prompt: false,
|
||||||
canvas_hotkey_undo: "KeyZ",
|
|
||||||
canvas_hotkey_clear: "KeyC",
|
|
||||||
};
|
};
|
||||||
|
|
||||||
const functionMap = {
|
const functionMap = {
|
||||||
@@ -238,9 +236,7 @@ onUiLoaded(async() => {
|
|||||||
"Moving canvas": "canvas_hotkey_move",
|
"Moving canvas": "canvas_hotkey_move",
|
||||||
"Fullscreen": "canvas_hotkey_fullscreen",
|
"Fullscreen": "canvas_hotkey_fullscreen",
|
||||||
"Reset Zoom": "canvas_hotkey_reset",
|
"Reset Zoom": "canvas_hotkey_reset",
|
||||||
"Overlap": "canvas_hotkey_overlap",
|
"Overlap": "canvas_hotkey_overlap"
|
||||||
"Undo": "canvas_hotkey_undo",
|
|
||||||
"Clear": "canvas_hotkey_clear"
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// Loading the configuration from opts
|
// Loading the configuration from opts
|
||||||
@@ -325,8 +321,6 @@ onUiLoaded(async() => {
|
|||||||
action: "Adjust brush size",
|
action: "Adjust brush size",
|
||||||
keySuffix: " + wheel"
|
keySuffix: " + wheel"
|
||||||
},
|
},
|
||||||
{configKey: "canvas_hotkey_undo", action: "Undo brush stroke"},
|
|
||||||
{configKey: "canvas_hotkey_clear", action: "Clear canvas"},
|
|
||||||
{configKey: "canvas_hotkey_reset", action: "Reset zoom"},
|
{configKey: "canvas_hotkey_reset", action: "Reset zoom"},
|
||||||
{
|
{
|
||||||
configKey: "canvas_hotkey_fullscreen",
|
configKey: "canvas_hotkey_fullscreen",
|
||||||
@@ -470,45 +464,22 @@ onUiLoaded(async() => {
|
|||||||
gradioApp().querySelector(
|
gradioApp().querySelector(
|
||||||
`${elemId} button[aria-label="Use brush"]`
|
`${elemId} button[aria-label="Use brush"]`
|
||||||
);
|
);
|
||||||
|
|
||||||
if (input) {
|
if (input) {
|
||||||
input.click();
|
input.click();
|
||||||
if (!withoutValue) {
|
if (!withoutValue) {
|
||||||
const maxValue = parseFloat(input.getAttribute("max")) || 100;
|
const maxValue =
|
||||||
const minValue = parseFloat(input.getAttribute("min")) || 1;
|
parseFloat(input.getAttribute("max")) || 100;
|
||||||
// allow brush size up to 1/2 diagonal of the image, beyond gradio's arbitrary limit
|
const changeAmount = maxValue * (percentage / 100);
|
||||||
const canvasImg = gradioApp().querySelector(`${elemId} img`);
|
const newValue =
|
||||||
if (canvasImg) {
|
parseFloat(input.value) +
|
||||||
const maxDiameter = Math.sqrt(canvasImg.naturalWidth ** 2 + canvasImg.naturalHeight ** 2) / 2;
|
(deltaY > 0 ? -changeAmount : changeAmount);
|
||||||
if (maxDiameter > maxValue) {
|
input.value = Math.min(Math.max(newValue, 0), maxValue);
|
||||||
input.setAttribute("max", maxDiameter);
|
|
||||||
}
|
|
||||||
if (minValue > 1) {
|
|
||||||
input.setAttribute("min", '1');
|
|
||||||
}
|
|
||||||
}
|
|
||||||
const brush_factor = deltaY > 0 ? 1 - opts.canvas_hotkey_brush_factor : 1 + opts.canvas_hotkey_brush_factor;
|
|
||||||
const currentRadius = parseFloat(input.value);
|
|
||||||
let delta = Math.sqrt(currentRadius ** 2 * brush_factor) - currentRadius;
|
|
||||||
// minimum brush size step of 1
|
|
||||||
if (Math.abs(delta) < 1) {
|
|
||||||
delta = deltaY > 0 ? -1 : 1;
|
|
||||||
}
|
|
||||||
const newValue = currentRadius + delta;
|
|
||||||
input.value = Math.max(newValue, 1);
|
|
||||||
input.dispatchEvent(new Event("change"));
|
input.dispatchEvent(new Event("change"));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Undo the last brush stroke by clicking the undo button
|
|
||||||
function undoBrushStroke() {
|
|
||||||
gradioApp().querySelector(`${elemId} button[aria-label='Undo']`).click();
|
|
||||||
}
|
|
||||||
|
|
||||||
function clearCanvas() {
|
|
||||||
gradioApp().querySelector(`${elemId} button[aria-label='Clear']`).click();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Reset zoom when uploading a new image
|
// Reset zoom when uploading a new image
|
||||||
const fileInput = gradioApp().querySelector(
|
const fileInput = gradioApp().querySelector(
|
||||||
`${elemId} input[type="file"][accept="image/*"].svelte-116rqfv`
|
`${elemId} input[type="file"][accept="image/*"].svelte-116rqfv`
|
||||||
@@ -728,9 +699,7 @@ onUiLoaded(async() => {
|
|||||||
[hotkeysConfig.canvas_hotkey_overlap]: toggleOverlap,
|
[hotkeysConfig.canvas_hotkey_overlap]: toggleOverlap,
|
||||||
[hotkeysConfig.canvas_hotkey_fullscreen]: fitToScreen,
|
[hotkeysConfig.canvas_hotkey_fullscreen]: fitToScreen,
|
||||||
[hotkeysConfig.canvas_hotkey_shrink_brush]: () => adjustBrushSize(elemId, 10),
|
[hotkeysConfig.canvas_hotkey_shrink_brush]: () => adjustBrushSize(elemId, 10),
|
||||||
[hotkeysConfig.canvas_hotkey_grow_brush]: () => adjustBrushSize(elemId, -10),
|
[hotkeysConfig.canvas_hotkey_grow_brush]: () => adjustBrushSize(elemId, -10)
|
||||||
[hotkeysConfig.canvas_hotkey_undo]: undoBrushStroke,
|
|
||||||
[hotkeysConfig.canvas_hotkey_clear]: clearCanvas
|
|
||||||
};
|
};
|
||||||
|
|
||||||
const action = hotkeyActions[event.code];
|
const action = hotkeyActions[event.code];
|
||||||
|
|||||||
@@ -7,14 +7,11 @@ shared.options_templates.update(shared.options_section(('canvas_hotkey', "Canvas
|
|||||||
"canvas_hotkey_shrink_brush": shared.OptionInfo("Q", "Shrink the brush size"),
|
"canvas_hotkey_shrink_brush": shared.OptionInfo("Q", "Shrink the brush size"),
|
||||||
"canvas_hotkey_grow_brush": shared.OptionInfo("W", "Enlarge the brush size"),
|
"canvas_hotkey_grow_brush": shared.OptionInfo("W", "Enlarge the brush size"),
|
||||||
"canvas_hotkey_move": shared.OptionInfo("F", "Moving the canvas").info("To work correctly in firefox, turn off 'Automatically search the page text when typing' in the browser settings"),
|
"canvas_hotkey_move": shared.OptionInfo("F", "Moving the canvas").info("To work correctly in firefox, turn off 'Automatically search the page text when typing' in the browser settings"),
|
||||||
"canvas_hotkey_undo": shared.OptionInfo("Z", "Undo brush stroke"),
|
|
||||||
"canvas_hotkey_clear": shared.OptionInfo("C", "Clear canvas"),
|
|
||||||
"canvas_hotkey_fullscreen": shared.OptionInfo("S", "Fullscreen Mode, maximizes the picture so that it fits into the screen and stretches it to its full width "),
|
"canvas_hotkey_fullscreen": shared.OptionInfo("S", "Fullscreen Mode, maximizes the picture so that it fits into the screen and stretches it to its full width "),
|
||||||
"canvas_hotkey_reset": shared.OptionInfo("R", "Reset zoom and canvas position"),
|
"canvas_hotkey_reset": shared.OptionInfo("R", "Reset zoom and canvas position"),
|
||||||
"canvas_hotkey_overlap": shared.OptionInfo("O", "Toggle overlap").info("Technical button, needed for testing"),
|
"canvas_hotkey_overlap": shared.OptionInfo("O", "Toggle overlap").info("Technical button, needed for testing"),
|
||||||
"canvas_show_tooltip": shared.OptionInfo(True, "Enable tooltip on the canvas"),
|
"canvas_show_tooltip": shared.OptionInfo(True, "Enable tooltip on the canvas"),
|
||||||
"canvas_auto_expand": shared.OptionInfo(True, "Automatically expands an image that does not fit completely in the canvas area, similar to manually pressing the S and R buttons"),
|
"canvas_auto_expand": shared.OptionInfo(True, "Automatically expands an image that does not fit completely in the canvas area, similar to manually pressing the S and R buttons"),
|
||||||
"canvas_blur_prompt": shared.OptionInfo(False, "Take the focus off the prompt when working with a canvas"),
|
"canvas_blur_prompt": shared.OptionInfo(False, "Take the focus off the prompt when working with a canvas"),
|
||||||
"canvas_disabled_functions": shared.OptionInfo(["Overlap"], "Disable function that you don't use", gr.CheckboxGroup, {"choices": ["Zoom", "Adjust brush size", "Hotkey enlarge brush", "Hotkey shrink brush", "Undo", "Clear", "Moving canvas", "Fullscreen", "Reset Zoom", "Overlap"]}),
|
"canvas_disabled_functions": shared.OptionInfo(["Overlap"], "Disable function that you don't use", gr.CheckboxGroup, {"choices": ["Zoom","Adjust brush size","Hotkey enlarge brush","Hotkey shrink brush","Moving canvas","Fullscreen","Reset Zoom","Overlap"]}),
|
||||||
"canvas_hotkey_brush_factor": shared.OptionInfo(0.1, "Brush size change rate", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}).info('controls how much the brush size is changed when using hotkeys or scroll wheel'),
|
|
||||||
}))
|
}))
|
||||||
|
|||||||
@@ -4,11 +4,11 @@
|
|||||||
// If there's a mismatch, the keyword counter turns red and if you hover on it, a tooltip tells you what's wrong.
|
// If there's a mismatch, the keyword counter turns red and if you hover on it, a tooltip tells you what's wrong.
|
||||||
|
|
||||||
function checkBrackets(textArea, counterElt) {
|
function checkBrackets(textArea, counterElt) {
|
||||||
var counts = {};
|
const counts = {};
|
||||||
(textArea.value.match(/[(){}[\]]/g) || []).forEach(bracket => {
|
textArea.value.matchAll(/(?<!\\)(?:\\\\)*?([(){}[\]])/g).forEach(bracket => {
|
||||||
counts[bracket] = (counts[bracket] || 0) + 1;
|
counts[bracket[1]] = (counts[bracket[1]] || 0) + 1;
|
||||||
});
|
});
|
||||||
var errors = [];
|
const errors = [];
|
||||||
|
|
||||||
function checkPair(open, close, kind) {
|
function checkPair(open, close, kind) {
|
||||||
if (counts[open] !== counts[close]) {
|
if (counts[open] !== counts[close]) {
|
||||||
|
|||||||
+34
-6
@@ -16,7 +16,7 @@ from skimage import exposure
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import modules.sd_hijack
|
import modules.sd_hijack
|
||||||
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, infotext_utils, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng, profiling
|
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, infotext_utils, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng, profiling, util
|
||||||
from modules.rng import slerp # noqa: F401
|
from modules.rng import slerp # noqa: F401
|
||||||
from modules.sd_hijack import model_hijack
|
from modules.sd_hijack import model_hijack
|
||||||
from modules.sd_samplers_common import images_tensor_to_samples, decode_first_stage, approximation_indexes
|
from modules.sd_samplers_common import images_tensor_to_samples, decode_first_stage, approximation_indexes
|
||||||
@@ -457,6 +457,20 @@ class StableDiffusionProcessing:
|
|||||||
opts.emphasis,
|
opts.emphasis,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def apply_generation_params_list(self, generation_params_states):
|
||||||
|
"""add and apply generation_params_states to self.extra_generation_params"""
|
||||||
|
for key, value in generation_params_states.items():
|
||||||
|
if key in self.extra_generation_params and isinstance(current_value := self.extra_generation_params[key], util.GenerationParametersList):
|
||||||
|
self.extra_generation_params[key] = current_value + value
|
||||||
|
else:
|
||||||
|
self.extra_generation_params[key] = value
|
||||||
|
|
||||||
|
def clear_marked_generation_params(self):
|
||||||
|
"""clears any generation parameters that are with the attribute to_be_clear_before_batch = True"""
|
||||||
|
for key, value in list(self.extra_generation_params.items()):
|
||||||
|
if getattr(value, 'to_be_clear_before_batch', False):
|
||||||
|
self.extra_generation_params.pop(key)
|
||||||
|
|
||||||
def get_conds_with_caching(self, function, required_prompts, steps, caches, extra_network_data, hires_steps=None):
|
def get_conds_with_caching(self, function, required_prompts, steps, caches, extra_network_data, hires_steps=None):
|
||||||
"""
|
"""
|
||||||
Returns the result of calling function(shared.sd_model, required_prompts, steps)
|
Returns the result of calling function(shared.sd_model, required_prompts, steps)
|
||||||
@@ -480,6 +494,10 @@ class StableDiffusionProcessing:
|
|||||||
|
|
||||||
for cache in caches:
|
for cache in caches:
|
||||||
if cache[0] is not None and cached_params == cache[0]:
|
if cache[0] is not None and cached_params == cache[0]:
|
||||||
|
if len(cache) == 3:
|
||||||
|
generation_params_states, cached_cached_params = cache[2]
|
||||||
|
if cached_params == cached_cached_params:
|
||||||
|
self.apply_generation_params_list(generation_params_states)
|
||||||
return cache[1]
|
return cache[1]
|
||||||
|
|
||||||
cache = caches[0]
|
cache = caches[0]
|
||||||
@@ -487,6 +505,13 @@ class StableDiffusionProcessing:
|
|||||||
with devices.autocast():
|
with devices.autocast():
|
||||||
cache[1] = function(shared.sd_model, required_prompts, steps, hires_steps, shared.opts.use_old_scheduling)
|
cache[1] = function(shared.sd_model, required_prompts, steps, hires_steps, shared.opts.use_old_scheduling)
|
||||||
|
|
||||||
|
generation_params_states = model_hijack.extract_generation_params_states()
|
||||||
|
self.apply_generation_params_list(generation_params_states)
|
||||||
|
if len(cache) == 2:
|
||||||
|
cache.append((generation_params_states, cached_params))
|
||||||
|
else:
|
||||||
|
cache[2] = (generation_params_states, cached_params)
|
||||||
|
|
||||||
cache[0] = cached_params
|
cache[0] = cached_params
|
||||||
return cache[1]
|
return cache[1]
|
||||||
|
|
||||||
@@ -502,6 +527,8 @@ class StableDiffusionProcessing:
|
|||||||
self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, total_steps, [self.cached_uc], self.extra_network_data)
|
self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, total_steps, [self.cached_uc], self.extra_network_data)
|
||||||
self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, total_steps, [self.cached_c], self.extra_network_data)
|
self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, total_steps, [self.cached_c], self.extra_network_data)
|
||||||
|
|
||||||
|
self.extra_generation_params.update(model_hijack.extra_generation_params)
|
||||||
|
|
||||||
def get_conds(self):
|
def get_conds(self):
|
||||||
return self.c, self.uc
|
return self.c, self.uc
|
||||||
|
|
||||||
@@ -801,10 +828,10 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
|
|||||||
|
|
||||||
for key, value in generation_params.items():
|
for key, value in generation_params.items():
|
||||||
try:
|
try:
|
||||||
if isinstance(value, list):
|
if callable(value):
|
||||||
generation_params[key] = value[index]
|
|
||||||
elif callable(value):
|
|
||||||
generation_params[key] = value(**locals())
|
generation_params[key] = value(**locals())
|
||||||
|
elif isinstance(value, list):
|
||||||
|
generation_params[key] = value[index]
|
||||||
except Exception:
|
except Exception:
|
||||||
errors.report(f'Error creating infotext for key "{key}"', exc_info=True)
|
errors.report(f'Error creating infotext for key "{key}"', exc_info=True)
|
||||||
generation_params[key] = None
|
generation_params[key] = None
|
||||||
@@ -938,6 +965,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||||||
if state.interrupted or state.stopping_generation:
|
if state.interrupted or state.stopping_generation:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
p.clear_marked_generation_params() # clean up some generation params are tagged to be cleared before batch
|
||||||
sd_models.reload_model_weights() # model can be changed for example by refiner
|
sd_models.reload_model_weights() # model can be changed for example by refiner
|
||||||
|
|
||||||
p.prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
|
p.prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
|
||||||
@@ -965,8 +993,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||||||
|
|
||||||
p.setup_conds()
|
p.setup_conds()
|
||||||
|
|
||||||
p.extra_generation_params.update(model_hijack.extra_generation_params)
|
|
||||||
|
|
||||||
# params.txt should be saved after scripts.process_batch, since the
|
# params.txt should be saved after scripts.process_batch, since the
|
||||||
# infotext could be modified by that callback
|
# infotext could be modified by that callback
|
||||||
# Example: a wildcard processed by process_batch sets an extra model
|
# Example: a wildcard processed by process_batch sets an extra model
|
||||||
@@ -1513,6 +1539,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||||||
self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, hr_negative_prompts, self.firstpass_steps, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data, total_steps)
|
self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, hr_negative_prompts, self.firstpass_steps, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data, total_steps)
|
||||||
self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, hr_prompts, self.firstpass_steps, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data, total_steps)
|
self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, hr_prompts, self.firstpass_steps, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data, total_steps)
|
||||||
|
|
||||||
|
self.extra_generation_params.update(model_hijack.extra_generation_params)
|
||||||
|
|
||||||
def setup_conds(self):
|
def setup_conds(self):
|
||||||
if self.is_hr_pass:
|
if self.is_hr_pass:
|
||||||
# if we are in hr pass right now, the call is being made from the refiner, and we don't need to setup firstpass cons or switch model
|
# if we are in hr pass right now, the call is being made from the refiner, and we don't need to setup firstpass cons or switch model
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import torch
|
|||||||
from torch.nn.functional import silu
|
from torch.nn.functional import silu
|
||||||
from types import MethodType
|
from types import MethodType
|
||||||
|
|
||||||
from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet, patches
|
from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet, patches, util
|
||||||
from modules.hypernetworks import hypernetwork
|
from modules.hypernetworks import hypernetwork
|
||||||
from modules.shared import cmd_opts
|
from modules.shared import cmd_opts
|
||||||
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr, xlmr_m18
|
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr, xlmr_m18
|
||||||
@@ -321,6 +321,14 @@ class StableDiffusionModelHijack:
|
|||||||
self.comments = []
|
self.comments = []
|
||||||
self.extra_generation_params = {}
|
self.extra_generation_params = {}
|
||||||
|
|
||||||
|
def extract_generation_params_states(self):
|
||||||
|
"""Extracts GenerationParametersList so that they can be cached and restored later"""
|
||||||
|
states = {}
|
||||||
|
for key in list(self.extra_generation_params):
|
||||||
|
if isinstance(self.extra_generation_params[key], util.GenerationParametersList):
|
||||||
|
states[key] = self.extra_generation_params.pop(key)
|
||||||
|
return states
|
||||||
|
|
||||||
def get_prompt_lengths(self, text):
|
def get_prompt_lengths(self, text):
|
||||||
if self.clip is None:
|
if self.clip is None:
|
||||||
return "-", "-"
|
return "-", "-"
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from collections import namedtuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from modules import prompt_parser, devices, sd_hijack, sd_emphasis
|
from modules import prompt_parser, devices, sd_hijack, sd_emphasis, util
|
||||||
from modules.shared import opts
|
from modules.shared import opts
|
||||||
|
|
||||||
|
|
||||||
@@ -27,6 +27,30 @@ chunk. Those objects are found in PromptChunk.fixes and, are placed into FrozenC
|
|||||||
are applied by sd_hijack.EmbeddingsWithFixes's forward function."""
|
are applied by sd_hijack.EmbeddingsWithFixes's forward function."""
|
||||||
|
|
||||||
|
|
||||||
|
class EmphasisMode(util.GenerationParametersList):
|
||||||
|
def __init__(self, emphasis_mode:str = None):
|
||||||
|
super().__init__()
|
||||||
|
self.emphasis_mode = emphasis_mode
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
return self.emphasis_mode
|
||||||
|
|
||||||
|
def __add__(self, other):
|
||||||
|
if isinstance(other, EmphasisMode):
|
||||||
|
return self if self.emphasis_mode else other
|
||||||
|
elif isinstance(other, str):
|
||||||
|
return self.__str__() + other
|
||||||
|
return NotImplemented
|
||||||
|
|
||||||
|
def __radd__(self, other):
|
||||||
|
if isinstance(other, str):
|
||||||
|
return other + self.__str__()
|
||||||
|
return NotImplemented
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.emphasis_mode if self.emphasis_mode else ''
|
||||||
|
|
||||||
|
|
||||||
class TextConditionalModel(torch.nn.Module):
|
class TextConditionalModel(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -238,12 +262,10 @@ class TextConditionalModel(torch.nn.Module):
|
|||||||
hashes.append(f"{name}: {shorthash}")
|
hashes.append(f"{name}: {shorthash}")
|
||||||
|
|
||||||
if hashes:
|
if hashes:
|
||||||
if self.hijack.extra_generation_params.get("TI hashes"):
|
self.hijack.extra_generation_params["TI hashes"] = util.GenerationParametersList(hashes)
|
||||||
hashes.append(self.hijack.extra_generation_params.get("TI hashes"))
|
|
||||||
self.hijack.extra_generation_params["TI hashes"] = ", ".join(hashes)
|
|
||||||
|
|
||||||
if any(x for x in texts if "(" in x or "[" in x) and opts.emphasis != "Original":
|
if opts.emphasis != 'Original' and any(x for x in texts if '(' in x or '[' in x):
|
||||||
self.hijack.extra_generation_params["Emphasis"] = opts.emphasis
|
self.hijack.extra_generation_params["Emphasis"] = EmphasisMode(opts.emphasis)
|
||||||
|
|
||||||
if self.return_pooled:
|
if self.return_pooled:
|
||||||
return torch.hstack(zs), zs[0].pooled
|
return torch.hstack(zs), zs[0].pooled
|
||||||
|
|||||||
@@ -288,3 +288,49 @@ def compare_sha256(file_path: str, hash_prefix: str) -> bool:
|
|||||||
for chunk in iter(lambda: f.read(blksize), b""):
|
for chunk in iter(lambda: f.read(blksize), b""):
|
||||||
hash_sha256.update(chunk)
|
hash_sha256.update(chunk)
|
||||||
return hash_sha256.hexdigest().startswith(hash_prefix.strip().lower())
|
return hash_sha256.hexdigest().startswith(hash_prefix.strip().lower())
|
||||||
|
|
||||||
|
|
||||||
|
class GenerationParametersList(list):
|
||||||
|
"""A special object used in sd_hijack.StableDiffusionModelHijack for setting extra_generation_params
|
||||||
|
due to StableDiffusionProcessing.get_conds_with_caching
|
||||||
|
extra_generation_params set in StableDiffusionModelHijack will be lost when cached is used
|
||||||
|
|
||||||
|
When an extra_generation_params is set in StableDiffusionModelHijack using this object,
|
||||||
|
the params will be extracted by StableDiffusionModelHijack.extract_generation_params_states
|
||||||
|
the extracted params will be cached in StableDiffusionProcessing.get_conds_with_caching
|
||||||
|
and applyed to StableDiffusionProcessing.extra_generation_params by StableDiffusionProcessing.apply_generation_params_states
|
||||||
|
|
||||||
|
Example see modules.sd_hijack_clip.TextConditionalModel.hijack.extra_generation_params 'TI hashes' 'Emphasis'
|
||||||
|
|
||||||
|
Depending on the use case the methods can be overwritten.
|
||||||
|
In general __call__ method should return str or None, as normally it's called in modules.processing.create_infotext.
|
||||||
|
When called by create_infotext it will access to the locals() of the caller,
|
||||||
|
if return str, the value will be written to infotext, if return None will be ignored.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, to_be_clear_before_batch=True, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self._to_be_clear_before_batch = to_be_clear_before_batch
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
return ', '.join(sorted(set(self), key=natural_sort_key))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def to_be_clear_before_batch(self):
|
||||||
|
return self._to_be_clear_before_batch
|
||||||
|
|
||||||
|
def __add__(self, other):
|
||||||
|
if isinstance(other, GenerationParametersList):
|
||||||
|
return self.__class__([*self, *other])
|
||||||
|
elif isinstance(other, str):
|
||||||
|
return self.__str__() + other
|
||||||
|
return NotImplemented
|
||||||
|
|
||||||
|
def __radd__(self, other):
|
||||||
|
if isinstance(other, str):
|
||||||
|
return other + self.__str__()
|
||||||
|
return NotImplemented
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.__call__()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user