Compare commits

...

4 Commits

Author SHA1 Message Date
Kohaku-Blueleaf a36a30fb93 add gc after using consistency dec 2023-11-07 13:01:10 +08:00
Kohaku-Blueleaf 2ea8726597 custom schedule 2023-11-07 12:35:56 +08:00
Kohaku-Blueleaf 5dbd0355b0 Fix linting 2023-11-07 11:00:24 +08:00
Kohaku-Blueleaf 64fd916334 Add consistency decoder 2023-11-07 10:52:29 +08:00
5 changed files with 48 additions and 3 deletions
+9 -2
View File
@@ -3,7 +3,7 @@ from collections import namedtuple
import numpy as np import numpy as np
import torch import torch
from PIL import Image from PIL import Image
from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, shared, sd_models from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, sd_vae_consistency, shared, sd_models
from modules.shared import opts, state from modules.shared import opts, state
import k_diffusion.sampling import k_diffusion.sampling
@@ -31,7 +31,7 @@ def setup_img2img_steps(p, steps=None):
return steps, t_enc return steps, t_enc
approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2, "TAESD": 3} approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2, "TAESD": 3, "Consistency Decoder": 4}
def samples_to_images_tensor(sample, approximation=None, model=None): def samples_to_images_tensor(sample, approximation=None, model=None):
@@ -51,6 +51,13 @@ def samples_to_images_tensor(sample, approximation=None, model=None):
elif approximation == 3: elif approximation == 3:
x_sample = sd_vae_taesd.decoder_model()(sample.to(devices.device, devices.dtype)).detach() x_sample = sd_vae_taesd.decoder_model()(sample.to(devices.device, devices.dtype)).detach()
x_sample = x_sample * 2 - 1 x_sample = x_sample * 2 - 1
elif approximation == 4:
with devices.autocast(), torch.no_grad():
x_sample = sd_vae_consistency.decoder_model()(
sample.detach().to(devices.device, devices.dtype)/0.18215,
schedule=[float(i.strip()) for i in shared.opts.sd_vae_consistency_schedule.split(',')],
)
sd_vae_consistency.unload()
else: else:
if model is None: if model is None:
model = shared.sd_model model = shared.sd_model
+34
View File
@@ -0,0 +1,34 @@
"""
Consistency Decoder
Improved decoding for stable diffusion vaes.
https://github.com/openai/consistencydecoder
"""
import os
from modules import devices, paths_internal, shared
from consistencydecoder import ConsistencyDecoder
sd_vae_consistency_models = None
model_path = os.path.join(paths_internal.models_path, 'consistencydecoder')
def decoder_model():
global sd_vae_consistency_models
if getattr(shared.sd_model, 'is_sdxl', False):
raise NotImplementedError("SDXL is not supported for consistency decoder")
if sd_vae_consistency_models is not None:
sd_vae_consistency_models.ckpt.to(devices.device)
return sd_vae_consistency_models
loaded_model = ConsistencyDecoder(devices.device, model_path)
sd_vae_consistency_models = loaded_model
return loaded_model
def unload():
global sd_vae_consistency_models
if sd_vae_consistency_models is not None:
devices.torch_gc()
sd_vae_consistency_models.ckpt.to('cpu')
+2 -1
View File
@@ -172,7 +172,8 @@ For img2img, VAE is used to process user's input image before the sampling, and
"sd_vae_overrides_per_model_preferences": OptionInfo(True, "Selected VAE overrides per-model preferences").info("you can set per-model VAE either by editing user metadata for checkpoints, or by making the VAE have same name as checkpoint"), "sd_vae_overrides_per_model_preferences": OptionInfo(True, "Selected VAE overrides per-model preferences").info("you can set per-model VAE either by editing user metadata for checkpoints, or by making the VAE have same name as checkpoint"),
"auto_vae_precision": OptionInfo(True, "Automatically revert VAE to 32-bit floats").info("triggers when a tensor with NaNs is produced in VAE; disabling the option in this case will result in a black square image"), "auto_vae_precision": OptionInfo(True, "Automatically revert VAE to 32-bit floats").info("triggers when a tensor with NaNs is produced in VAE; disabling the option in this case will result in a black square image"),
"sd_vae_encode_method": OptionInfo("Full", "VAE type for encode", gr.Radio, {"choices": ["Full", "TAESD"]}, infotext='VAE Encoder').info("method to encode image to latent (use in img2img, hires-fix or inpaint mask)"), "sd_vae_encode_method": OptionInfo("Full", "VAE type for encode", gr.Radio, {"choices": ["Full", "TAESD"]}, infotext='VAE Encoder').info("method to encode image to latent (use in img2img, hires-fix or inpaint mask)"),
"sd_vae_decode_method": OptionInfo("Full", "VAE type for decode", gr.Radio, {"choices": ["Full", "TAESD"]}, infotext='VAE Decoder').info("method to decode latent to image"), "sd_vae_decode_method": OptionInfo("Full", "VAE type for decode", gr.Radio, {"choices": ["Full", "TAESD", "Consistency Decoder"]}, infotext='VAE Decoder').info("method to decode latent to image"),
"sd_vae_consistency_schedule": OptionInfo("1.0, 0.5", "consistency schedule").info("sampling schedule for consistency decoder."),
})) }))
options_templates.update(options_section(('img2img', "img2img"), { options_templates.update(options_section(('img2img', "img2img"), {
+2
View File
@@ -32,3 +32,5 @@ torch
torchdiffeq torchdiffeq
torchsde torchsde
transformers==4.30.2 transformers==4.30.2
git+https://github.com/openai/consistencydecoder.git
+1
View File
@@ -30,3 +30,4 @@ torchdiffeq==0.2.3
torchsde==0.2.6 torchsde==0.2.6
transformers==4.30.2 transformers==4.30.2
httpx==0.24.1 httpx==0.24.1
git+https://github.com/openai/consistencydecoder.git