Compare commits

..

3 Commits

Author SHA1 Message Date
w-e-w 6aa7925d36 lint 2024-11-23 19:12:13 +09:00
w-e-w a57adde5ae sort TI hashes 2024-11-23 19:09:09 +09:00
w-e-w e72a6c411a fix missing infotext cased by conda cache
some generation params such as TI hashes or Emphasis is added in sd_hijack / sd_hijack_clip
if conda are fetche from cache sd_hijack_clip will not be executed and it won't have a chance to to add generation params

the generation params will also be missing if in non low-vram mode because the hijack.extra_generation_params was never read after calculate_hr_conds
2024-11-23 19:09:09 +09:00
15 changed files with 113 additions and 126 deletions
+1 -1
View File
@@ -22,7 +22,7 @@ jobs:
- name: Install Ruff - name: Install Ruff
run: pip install ruff==0.3.3 run: pip install ruff==0.3.3
- name: Run Ruff - name: Run Ruff
run: ruff check . run: ruff .
lint-js: lint-js:
name: eslint name: eslint
runs-on: ubuntu-latest runs-on: ubuntu-latest
@@ -1,69 +1,36 @@
// Stable Diffusion WebUI - Bracket Checker // Stable Diffusion WebUI - Bracket checker
// By @Bwin4L, @akx, @w-e-w, @Haoming02 // By Hingashi no Florin/Bwin4L & @akx
// Counts open and closed brackets (round, square, curly) in the prompt and negative prompt text boxes in the txt2img and img2img tabs. // Counts open and closed brackets (round, square, curly) in the prompt and negative prompt text boxes in the txt2img and img2img tabs.
// If there's a mismatch, the keyword counter turns red, and if you hover on it, a tooltip tells you what's wrong. // If there's a mismatch, the keyword counter turns red and if you hover on it, a tooltip tells you what's wrong.
function checkBrackets(textArea, counterElem) {
const pairs = [
['(', ')', 'round brackets'],
['[', ']', 'square brackets'],
['{', '}', 'curly brackets']
];
function checkBrackets(textArea, counterElt) {
const counts = {}; const counts = {};
const errors = new Set(); textArea.value.matchAll(/(?<!\\)(?:\\\\)*?([(){}[\]])/g).forEach(bracket => {
let i = 0; counts[bracket[1]] = (counts[bracket[1]] || 0) + 1;
});
const errors = [];
while (i < textArea.value.length) { function checkPair(open, close, kind) {
let char = textArea.value[i]; if (counts[open] !== counts[close]) {
let escaped = false; errors.push(
while (char === '\\' && i + 1 < textArea.value.length) { `${open}...${close} - Detected ${counts[open] || 0} opening and ${counts[close] || 0} closing ${kind}.`
escaped = !escaped; );
i++;
char = textArea.value[i];
}
if (escaped) {
i++;
continue;
}
for (const [open, close, label] of pairs) {
if (char === open) {
counts[label] = (counts[label] || 0) + 1;
} else if (char === close) {
counts[label] = (counts[label] || 0) - 1;
if (counts[label] < 0) {
errors.add(`Incorrect order of ${label}.`);
}
}
}
i++;
}
for (const [open, close, label] of pairs) {
if (counts[label] == undefined) {
continue;
}
if (counts[label] > 0) {
errors.add(`${open} ... ${close} - Detected ${counts[label]} more opening than closing ${label}.`);
} else if (counts[label] < 0) {
errors.add(`${open} ... ${close} - Detected ${-counts[label]} more closing than opening ${label}.`);
} }
} }
counterElem.title = [...errors].join('\n'); checkPair('(', ')', 'round brackets');
counterElem.classList.toggle('error', errors.size !== 0); checkPair('[', ']', 'square brackets');
checkPair('{', '}', 'curly brackets');
counterElt.title = errors.join('\n');
counterElt.classList.toggle('error', errors.length !== 0);
} }
function setupBracketChecking(id_prompt, id_counter) { function setupBracketChecking(id_prompt, id_counter) {
const textarea = gradioApp().querySelector(`#${id_prompt} > label > textarea`); var textarea = gradioApp().querySelector("#" + id_prompt + " > label > textarea");
const counter = gradioApp().getElementById(id_counter); var counter = gradioApp().getElementById(id_counter);
if (textarea && counter) { if (textarea && counter) {
onEdit(`${id_prompt}_BracketChecking`, textarea, 400, () => checkBrackets(textarea, counter)); textarea.addEventListener("input", () => checkBrackets(textarea, counter));
} }
} }
+1 -1
View File
@@ -1,5 +1,5 @@
<div> <div>
<a href="{api_docs}" target="_blank">API</a> <a href="{api_docs}">API</a>
 •   • 
