Add consistency decoder
This commit is contained in:
@@ -3,7 +3,7 @@ from collections import namedtuple
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, shared, sd_models
|
||||
from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, sd_vae_consistency, shared, sd_models
|
||||
from modules.shared import opts, state
|
||||
import k_diffusion.sampling
|
||||
|
||||
@@ -31,7 +31,7 @@ def setup_img2img_steps(p, steps=None):
|
||||
return steps, t_enc
|
||||
|
||||
|
||||
approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2, "TAESD": 3}
|
||||
approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2, "TAESD": 3, "Consistency Decoder": 4}
|
||||
|
||||
|
||||
def samples_to_images_tensor(sample, approximation=None, model=None):
|
||||
@@ -51,6 +51,13 @@ def samples_to_images_tensor(sample, approximation=None, model=None):
|
||||
elif approximation == 3:
|
||||
x_sample = sd_vae_taesd.decoder_model()(sample.to(devices.device, devices.dtype)).detach()
|
||||
x_sample = x_sample * 2 - 1
|
||||
elif approximation == 4:
|
||||
with devices.autocast(), torch.no_grad():
|
||||
x_sample = sd_vae_consistency.decoder_model()(
|
||||
sample.to(devices.device, devices.dtype)/0.18215,
|
||||
schedule=[1.0]
|
||||
)
|
||||
sd_vae_consistency.unload()
|
||||
else:
|
||||
if model is None:
|
||||
model = shared.sd_model
|
||||
|
||||
Reference in New Issue
Block a user