Skip to content

Commit 32a627b

Browse files
authored
SEEDS: update noise decomposition and refactor (Comfy-Org#9633)
- Update the decomposition to reflect interval dependency - Extract phi computations into functions - Use torch.lerp for interpolation
1 parent fe442fa commit 32a627b

File tree

1 file changed

+82
-71
lines changed

1 file changed

+82
-71
lines changed

comfy/k_diffusion/sampling.py

Lines changed: 82 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,16 @@ def offset_first_sigma_for_snr(sigmas, model_sampling, percent_offset=1e-4):
171171
return sigmas
172172

173173

174+
def ei_h_phi_1(h: torch.Tensor) -> torch.Tensor:
175+
"""Compute the result of h*phi_1(h) in exponential integrator methods."""
176+
return torch.expm1(h)
177+
178+
179+
def ei_h_phi_2(h: torch.Tensor) -> torch.Tensor:
180+
"""Compute the result of h*phi_2(h) in exponential integrator methods."""
181+
return (torch.expm1(h) - h) / h
182+
183+
174184
@torch.no_grad()
175185
def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
176186
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
@@ -1550,69 +1560,66 @@ def default_er_sde_noise_scaler(x):
15501560
@torch.no_grad()
15511561
def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5):
15521562
"""SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 2.
1553-
arXiv: https://arxiv.org/abs/2305.14267
1563+
arXiv: https://arxiv.org/abs/2305.14267 (NeurIPS 2023)
15541564
"""
15551565
extra_args = {} if extra_args is None else extra_args
15561566
seed = extra_args.get("seed", None)
15571567
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
15581568
s_in = x.new_ones([x.shape[0]])
1559-
15601569
inject_noise = eta > 0 and s_noise > 0
15611570

15621571
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
15631572
sigma_fn = partial(half_log_snr_to_sigma, model_sampling=model_sampling)
15641573
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
15651574
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
15661575

1576+
fac = 1 / (2 * r)
1577+
15671578
for i in trange(len(sigmas) - 1, disable=disable):
15681579
denoised = model(x, sigmas[i] * s_in, **extra_args)
15691580
if callback is not None:
15701581
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
1582+
15711583
if sigmas[i + 1] == 0:
15721584
x = denoised
1573-
else:
1574-
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
1575-
h = lambda_t - lambda_s
1576-
h_eta = h * (eta + 1)
1577-
lambda_s_1 = lambda_s + r * h
1578-
fac = 1 / (2 * r)
1579-
sigma_s_1 = sigma_fn(lambda_s_1)
1580-
1581-
# alpha_t = sigma_t * exp(log(alpha_t / sigma_t)) = sigma_t * exp(lambda_t)
1582-
alpha_s_1 = sigma_s_1 * lambda_s_1.exp()
1583-
alpha_t = sigmas[i + 1] * lambda_t.exp()
1584-
1585-
coeff_1, coeff_2 = (-r * h_eta).expm1(), (-h_eta).expm1()
1586-
if inject_noise:
1587-
# 0 < r < 1
1588-
noise_coeff_1 = (-2 * r * h * eta).expm1().neg().sqrt()
1589-
noise_coeff_2 = (-r * h * eta).exp() * (-2 * (1 - r) * h * eta).expm1().neg().sqrt()
1590-
noise_1, noise_2 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigmas[i + 1])
1591-
1592-
# Step 1
1593-
x_2 = sigma_s_1 / sigmas[i] * (-r * h * eta).exp() * x - alpha_s_1 * coeff_1 * denoised
1594-
if inject_noise:
1595-
x_2 = x_2 + sigma_s_1 * (noise_coeff_1 * noise_1) * s_noise
1596-
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
1597-
1598-
# Step 2
1599-
denoised_d = (1 - fac) * denoised + fac * denoised_2
1600-
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * coeff_2 * denoised_d
1601-
if inject_noise:
1602-
x = x + sigmas[i + 1] * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise
1585+
continue
1586+
1587+
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
1588+
h = lambda_t - lambda_s
1589+
h_eta = h * (eta + 1)
1590+
lambda_s_1 = torch.lerp(lambda_s, lambda_t, r)
1591+
sigma_s_1 = sigma_fn(lambda_s_1)
1592+
1593+
alpha_s_1 = sigma_s_1 * lambda_s_1.exp()
1594+
alpha_t = sigmas[i + 1] * lambda_t.exp()
1595+
1596+
# Step 1
1597+
x_2 = sigma_s_1 / sigmas[i] * (-r * h * eta).exp() * x - alpha_s_1 * ei_h_phi_1(-r * h_eta) * denoised
1598+
if inject_noise:
1599+
sde_noise = (-2 * r * h * eta).expm1().neg().sqrt() * noise_sampler(sigmas[i], sigma_s_1)
1600+
x_2 = x_2 + sde_noise * sigma_s_1 * s_noise
1601+
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
1602+
1603+
# Step 2
1604+
denoised_d = torch.lerp(denoised, denoised_2, fac)
1605+
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * ei_h_phi_1(-h_eta) * denoised_d
1606+
if inject_noise:
1607+
segment_factor = (r - 1) * h * eta
1608+
sde_noise = sde_noise * segment_factor.exp()
1609+
sde_noise = sde_noise + segment_factor.mul(2).expm1().neg().sqrt() * noise_sampler(sigma_s_1, sigmas[i + 1])
1610+
x = x + sde_noise * sigmas[i + 1] * s_noise
16031611
return x
16041612

