Add consistency decoder

This commit is contained in:
Kohaku-Blueleaf
2023-11-07 10:52:29 +08:00
parent 9c1c0da026
commit 64fd916334
5 changed files with 48 additions and 3 deletions
+9 -2
View File
@@ -3,7 +3,7 @@ from collections import namedtuple
import numpy as np
import torch
from PIL import Image
from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, shared, sd_models
from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, sd_vae_consistency, shared, sd_models
from modules.shared import opts, state
import k_diffusion.sampling
@@ -31,7 +31,7 @@ def setup_img2img_steps(p, steps=None):
return steps, t_enc
approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2, "TAESD": 3}
approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2, "TAESD": 3, "Consistency Decoder": 4}
def samples_to_images_tensor(sample, approximation=None, model=None):
@@ -51,6 +51,13 @@ def samples_to_images_tensor(sample, approximation=None, model=None):
elif approximation == 3:
x_sample = sd_vae_taesd.decoder_model()(sample.to(devices.device, devices.dtype)).detach()
x_sample = x_sample * 2 - 1
elif approximation == 4:
with devices.autocast(), torch.no_grad():
x_sample = sd_vae_consistency.decoder_model()(
sample.to(devices.device, devices.dtype)/0.18215,
schedule=[1.0]
)
sd_vae_consistency.unload()
else:
if model is None:
model = shared.sd_model