@@ -171,6 +171,16 @@ def offset_first_sigma_for_snr(sigmas, model_sampling, percent_offset=1e-4):
171171 return sigmas
172172
173173
174+ def ei_h_phi_1 (h : torch .Tensor ) -> torch .Tensor :
175+ """Compute the result of h*phi_1(h) in exponential integrator methods."""
176+ return torch .expm1 (h )
177+
178+
179+ def ei_h_phi_2 (h : torch .Tensor ) -> torch .Tensor :
180+ """Compute the result of h*phi_2(h) in exponential integrator methods."""
181+ return (torch .expm1 (h ) - h ) / h
182+
183+
174184@torch .no_grad ()
175185def sample_euler (model , x , sigmas , extra_args = None , callback = None , disable = None , s_churn = 0. , s_tmin = 0. , s_tmax = float ('inf' ), s_noise = 1. ):
176186 """Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
@@ -1550,69 +1560,66 @@ def default_er_sde_noise_scaler(x):
15501560@torch .no_grad ()
15511561def sample_seeds_2 (model , x , sigmas , extra_args = None , callback = None , disable = None , eta = 1. , s_noise = 1. , noise_sampler = None , r = 0.5 ):
15521562 """SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 2.
1553- arXiv: https://arxiv.org/abs/2305.14267
1563+ arXiv: https://arxiv.org/abs/2305.14267 (NeurIPS 2023)
15541564 """
15551565 extra_args = {} if extra_args is None else extra_args
15561566 seed = extra_args .get ("seed" , None )
15571567 noise_sampler = default_noise_sampler (x , seed = seed ) if noise_sampler is None else noise_sampler
15581568 s_in = x .new_ones ([x .shape [0 ]])
1559-
15601569 inject_noise = eta > 0 and s_noise > 0
15611570
15621571 model_sampling = model .inner_model .model_patcher .get_model_object ('model_sampling' )
15631572 sigma_fn = partial (half_log_snr_to_sigma , model_sampling = model_sampling )
15641573 lambda_fn = partial (sigma_to_half_log_snr , model_sampling = model_sampling )
15651574 sigmas = offset_first_sigma_for_snr (sigmas , model_sampling )
15661575
1576+ fac = 1 / (2 * r )
1577+
15671578 for i in trange (len (sigmas ) - 1 , disable = disable ):
15681579 denoised = model (x , sigmas [i ] * s_in , ** extra_args )
15691580 if callback is not None :
15701581 callback ({'x' : x , 'i' : i , 'sigma' : sigmas [i ], 'sigma_hat' : sigmas [i ], 'denoised' : denoised })
1582+
15711583 if sigmas [i + 1 ] == 0 :
15721584 x = denoised
1573- else :
1574- lambda_s , lambda_t = lambda_fn (sigmas [i ]), lambda_fn (sigmas [i + 1 ])
1575- h = lambda_t - lambda_s
1576- h_eta = h * (eta + 1 )
1577- lambda_s_1 = lambda_s + r * h
1578- fac = 1 / (2 * r )
1579- sigma_s_1 = sigma_fn (lambda_s_1 )
1580-
1581- # alpha_t = sigma_t * exp(log(alpha_t / sigma_t)) = sigma_t * exp(lambda_t)
1582- alpha_s_1 = sigma_s_1 * lambda_s_1 .exp ()
1583- alpha_t = sigmas [i + 1 ] * lambda_t .exp ()
1584-
1585- coeff_1 , coeff_2 = (- r * h_eta ).expm1 (), (- h_eta ).expm1 ()
1586- if inject_noise :
1587- # 0 < r < 1
1588- noise_coeff_1 = (- 2 * r * h * eta ).expm1 ().neg ().sqrt ()
1589- noise_coeff_2 = (- r * h * eta ).exp () * (- 2 * (1 - r ) * h * eta ).expm1 ().neg ().sqrt ()
1590- noise_1 , noise_2 = noise_sampler (sigmas [i ], sigma_s_1 ), noise_sampler (sigma_s_1 , sigmas [i + 1 ])
1591-
1592- # Step 1
1593- x_2 = sigma_s_1 / sigmas [i ] * (- r * h * eta ).exp () * x - alpha_s_1 * coeff_1 * denoised
1594- if inject_noise :
1595- x_2 = x_2 + sigma_s_1 * (noise_coeff_1 * noise_1 ) * s_noise
1596- denoised_2 = model (x_2 , sigma_s_1 * s_in , ** extra_args )
1597-
1598- # Step 2
1599- denoised_d = (1 - fac ) * denoised + fac * denoised_2
1600- x = sigmas [i + 1 ] / sigmas [i ] * (- h * eta ).exp () * x - alpha_t * coeff_2 * denoised_d
1601- if inject_noise :
1602- x = x + sigmas [i + 1 ] * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2 ) * s_noise
1585+ continue
1586+
1587+ lambda_s , lambda_t = lambda_fn (sigmas [i ]), lambda_fn (sigmas [i + 1 ])
1588+ h = lambda_t - lambda_s
1589+ h_eta = h * (eta + 1 )
1590+ lambda_s_1 = torch .lerp (lambda_s , lambda_t , r )
1591+ sigma_s_1 = sigma_fn (lambda_s_1 )
1592+
1593+ alpha_s_1 = sigma_s_1 * lambda_s_1 .exp ()
1594+ alpha_t = sigmas [i + 1 ] * lambda_t .exp ()
1595+
1596+ # Step 1
1597+ x_2 = sigma_s_1 / sigmas [i ] * (- r * h * eta ).exp () * x - alpha_s_1 * ei_h_phi_1 (- r * h_eta ) * denoised
1598+ if inject_noise :
1599+ sde_noise = (- 2 * r * h * eta ).expm1 ().neg ().sqrt () * noise_sampler (sigmas [i ], sigma_s_1 )
1600+ x_2 = x_2 + sde_noise * sigma_s_1 * s_noise
1601+ denoised_2 = model (x_2 , sigma_s_1 * s_in , ** extra_args )
1602+
1603+ # Step 2
1604+ denoised_d = torch .lerp (denoised , denoised_2 , fac )
1605+ x = sigmas [i + 1 ] / sigmas [i ] * (- h * eta ).exp () * x - alpha_t * ei_h_phi_1 (- h_eta ) * denoised_d
1606+ if inject_noise :
1607+ segment_factor = (r - 1 ) * h * eta
1608+ sde_noise = sde_noise * segment_factor .exp ()
1609+ sde_noise = sde_noise + segment_factor .mul (2 ).expm1 ().neg ().sqrt () * noise_sampler (sigma_s_1 , sigmas [i + 1 ])
1610+ x = x + sde_noise * sigmas [i + 1 ] * s_noise
16031611 return x
16041612
16051613
16061614@torch .no_grad ()
16071615def sample_seeds_3 (model , x , sigmas , extra_args = None , callback = None , disable = None , eta = 1. , s_noise = 1. , noise_sampler = None , r_1 = 1. / 3 , r_2 = 2. / 3 ):
16081616 """SEEDS-3 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 3.
1609- arXiv: https://arxiv.org/abs/2305.14267
1617+ arXiv: https://arxiv.org/abs/2305.14267 (NeurIPS 2023)
16101618 """
16111619 extra_args = {} if extra_args is None else extra_args
16121620 seed = extra_args .get ("seed" , None )
16131621 noise_sampler = default_noise_sampler (x , seed = seed ) if noise_sampler is None else noise_sampler
16141622 s_in = x .new_ones ([x .shape [0 ]])
1615-
16161623 inject_noise = eta > 0 and s_noise > 0
16171624
16181625 model_sampling = model .inner_model .model_patcher .get_model_object ('model_sampling' )
@@ -1624,45 +1631,49 @@ def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=Non
16241631 denoised = model (x , sigmas [i ] * s_in , ** extra_args )
16251632 if callback is not None :
16261633 callback ({'x' : x , 'i' : i , 'sigma' : sigmas [i ], 'sigma_hat' : sigmas [i ], 'denoised' : denoised })
1634+
16271635 if sigmas [i + 1 ] == 0 :
16281636 x = denoised
1629- else :
1630- lambda_s , lambda_t = lambda_fn (sigmas [i ]), lambda_fn (sigmas [i + 1 ])
1631- h = lambda_t - lambda_s
1632- h_eta = h * (eta + 1 )
1633- lambda_s_1 = lambda_s + r_1 * h
1634- lambda_s_2 = lambda_s + r_2 * h
1635- sigma_s_1 , sigma_s_2 = sigma_fn (lambda_s_1 ), sigma_fn (lambda_s_2 )
1636-
1637- # alpha_t = sigma_t * exp(log(alpha_t / sigma_t)) = sigma_t * exp(lambda_t)
1638- alpha_s_1 = sigma_s_1 * lambda_s_1 .exp ()
1639- alpha_s_2 = sigma_s_2 * lambda_s_2 .exp ()
1640- alpha_t = sigmas [i + 1 ] * lambda_t .exp ()
1641-
1642- coeff_1 , coeff_2 , coeff_3 = (- r_1 * h_eta ).expm1 (), (- r_2 * h_eta ).expm1 (), (- h_eta ).expm1 ()
1643- if inject_noise :
1644- # 0 < r_1 < r_2 < 1
1645- noise_coeff_1 = (- 2 * r_1 * h * eta ).expm1 ().neg ().sqrt ()
1646- noise_coeff_2 = (- r_1 * h * eta ).exp () * (- 2 * (r_2 - r_1 ) * h * eta ).expm1 ().neg ().sqrt ()
1647- noise_coeff_3 = (- r_2 * h * eta ).exp () * (- 2 * (1 - r_2 ) * h * eta ).expm1 ().neg ().sqrt ()
1648- noise_1 , noise_2 , noise_3 = noise_sampler (sigmas [i ], sigma_s_1 ), noise_sampler (sigma_s_1 , sigma_s_2 ), noise_sampler (sigma_s_2 , sigmas [i + 1 ])
1649-
1650- # Step 1
1651- x_2 = sigma_s_1 / sigmas [i ] * (- r_1 * h * eta ).exp () * x - alpha_s_1 * coeff_1 * denoised
1652- if inject_noise :
1653- x_2 = x_2 + sigma_s_1 * (noise_coeff_1 * noise_1 ) * s_noise
1654- denoised_2 = model (x_2 , sigma_s_1 * s_in , ** extra_args )
1655-
1656- # Step 2
1657- x_3 = sigma_s_2 / sigmas [i ] * (- r_2 * h * eta ).exp () * x - alpha_s_2 * coeff_2 * denoised + (r_2 / r_1 ) * alpha_s_2 * (coeff_2 / (r_2 * h_eta ) + 1 ) * (denoised_2 - denoised )
1658- if inject_noise :
1659- x_3 = x_3 + sigma_s_2 * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2 ) * s_noise
1660- denoised_3 = model (x_3 , sigma_s_2 * s_in , ** extra_args )
1661-
1662- # Step 3
1663- x = sigmas [i + 1 ] / sigmas [i ] * (- h * eta ).exp () * x - alpha_t * coeff_3 * denoised + (1. / r_2 ) * alpha_t * (coeff_3 / h_eta + 1 ) * (denoised_3 - denoised )
1664- if inject_noise :
1665- x = x + sigmas [i + 1 ] * (noise_coeff_3 * noise_1 + noise_coeff_2 * noise_2 + noise_coeff_1 * noise_3 ) * s_noise
1637+ continue
1638+
1639+ lambda_s , lambda_t = lambda_fn (sigmas [i ]), lambda_fn (sigmas [i + 1 ])
1640+ h = lambda_t - lambda_s
1641+ h_eta = h * (eta + 1 )
1642+ lambda_s_1 = torch .lerp (lambda_s , lambda_t , r_1 )
1643+ lambda_s_2 = torch .lerp (lambda_s , lambda_t , r_2 )
1644+ sigma_s_1 , sigma_s_2 = sigma_fn (lambda_s_1 ), sigma_fn (lambda_s_2 )
1645+
1646+ alpha_s_1 = sigma_s_1 * lambda_s_1 .exp ()
1647+ alpha_s_2 = sigma_s_2 * lambda_s_2 .exp ()
1648+ alpha_t = sigmas [i + 1 ] * lambda_t .exp ()
1649+
1650+ # Step 1
1651+ x_2 = sigma_s_1 / sigmas [i ] * (- r_1 * h * eta ).exp () * x - alpha_s_1 * ei_h_phi_1 (- r_1 * h_eta ) * denoised
1652+ if inject_noise :
1653+ sde_noise = (- 2 * r_1 * h * eta ).expm1 ().neg ().sqrt () * noise_sampler (sigmas [i ], sigma_s_1 )
1654+ x_2 = x_2 + sde_noise * sigma_s_1 * s_noise
1655+ denoised_2 = model (x_2 , sigma_s_1 * s_in , ** extra_args )
1656+
1657+ # Step 2
1658+ a3_2 = r_2 / r_1 * ei_h_phi_2 (- r_2 * h_eta )
1659+ a3_1 = ei_h_phi_1 (- r_2 * h_eta ) - a3_2
1660+ x_3 = sigma_s_2 / sigmas [i ] * (- r_2 * h * eta ).exp () * x - alpha_s_2 * (a3_1 * denoised + a3_2 * denoised_2 )
1661+ if inject_noise :
1662+ segment_factor = (r_1 - r_2 ) * h * eta
1663+ sde_noise = sde_noise * segment_factor .exp ()
1664+ sde_noise = sde_noise + segment_factor .mul (2 ).expm1 ().neg ().sqrt () * noise_sampler (sigma_s_1 , sigma_s_2 )
1665+ x_3 = x_3 + sde_noise * sigma_s_2 * s_noise
1666+ denoised_3 = model (x_3 , sigma_s_2 * s_in , ** extra_args )
1667+
1668+ # Step 3
1669+ b3 = ei_h_phi_2 (- h_eta ) / r_2
1670+ b1 = ei_h_phi_1 (- h_eta ) - b3
1671+ x = sigmas [i + 1 ] / sigmas [i ] * (- h * eta ).exp () * x - alpha_t * (b1 * denoised + b3 * denoised_3 )
1672+ if inject_noise :
1673+ segment_factor = (r_2 - 1 ) * h * eta
1674+ sde_noise = sde_noise * segment_factor .exp ()
1675+ sde_noise = sde_noise + segment_factor .mul (2 ).expm1 ().neg ().sqrt () * noise_sampler (sigma_s_2 , sigmas [i + 1 ])
1676+ x = x + sde_noise * sigmas [i + 1 ] * s_noise
16661677 return x
16671678
16681679
0 commit comments