Fix Beta sampling to match the paper (#16926)
This commit is contained in:
committed by
GitHub
parent
82bf9a3730
commit
2174ce5afe
@@ -117,12 +117,15 @@ def ddim_scheduler(n, sigma_min, sigma_max, inner_model, device):
|
||||
|
||||
|
||||
def beta_scheduler(n, sigma_min, sigma_max, inner_model, device):
|
||||
# From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024) """
|
||||
# From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)
|
||||
alpha = shared.opts.beta_dist_alpha
|
||||
beta = shared.opts.beta_dist_beta
|
||||
timesteps = 1 - np.linspace(0, 1, n)
|
||||
timesteps = [stats.beta.ppf(x, alpha, beta) for x in timesteps]
|
||||
sigmas = [sigma_min + (x * (sigma_max-sigma_min)) for x in timesteps]
|
||||
curve = [stats.beta.ppf(x, alpha, beta) for x in np.linspace(1, 0, n)]
|
||||
|
||||
start = inner_model.sigma_to_t(torch.tensor(sigma_max))
|
||||
end = inner_model.sigma_to_t(torch.tensor(sigma_min))
|
||||
timesteps = [end + x * (start - end) for x in curve]
|
||||
sigmas = [inner_model.t_to_sigma(ts) for ts in timesteps]
|
||||
sigmas += [0.0]
|
||||
return torch.FloatTensor(sigmas).to(device)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user