<a href="https://github.com/AUTOMATIC1111/stable-diffusion-webui">GitHub</a> <a href="https://github.com/AUTOMATIC1111/stable-diffusion-webui">GitHub</a>
 •   • 
-1
View File
@@ -50,7 +50,6 @@ def check_versions():
def initialize(): def initialize():
from modules import initialize_util from modules import initialize_util
initialize_util.allow_add_middleware_after_start()
initialize_util.fix_torch_version() initialize_util.fix_torch_version()
initialize_util.fix_pytorch_lightning() initialize_util.fix_pytorch_lightning()
initialize_util.fix_asyncio_event_loop_policy() initialize_util.fix_asyncio_event_loop_policy()
+3 -37
View File
@@ -5,8 +5,6 @@ import sys
import re import re
from modules.timer import startup_timer from modules.timer import startup_timer
from modules import patches
from functools import wraps
def gradio_server_name(): def gradio_server_name():
@@ -193,8 +191,11 @@ def configure_opts_onchange():
def setup_middleware(app): def setup_middleware(app):
from starlette.middleware.gzip import GZipMiddleware from starlette.middleware.gzip import GZipMiddleware
app.middleware_stack = None # reset current middleware to allow modifying user provided list
app.add_middleware(GZipMiddleware, minimum_size=1000) app.add_middleware(GZipMiddleware, minimum_size=1000)
configure_cors_middleware(app) configure_cors_middleware(app)
app.build_middleware_stack() # rebuild middleware stack on-the-fly
def configure_cors_middleware(app): def configure_cors_middleware(app):
@@ -212,38 +213,3 @@ def configure_cors_middleware(app):
cors_options["allow_origin_regex"] = cmd_opts.cors_allow_origins_regex cors_options["allow_origin_regex"] = cmd_opts.cors_allow_origins_regex
app.add_middleware(CORSMiddleware, **cors_options) app.add_middleware(CORSMiddleware, **cors_options)
def allow_add_middleware_after_start():
from starlette.applications import Starlette
def add_middleware_wrapper(func):
"""Patch Starlette.add_middleware to allow for middleware to be added after the app has started
Starlette.add_middleware raises RuntimeError("Cannot add middleware after an application has started") if middleware_stack is not None.
We can force add new middleware by first setting middleware_stack to None, then adding the middleware.
When middleware_stack is None, it will rebuild the middleware_stack on the next request (Lazily build middleware stack).
If packages are updated in the future, things may break, so we have two ways to add middleware after the app has started:
the first way is to just set middleware_stack to None and then retry
the second manually insert the middleware into the user_middleware list without calling add_middleware
"""
@wraps(func)
def wrapper(self, *args, **kwargs):
res = None
try:
res = func(self, *args, **kwargs)
except RuntimeError as _:
try:
self.middleware_stack = None
res = func(self, *args, **kwargs)
except RuntimeError as e:
print(f'Warning: "{e}", Retrying...')
from starlette.middleware import Middleware
self.user_middleware.insert(0, Middleware(*args, **kwargs))
self.middleware_stack = None # ensure middleware_stack in the event of concurrent requests
return res
return wrapper
patches.patch(__name__, obj=Starlette, field="add_middleware", replacement=add_middleware_wrapper(Starlette.add_middleware))
+3 -1
View File
@@ -43,7 +43,9 @@ def check_python_version():
supported_minors = [7, 8, 9, 10, 11] supported_minors = [7, 8, 9, 10, 11]
if not (major == 3 and minor in supported_minors): if not (major == 3 and minor in supported_minors):
errors.print_error_explanation(f""" import modules.errors
modules.errors.print_error_explanation(f"""
INCOMPATIBLE PYTHON VERSION INCOMPATIBLE PYTHON VERSION
This program is tested with 3.10.6 Python, but you have {major}.{minor}.{micro}. This program is tested with 3.10.6 Python, but you have {major}.{minor}.{micro}.
+25 -2
View File
@@ -187,6 +187,7 @@ class StableDiffusionProcessing:
cached_uc = [None, None] cached_uc = [None, None]
cached_c = [None, None] cached_c = [None, None]
hijack_generation_params_state_list = []
comments: dict = None comments: dict = None
sampler: sd_samplers_common.Sampler | None = field(default=None, init=False) sampler: sd_samplers_common.Sampler | None = field(default=None, init=False)
@@ -480,6 +481,10 @@ class StableDiffusionProcessing:
for cache in caches: for cache in caches:
if cache[0] is not None and cached_params == cache[0]: if cache[0] is not None and cached_params == cache[0]:
if len(cache) == 3:
generation_params_state, cached_params_2 = cache[2]
if cached_params == cached_params_2:
self.hijack_generation_params_state_list.extend(generation_params_state)
return cache[1] return cache[1]
cache = caches[0] cache = caches[0]
@@ -487,9 +492,25 @@ class StableDiffusionProcessing:
with devices.autocast(): with devices.autocast():
cache[1] = function(shared.sd_model, required_prompts, steps, hires_steps, shared.opts.use_old_scheduling) cache[1] = function(shared.sd_model, required_prompts, steps, hires_steps, shared.opts.use_old_scheduling)
generation_params_state = model_hijack.capture_generation_params_state()
self.hijack_generation_params_state_list.extend(generation_params_state)
if len(cache) == 2:
cache.append((generation_params_state, cached_params))
else:
cache[2] = (generation_params_state, cached_params)
cache[0] = cached_params cache[0] = cached_params
return cache[1] return cache[1]
def apply_hijack_generation_params(self):
self.extra_generation_params.update(model_hijack.extra_generation_params)
for func in self.hijack_generation_params_state_list:
try:
func(self.extra_generation_params)
except Exception:
errors.report('Failed to apply hijack generation params state', exc_info=True)
self.hijack_generation_params_state_list.clear()
def setup_conds(self): def setup_conds(self):
prompts = prompt_parser.SdConditioning(self.prompts, width=self.width, height=self.height) prompts = prompt_parser.SdConditioning(self.prompts, width=self.width, height=self.height)
negative_prompts = prompt_parser.SdConditioning(self.negative_prompts, width=self.width, height=self.height, is_negative_prompt=True) negative_prompts = prompt_parser.SdConditioning(self.negative_prompts, width=self.width, height=self.height, is_negative_prompt=True)
@@ -502,6 +523,8 @@ class StableDiffusionProcessing:
self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, total_steps, [self.cached_uc], self.extra_network_data) self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, total_steps, [self.cached_uc], self.extra_network_data)
self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, total_steps, [self.cached_c], self.extra_network_data) self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, total_steps, [self.cached_c], self.extra_network_data)
self.apply_hijack_generation_params()
def get_conds(self): def get_conds(self):
return self.c, self.uc return self.c, self.uc
@@ -965,8 +988,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
p.setup_conds() p.setup_conds()
p.extra_generation_params.update(model_hijack.extra_generation_params)
# params.txt should be saved after scripts.process_batch, since the # params.txt should be saved after scripts.process_batch, since the
# infotext could be modified by that callback # infotext could be modified by that callback
# Example: a wildcard processed by process_batch sets an extra model # Example: a wildcard processed by process_batch sets an extra model
@@ -1513,6 +1534,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, hr_negative_prompts, self.firstpass_steps, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data, total_steps) self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, hr_negative_prompts, self.firstpass_steps, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data, total_steps)
self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, hr_prompts, self.firstpass_steps, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data, total_steps) self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, hr_prompts, self.firstpass_steps, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data, total_steps)
self.apply_hijack_generation_params()
def setup_conds(self): def setup_conds(self):
if self.is_hr_pass: if self.is_hr_pass:
# if we are in hr pass right now, the call is being made from the refiner, and we don't need to setup firstpass cons or switch model # if we are in hr pass right now, the call is being made from the refiner, and we don't need to setup firstpass cons or switch model
+8
View File
@@ -6,6 +6,7 @@ from modules import devices, sd_hijack_optimizations, shared, script_callbacks,
from modules.hypernetworks import hypernetwork from modules.hypernetworks import hypernetwork
from modules.shared import cmd_opts from modules.shared import cmd_opts
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr, xlmr_m18 from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr, xlmr_m18
from modules.util import GenerationParamsState
import ldm.modules.attention import ldm.modules.attention
import ldm.modules.diffusionmodules.model import ldm.modules.diffusionmodules.model
@@ -321,6 +322,13 @@ class StableDiffusionModelHijack:
self.comments = [] self.comments = []
self.extra_generation_params = {} self.extra_generation_params = {}
def capture_generation_params_state(self):
state = []
for key in list(self.extra_generation_params):
if isinstance(self.extra_generation_params[key], GenerationParamsState):
state.append(self.extra_generation_params.pop(key))
return state
def get_prompt_lengths(self, text): def get_prompt_lengths(self, text):
if self.clip is None: if self.clip is None:
return "-", "-" return "-", "-"
+29 -6
View File
@@ -3,8 +3,9 @@ from collections import namedtuple
import torch import torch
from modules import prompt_parser, devices, sd_hijack, sd_emphasis from modules import prompt_parser, devices, sd_hijack, sd_emphasis, util
from modules.shared import opts from modules.shared import opts
from modules.util import GenerationParamsState
class PromptChunk: class PromptChunk:
@@ -27,6 +28,31 @@ chunk. Those objects are found in PromptChunk.fixes and, are placed into FrozenC
are applied by sd_hijack.EmbeddingsWithFixes's forward function.""" are applied by sd_hijack.EmbeddingsWithFixes's forward function."""
class EmbeddingHashes(GenerationParamsState):
def __init__(self, hashes: list):
super().__init__()
self.hashes = hashes
def __call__(self, extra_generation_params):
unique_hashes = dict.fromkeys(self.hashes)
if existing_ti_hashes := extra_generation_params.get('TI hashes'):
unique_hashes.update(dict.fromkeys(existing_ti_hashes.split(', ')))
extra_generation_params['TI hashes'] = ', '.join(sorted(unique_hashes, key=util.natural_sort_key))
class EmphasisMode(GenerationParamsState):
def __init__(self, texts):
super().__init__()
if opts.emphasis != 'Original' and any(x for x in texts if '(' in x or '[' in x):
self.emphasis = opts.emphasis
else:
self.emphasis = None
def __call__(self, extra_generation_params):
if self.emphasis:
extra_generation_params['Emphasis'] = self.emphasis
class TextConditionalModel(torch.nn.Module): class TextConditionalModel(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@@ -238,12 +264,9 @@ class TextConditionalModel(torch.nn.Module):
hashes.append(f"{name}: {shorthash}") hashes.append(f"{name}: {shorthash}")
if hashes: if hashes:
if self.hijack.extra_generation_params.get("TI hashes"): self.hijack.extra_generation_params["TI hashes"] = EmbeddingHashes(hashes)
hashes.append(self.hijack.extra_generation_params.get("TI hashes"))
self.hijack.extra_generation_params["TI hashes"] = ", ".join(hashes)
if any(x for x in texts if "(" in x or "[" in x) and opts.emphasis != "Original": self.hijack.extra_generation_params["Emphasis"] = EmphasisMode(texts)
self.hijack.extra_generation_params["Emphasis"] = opts.emphasis
if self.return_pooled: if self.return_pooled:
return torch.hstack(zs), zs[0].pooled return torch.hstack(zs), zs[0].pooled
+1 -1
View File
@@ -125,7 +125,7 @@ def ui_reorder_categories():
def callbacks_order_settings(): def callbacks_order_settings():
options = { options = {
"callbacks_order_explanation": OptionHTML(""" "sd_vae_explanation": OptionHTML("""
For categories below, callbacks added to dropdowns happen before others, in order listed. For categories below, callbacks added to dropdowns happen before others, in order listed.
"""), """),
+2 -3
View File
@@ -33,12 +33,12 @@ categories.register_category("training", "Training")
options_templates.update(options_section(('saving-images', "Saving images/grids", "saving"), { options_templates.update(options_section(('saving-images', "Saving images/grids", "saving"), {
"samples_save": OptionInfo(True, "Always save all generated images"), "samples_save": OptionInfo(True, "Always save all generated images"),
"samples_format": OptionInfo('png', 'File format for images', ui_components.DropdownEditable, {"choices": ("png", "jpg", "jpeg", "webp", "avif")}).info("manual input of <a href='https://pillow.readthedocs.io/en/stable/handbook/image-file-formats.html' target='_blank'>other formats</a> is possible, but compatibility is not guaranteed"), "samples_format": OptionInfo('png', 'File format for images'),
"samples_filename_pattern": OptionInfo("", "Images filename pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"), "samples_filename_pattern": OptionInfo("", "Images filename pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"),
"save_images_add_number": OptionInfo(True, "Add number to filename when saving", component_args=hide_dirs), "save_images_add_number": OptionInfo(True, "Add number to filename when saving", component_args=hide_dirs),
"save_images_replace_action": OptionInfo("Replace", "Saving the image to an existing file", gr.Radio, {"choices": ["Replace", "Add number suffix"], **hide_dirs}), "save_images_replace_action": OptionInfo("Replace", "Saving the image to an existing file", gr.Radio, {"choices": ["Replace", "Add number suffix"], **hide_dirs}),
"grid_save": OptionInfo(True, "Always save all generated image grids"), "grid_save": OptionInfo(True, "Always save all generated image grids"),
"grid_format": OptionInfo('png', 'File format for grids', ui_components.DropdownEditable, {"choices": ("png", "jpg", "jpeg", "webp", "avif")}).info("manual input of <a href='https://pillow.readthedocs.io/en/stable/handbook/image-file-formats.html' target='_blank'>other formats</a> is possible, but compatibility is not guaranteed"), "grid_format": OptionInfo('png', 'File format for grids'),
"grid_extended_filename": OptionInfo(False, "Add extended info (seed, prompt) to filename when saving grid"), "grid_extended_filename": OptionInfo(False, "Add extended info (seed, prompt) to filename when saving grid"),
"grid_only_if_multiple": OptionInfo(True, "Do not save grids consisting of one picture"), "grid_only_if_multiple": OptionInfo(True, "Do not save grids consisting of one picture"),
"grid_prevent_empty_spots": OptionInfo(False, "Prevent empty spots in grid (when set to autodetect)"), "grid_prevent_empty_spots": OptionInfo(False, "Prevent empty spots in grid (when set to autodetect)"),
@@ -128,7 +128,6 @@ options_templates.update(options_section(('system', "System", "system"), {
"disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files.").info("fixes very slow loading speed in some cases"), "disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files.").info("fixes very slow loading speed in some cases"),
"hide_ldm_prints": OptionInfo(True, "Prevent Stability-AI's ldm/sgm modules from printing noise to console."), "hide_ldm_prints": OptionInfo(True, "Prevent Stability-AI's ldm/sgm modules from printing noise to console."),
"dump_stacks_on_signal": OptionInfo(False, "Print stack traces before exiting the program with ctrl+c."), "dump_stacks_on_signal": OptionInfo(False, "Print stack traces before exiting the program with ctrl+c."),
"concurrent_git_fetch_limit": OptionInfo(16, "Number of simultaneous extension update checks ", gr.Slider, {"step": 1, "minimum": 1, "maximum": 100}).info("reduce extension update check time"),
})) }))
options_templates.update(options_section(('profiler', "Profiler", "system"), { options_templates.update(options_section(('profiler', "Profiler", "system"), {
+4 -11
View File
@@ -1,6 +1,5 @@
import json import json
import os import os
from concurrent.futures import ThreadPoolExecutor
import threading import threading
import time import time
from datetime import datetime, timezone from datetime import datetime, timezone
@@ -107,24 +106,18 @@ def check_updates(id_task, disable_list):
exts = [ext for ext in extensions.extensions if ext.remote is not None and ext.name not in disabled] exts = [ext for ext in extensions.extensions if ext.remote is not None and ext.name not in disabled]
shared.state.job_count = len(exts) shared.state.job_count = len(exts)
lock = threading.Lock() for ext in exts:
shared.state.textinfo = ext.name
def _check_update(ext):
try: try:
ext.check_updates() ext.check_updates()
except FileNotFoundError as e: except FileNotFoundError as e:
if 'FETCH_HEAD' not in str(e): if 'FETCH_HEAD' not in str(e):
raise raise
except Exception: except Exception:
with lock: errors.report(f"Error checking updates for {ext.name}", exc_info=True)
errors.report(f"Error checking updates for {ext.name}", exc_info=True)
with lock:
shared.state.textinfo = ext.name
shared.state.nextjob()
with ThreadPoolExecutor(max_workers=max(1, int(shared.opts.concurrent_git_fetch_limit))) as executor: shared.state.nextjob()
for ext in exts:
executor.submit(_check_update, ext)
return extension_table(), "" return extension_table(), ""
+15
View File
@@ -288,3 +288,18 @@ def compare_sha256(file_path: str, hash_prefix: str) -> bool:
for chunk in iter(lambda: f.read(blksize), b""): for chunk in iter(lambda: f.read(blksize), b""):
hash_sha256.update(chunk) hash_sha256.update(chunk)
return hash_sha256.hexdigest().startswith(hash_prefix.strip().lower()) return hash_sha256.hexdigest().startswith(hash_prefix.strip().lower())
class GenerationParamsState:
"""A custom class used in StableDiffusionModelHijack for assigning extra_generation_params
generation_params assigned using this class will work properly with StableDiffusionProcessing.get_conds_with_caching()
if assigned directly the generation_params will not be populated if conda cache is used
Generation_params of this class will be captured (see StableDiffusionModelHijack.capture_generation_params_state) and stored with conda cache, and will be extracted in StableDiffusionProcessing.apply_hijack_generation_params()
To use this class, create a subclass with a __call__ method that takes extra_generation_params: dict as input
Example usage: sd_hijack_clip.EmbeddingHashes, sd_hijack_clip.EmphasisMode
"""
def __call__(self, extra_generation_params: dict):
raise NotImplementedError
-4
View File
@@ -29,10 +29,6 @@ class ScriptPostprocessingCodeFormer(scripts_postprocessing.ScriptPostprocessing
res = Image.fromarray(restored_img) res = Image.fromarray(restored_img)
if codeformer_visibility < 1.0: if codeformer_visibility < 1.0:
if pp.image.size != res.size:
res = res.resize(pp.image.size)
if pp.image.mode != res.mode:
res = res.convert(pp.image.mode)
res = Image.blend(pp.image, res, codeformer_visibility) res = Image.blend(pp.image, res, codeformer_visibility)
pp.image = res pp.image = res
-4
View File
@@ -26,10 +26,6 @@ class ScriptPostprocessingGfpGan(scripts_postprocessing.ScriptPostprocessing):
res = Image.fromarray(restored_img) res = Image.fromarray(restored_img)
if gfpgan_visibility < 1.0: if gfpgan_visibility < 1.0:
if pp.image.size != res.size:
res = res.resize(pp.image.size)
if pp.image.mode != res.mode:
res = res.convert(pp.image.mode)
res = Image.blend(pp.image, res, gfpgan_visibility) res = Image.blend(pp.image, res, gfpgan_visibility)
pp.image = res pp.image = res