diff --git a/nodes_sampler.py b/nodes_sampler.py index dfe17ade..59c6e490 100644 --- a/nodes_sampler.py +++ b/nodes_sampler.py @@ -1258,7 +1258,7 @@ def predict_with_cfg(z, cfg_scale, positive_embeds, negative_embeds, timestep, i latent_end = latent_frames mask_end = latent_end + 1 partial_latents = mocha_embeds[:, context_window] # windowed latents - mask_frame = mocha_embeds[:, latent_end:mask_end] # single mask frame + mask_frame = mocha_embeds[:, -mocha_num_refs-1:-mocha_num_refs] # single mask frame ref_frames = mocha_embeds[:, -mocha_num_refs:] # reference frames partial_mocha_embeds = torch.cat([partial_latents, mask_frame, ref_frames], dim=1)