Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions comfy/ldm/wan/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def qkv_fn_k(x):
)

x = self.o(x)
return x
return x, q, k
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is some uncertainty about whether returning this will in general increase the memory peak of WAN within native ComfyUI. Instead, comfy suggests that you add a patch to replace the x = optimized_attention(...) call on line 81 byreusing the ModelPatcher.set_model_attn1_replace functionality (in unet, attn1 is self, attn2 is cross), which can then do the optimized_attention call + the partial attention thing that happens inside the cross_attn patch. To get the q + k for the cross_attn patch, you can store the q and k values in transformer_options instead and then pop them out after usage.

The transformer_index can stay None (not given) since that was something unique to unet models.

It would probably be more optimal to not call optimized_attention anymore and just reuse the logic of hte slower partial attention thingy in this code, but comfy said he would be fine if you didn't go that far and just kept both within that attention replacement function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one I'm a bit unsure of, is there an example of such a patch?



class WanT2VCrossAttention(WanSelfAttention):
Expand Down Expand Up @@ -225,14 +225,16 @@ def forward(
"""
# assert e.dtype == torch.float32

patches = transformer_options.get("patches", {})

if e.ndim < 4:
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1)
else:
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device).unsqueeze(0) + e).unbind(2)
# assert e[0].dtype == torch.float32

# self-attention
y = self.self_attn(
y, q, k = self.self_attn(
torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)),
freqs, transformer_options=transformer_options)

Expand All @@ -241,6 +243,11 @@ def forward(

# cross-attention & ffn
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options)

if "cross_attn" in patches:
for p in patches["cross_attn"]:
x = x + p({"x": x, "q": q, "k": k, "transformer_options": transformer_options})

y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x)))
x = torch.addcmul(x, y, repeat_e(e[5], x))
return x
Expand Down Expand Up @@ -487,7 +494,7 @@ def __init__(self,
self.blocks = nn.ModuleList([
wan_attn_block_class(cross_attn_type, dim, ffn_dim, num_heads,
window_size, qk_norm, cross_attn_norm, eps, operation_settings=operation_settings)
for _ in range(num_layers)
for i in range(num_layers)
])

# head
Expand Down Expand Up @@ -540,6 +547,7 @@ def forward_orig(
# embeddings
x = self.patch_embedding(x.float()).to(x.dtype)
grid_sizes = x.shape[2:]
transformer_options["grid_sizes"] = grid_sizes
x = x.flatten(2).transpose(1, 2)

# time embeddings
Expand Down Expand Up @@ -568,6 +576,7 @@ def forward_orig(
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.blocks):
transformer_options["block_idx"] = i
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
Expand Down Expand Up @@ -734,6 +743,7 @@ def forward_orig(
# embeddings
x = self.patch_embedding(x.float()).to(x.dtype)
grid_sizes = x.shape[2:]
transformer_options["grid_sizes"] = grid_sizes
x = x.flatten(2).transpose(1, 2)

# time embeddings
Expand Down Expand Up @@ -763,6 +773,7 @@ def forward_orig(
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.blocks):
transformer_options["block_idx"] = i
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
Expand Down
Loading