@@ -69,8 +69,17 @@ def predict_noise(self, x, timestep, model_options={}, seed=None):
6969 negative_cond = self .conds .get ("negative" , None )
7070 empty_cond = self .conds .get ("empty_negative_prompt" , None )
7171
72- (noise_pred_pos , noise_pred_neg , noise_pred_empty ) = \
73- comfy .samplers .calc_cond_batch (self .inner_model , [positive_cond , negative_cond , empty_cond ], x , timestep , model_options )
72+ conds = [positive_cond , negative_cond , empty_cond ]
73+
74+ out = comfy .samplers .calc_cond_batch (self .inner_model , conds , x , timestep , model_options )
75+
76+ # Apply pre_cfg_functions since sampling_function() is skipped
77+ for fn in model_options .get ("sampler_pre_cfg_function" , []):
78+ args = {"conds" :conds , "conds_out" : out , "cond_scale" : self .cfg , "timestep" : timestep ,
79+ "input" : x , "sigma" : timestep , "model" : self .inner_model , "model_options" : model_options }
80+ out = fn (args )
81+
82+ noise_pred_pos , noise_pred_neg , noise_pred_empty = out
7483 cfg_result = perp_neg (x , noise_pred_pos , noise_pred_neg , noise_pred_empty , self .neg_scale , self .cfg )
7584
7685 # normally this would be done in cfg_function, but we skipped
@@ -82,6 +91,7 @@ def predict_noise(self, x, timestep, model_options={}, seed=None):
8291 "denoised" : cfg_result ,
8392 "cond" : positive_cond ,
8493 "uncond" : negative_cond ,
94+ "cond_scale" : self .cfg ,
8595 "model" : self .inner_model ,
8696 "uncond_denoised" : noise_pred_neg ,
8797 "cond_denoised" : noise_pred_pos ,
0 commit comments