16051613

16061614
@torch.no_grad()
16071615
def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r_1=1./3, r_2=2./3):
16081616
"""SEEDS-3 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 3.
1609-
arXiv: https://arxiv.org/abs/2305.14267
1617+
arXiv: https://arxiv.org/abs/2305.14267 (NeurIPS 2023)
16101618
"""
16111619
extra_args = {} if extra_args is None else extra_args
16121620
seed = extra_args.get("seed", None)
16131621
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
16141622
s_in = x.new_ones([x.shape[0]])
1615-
16161623
inject_noise = eta > 0 and s_noise > 0
16171624

16181625
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
@@ -1624,45 +1631,49 @@ def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=Non
16241631
denoised = model(x, sigmas[i] * s_in, **extra_args)
16251632
if callback is not None:
16261633
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
1634+
16271635
if sigmas[i + 1] == 0:
16281636
x = denoised
1629-
else:
1630-
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
1631-
h = lambda_t - lambda_s
1632-
h_eta = h * (eta + 1)
1633-
lambda_s_1 = lambda_s + r_1 * h
1634-
lambda_s_2 = lambda_s + r_2 * h
1635-
sigma_s_1, sigma_s_2 = sigma_fn(lambda_s_1), sigma_fn(lambda_s_2)
1636-
1637-
# alpha_t = sigma_t * exp(log(alpha_t / sigma_t)) = sigma_t * exp(lambda_t)
1638-
alpha_s_1 = sigma_s_1 * lambda_s_1.exp()
1639-
alpha_s_2 = sigma_s_2 * lambda_s_2.exp()
1640-
alpha_t = sigmas[i + 1] * lambda_t.exp()
1641-
1642-
coeff_1, coeff_2, coeff_3 = (-r_1 * h_eta).expm1(), (-r_2 * h_eta).expm1(), (-h_eta).expm1()
1643-
if inject_noise:
1644-
# 0 < r_1 < r_2 < 1
1645-
noise_coeff_1 = (-2 * r_1 * h * eta).expm1().neg().sqrt()
1646-
noise_coeff_2 = (-r_1 * h * eta).exp() * (-2 * (r_2 - r_1) * h * eta).expm1().neg().sqrt()
1647-
noise_coeff_3 = (-r_2 * h * eta).exp() * (-2 * (1 - r_2) * h * eta).expm1().neg().sqrt()
1648-
noise_1, noise_2, noise_3 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigma_s_2), noise_sampler(sigma_s_2, sigmas[i + 1])
1649-
1650-
# Step 1
1651-
x_2 = sigma_s_1 / sigmas[i] * (-r_1 * h * eta).exp() * x - alpha_s_1 * coeff_1 * denoised
1652-
if inject_noise:
1653-
x_2 = x_2 + sigma_s_1 * (noise_coeff_1 * noise_1) * s_noise
1654-
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
1655-
1656-
# Step 2
1657-
x_3 = sigma_s_2 / sigmas[i] * (-r_2 * h * eta).exp() * x - alpha_s_2 * coeff_2 * denoised + (r_2 / r_1) * alpha_s_2 * (coeff_2 / (r_2 * h_eta) + 1) * (denoised_2 - denoised)
1658-
if inject_noise:
1659-
x_3 = x_3 + sigma_s_2 * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise
1660-
denoised_3 = model(x_3, sigma_s_2 * s_in, **extra_args)
1661-
1662-
# Step 3
1663-
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * coeff_3 * denoised + (1. / r_2) * alpha_t * (coeff_3 / h_eta + 1) * (denoised_3 - denoised)
1664-
if inject_noise:
1665-
x = x + sigmas[i + 1] * (noise_coeff_3 * noise_1 + noise_coeff_2 * noise_2 + noise_coeff_1 * noise_3) * s_noise
1637+
continue
1638+
1639+
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
1640+
h = lambda_t - lambda_s
1641+
h_eta = h * (eta + 1)
1642+
lambda_s_1 = torch.lerp(lambda_s, lambda_t, r_1)
1643+
lambda_s_2 = torch.lerp(lambda_s, lambda_t, r_2)
1644+
sigma_s_1, sigma_s_2 = sigma_fn(lambda_s_1), sigma_fn(lambda_s_2)
1645+
1646+
alpha_s_1 = sigma_s_1 * lambda_s_1.exp()
1647+
alpha_s_2 = sigma_s_2 * lambda_s_2.exp()
1648+
alpha_t = sigmas[i + 1] * lambda_t.exp()
1649+
1650+
# Step 1
1651+
x_2 = sigma_s_1 / sigmas[i] * (-r_1 * h * eta).exp() * x - alpha_s_1 * ei_h_phi_1(-r_1 * h_eta) * denoised
1652+
if inject_noise:
1653+
sde_noise = (-2 * r_1 * h * eta).expm1().neg().sqrt() * noise_sampler(sigmas[i], sigma_s_1)
1654+
x_2 = x_2 + sde_noise * sigma_s_1 * s_noise
1655+
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
1656+
1657+
# Step 2
1658+
a3_2 = r_2 / r_1 * ei_h_phi_2(-r_2 * h_eta)
1659+
a3_1 = ei_h_phi_1(-r_2 * h_eta) - a3_2
1660+
x_3 = sigma_s_2 / sigmas[i] * (-r_2 * h * eta).exp() * x - alpha_s_2 * (a3_1 * denoised + a3_2 * denoised_2)
1661+
if inject_noise:
1662+
segment_factor = (r_1 - r_2) * h * eta
1663+
sde_noise = sde_noise * segment_factor.exp()
1664+
sde_noise = sde_noise + segment_factor.mul(2).expm1().neg().sqrt() * noise_sampler(sigma_s_1, sigma_s_2)
1665+
x_3 = x_3 + sde_noise * sigma_s_2 * s_noise
1666+
denoised_3 = model(x_3, sigma_s_2 * s_in, **extra_args)
1667+
1668+
# Step 3
1669+
b3 = ei_h_phi_2(-h_eta) / r_2
1670+
b1 = ei_h_phi_1(-h_eta) - b3
1671+
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * (b1 * denoised + b3 * denoised_3)
1672+
if inject_noise:
1673+
segment_factor = (r_2 - 1) * h * eta
1674+
sde_noise = sde_noise * segment_factor.exp()
1675+
sde_noise = sde_noise + segment_factor.mul(2).expm1().neg().sqrt() * noise_sampler(sigma_s_2, sigmas[i + 1])
1676+
x = x + sde_noise * sigmas[i + 1] * s_noise
16661677
return x
16671678

16681679

0 commit comments

Comments
 (0)