Fix Beta sampling to match the paper (#16926)

This commit is contained in:
Jan Alexander Steffens
2025-05-03 08:17:03 +02:00
committed by GitHub
parent 82bf9a3730
commit 2174ce5afe
2 changed files with 9 additions and 6 deletions
+7 -4
View File
@@ -117,12 +117,15 @@ def ddim_scheduler(n, sigma_min, sigma_max, inner_model, device):
def beta_scheduler(n, sigma_min, sigma_max, inner_model, device):
# From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024) """
# From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)
alpha = shared.opts.beta_dist_alpha
beta = shared.opts.beta_dist_beta
timesteps = 1 - np.linspace(0, 1, n)
timesteps = [stats.beta.ppf(x, alpha, beta) for x in timesteps]
sigmas = [sigma_min + (x * (sigma_max-sigma_min)) for x in timesteps]
curve = [stats.beta.ppf(x, alpha, beta) for x in np.linspace(1, 0, n)]
start = inner_model.sigma_to_t(torch.tensor(sigma_max))
end = inner_model.sigma_to_t(torch.tensor(sigma_min))
timesteps = [end + x * (start - end) for x in curve]
sigmas = [inner_model.t_to_sigma(ts) for ts in timesteps]
sigmas += [0.0]
return torch.FloatTensor(sigmas).to(device)