diff --git a/nodes_sampler.py b/nodes_sampler.py index 4dc208fe..e10d8332 100644 --- a/nodes_sampler.py +++ b/nodes_sampler.py @@ -1176,7 +1176,7 @@ def predict_with_cfg(z, cfg_scale, positive_embeds, negative_embeds, timestep, i with torch.autocast(device_type=mm.get_autocast_device(device), dtype=dtype) if autocast_enabled else nullcontext(): if use_cfg_zero_star and (idx <= zero_star_steps) and use_zero_init: - return z*0, None + return z*0, None, cache_state nonlocal patcher current_step_percentage = idx / len(timesteps)