Skip to content

Commit 16b9aab

Browse files
authored
Support Multi/InfiniteTalk (Comfy-Org#10179)
* re-init * Update model_multitalk.py * whitespace... * Update model_multitalk.py * remove print * this is redundant * remove import * Restore preview functionality * Move block_idx to transformer_options * Remove LoopingSamplerCustomAdvanced * Remove looping functionality, keep extension functionality * Update model_multitalk.py * Handle ref_attn_mask with separate patch to avoid having to always return q and k from self_attn * Chunk attention map calculation for multiple speakers to reduce peak VRAM usage * Update model_multitalk.py * Add ModelPatch type back * Fix for latest upstream * Use DynamicCombo for cleaner node Basically just so that single_speaker mode hides mask inputs and 2nd audio input * Update nodes_wan.py
1 parent 245f613 commit 16b9aab

File tree

5 files changed

+727
-3
lines changed

5 files changed

+727
-3
lines changed

comfy/ldm/wan/model.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ def forward(self, x, freqs, transformer_options={}):
6262
x(Tensor): Shape [B, L, num_heads, C / num_heads]
6363
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
6464
"""
65+
patches = transformer_options.get("patches", {})
66+
6567
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
6668

6769
def qkv_fn_q(x):
@@ -86,6 +88,10 @@ def qkv_fn_k(x):
8688
transformer_options=transformer_options,
8789
)
8890

91+
if "attn1_patch" in patches:
92+
for p in patches["attn1_patch"]:
93+
x = p({"x": x, "q": q, "k": k, "transformer_options": transformer_options})
94+
8995
x = self.o(x)
9096
return x
9197

@@ -225,6 +231,8 @@ def forward(
225231
"""
226232
# assert e.dtype == torch.float32
227233

234+
patches = transformer_options.get("patches", {})
235+
228236
if e.ndim < 4:
229237
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1)
230238
else:
@@ -242,6 +250,11 @@ def forward(
242250

243251
# cross-attention & ffn
244252
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options)
253+
254+
if "attn2_patch" in patches:
255+
for p in patches["attn2_patch"]:
256+
x = p({"x": x, "transformer_options": transformer_options})
257+
245258
y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x)))
246259
x = torch.addcmul(x, y, repeat_e(e[5], x))
247260
return x
@@ -488,7 +501,7 @@ def __init__(self,
488501
self.blocks = nn.ModuleList([
489502
wan_attn_block_class(cross_attn_type, dim, ffn_dim, num_heads,
490503
window_size, qk_norm, cross_attn_norm, eps, operation_settings=operation_settings)
491-
for _ in range(num_layers)
504+
for i in range(num_layers)
492505
])
493506

494507
# head
@@ -541,6 +554,7 @@ def forward_orig(
541554
# embeddings
542555
x = self.patch_embedding(x.float()).to(x.dtype)
543556
grid_sizes = x.shape[2:]
557+
transformer_options["grid_sizes"] = grid_sizes
544558
x = x.flatten(2).transpose(1, 2)
545559

546560
# time embeddings
@@ -738,6 +752,7 @@ def forward_orig(
738752
# embeddings
739753
x = self.patch_embedding(x.float()).to(x.dtype)
740754
grid_sizes = x.shape[2:]
755+
transformer_options["grid_sizes"] = grid_sizes
741756
x = x.flatten(2).transpose(1, 2)
742757

743758
# time embeddings

0 commit comments

Comments
 (0)