@@ -710,6 +710,7 @@ def sample_dpmpp_2s_ancestral_RF(model, x, sigmas, extra_args=None, callback=Non
710710 # logged_x = torch.cat((logged_x, x.unsqueeze(0)), dim=0)
711711 return x
712712
713+
713714@torch .no_grad ()
714715def sample_dpmpp_sde (model , x , sigmas , extra_args = None , callback = None , disable = None , eta = 1. , s_noise = 1. , noise_sampler = None , r = 1 / 2 ):
715716 """DPM-Solver++ (stochastic)."""
@@ -721,38 +722,49 @@ def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=N
721722 seed = extra_args .get ("seed" , None )
722723 noise_sampler = BrownianTreeNoiseSampler (x , sigma_min , sigma_max , seed = seed , cpu = True ) if noise_sampler is None else noise_sampler
723724 s_in = x .new_ones ([x .shape [0 ]])
724- sigma_fn = lambda t : t .neg ().exp ()
725- t_fn = lambda sigma : sigma .log ().neg ()
725+
726+ model_sampling = model .inner_model .model_patcher .get_model_object ('model_sampling' )
727+ sigma_fn = partial (half_log_snr_to_sigma , model_sampling = model_sampling )
728+ lambda_fn = partial (sigma_to_half_log_snr , model_sampling = model_sampling )
729+ sigmas = offset_first_sigma_for_snr (sigmas , model_sampling )
726730
727731 for i in trange (len (sigmas ) - 1 , disable = disable ):
728732 denoised = model (x , sigmas [i ] * s_in , ** extra_args )
729733 if callback is not None :
730734 callback ({'x' : x , 'i' : i , 'sigma' : sigmas [i ], 'sigma_hat' : sigmas [i ], 'denoised' : denoised })
731735 if sigmas [i + 1 ] == 0 :
732- # Euler method
733- d = to_d (x , sigmas [i ], denoised )
734- dt = sigmas [i + 1 ] - sigmas [i ]
735- x = x + d * dt
736+ # Denoising step
737+ x = denoised
736738 else :
737739 # DPM-Solver++
738- t , t_next = t_fn (sigmas [i ]), t_fn (sigmas [i + 1 ])
739- h = t_next - t
740- s = t + h * r
740+ lambda_s , lambda_t = lambda_fn (sigmas [i ]), lambda_fn (sigmas [i + 1 ])
741+ h = lambda_t - lambda_s
742+ lambda_s_1 = lambda_s + r * h
741743 fac = 1 / (2 * r )
742744
745+ sigma_s_1 = sigma_fn (lambda_s_1 )
746+
747+ alpha_s = sigmas [i ] * lambda_s .exp ()
748+ alpha_s_1 = sigma_s_1 * lambda_s_1 .exp ()
749+ alpha_t = sigmas [i + 1 ] * lambda_t .exp ()
750+
743751 # Step 1
744- sd , su = get_ancestral_step (sigma_fn (t ), sigma_fn (s ), eta )
745- s_ = t_fn (sd )
746- x_2 = (sigma_fn (s_ ) / sigma_fn (t )) * x - (t - s_ ).expm1 () * denoised
747- x_2 = x_2 + noise_sampler (sigma_fn (t ), sigma_fn (s )) * s_noise * su
748- denoised_2 = model (x_2 , sigma_fn (s ) * s_in , ** extra_args )
752+ sd , su = get_ancestral_step (lambda_s .neg ().exp (), lambda_s_1 .neg ().exp (), eta )
753+ lambda_s_1_ = sd .log ().neg ()
754+ h_ = lambda_s_1_ - lambda_s
755+ x_2 = (alpha_s_1 / alpha_s ) * (- h_ ).exp () * x - alpha_s_1 * (- h_ ).expm1 () * denoised
756+ if eta > 0 and s_noise > 0 :
757+ x_2 = x_2 + alpha_s_1 * noise_sampler (sigmas [i ], sigma_s_1 ) * s_noise * su
758+ denoised_2 = model (x_2 , sigma_s_1 * s_in , ** extra_args )
749759
750760 # Step 2
751- sd , su = get_ancestral_step (sigma_fn (t ), sigma_fn (t_next ), eta )
752- t_next_ = t_fn (sd )
761+ sd , su = get_ancestral_step (lambda_s .neg ().exp (), lambda_t .neg ().exp (), eta )
762+ lambda_t_ = sd .log ().neg ()
763+ h_ = lambda_t_ - lambda_s
753764 denoised_d = (1 - fac ) * denoised + fac * denoised_2
754- x = (sigma_fn (t_next_ ) / sigma_fn (t )) * x - (t - t_next_ ).expm1 () * denoised_d
755- x = x + noise_sampler (sigma_fn (t ), sigma_fn (t_next )) * s_noise * su
765+ x = (alpha_t / alpha_s ) * (- h_ ).exp () * x - alpha_t * (- h_ ).expm1 () * denoised_d
766+ if eta > 0 and s_noise > 0 :
767+ x = x + alpha_t * noise_sampler (sigmas [i ], sigmas [i + 1 ]) * s_noise * su
756768 return x
757769
758770
0 commit comments