Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions echoshot/echoshot.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
import torch
from comfy.model_management import get_autocast_device, get_torch_device

def _rope_real_dtype(device):
return torch.float32 if device.type == "mps" else torch.float64

def _rope_freqs_to(freqs, target):
dtype = target.dtype if target.is_complex() else None
return freqs.to(device=target.device, dtype=dtype)

@torch.autocast(device_type=get_autocast_device(get_torch_device()), enabled=False)
@torch.compiler.disable()
def rope_apply_z(x, grid_sizes, freqs, inner_t, shift=6):
Expand All @@ -13,7 +20,7 @@ def rope_apply_z(x, grid_sizes, freqs, inner_t, shift=6):

# precompute multipliers
x_i = torch.view_as_complex(
x[i, :seq_len].to(torch.float64).reshape(seq_len, n, -1, 2)
x[i, :seq_len].to(_rope_real_dtype(x.device)).reshape(seq_len, n, -1, 2)
)
start_ind = [sum(inner_t[i][:_]) for _ in range(len(inner_t[i]))]
end_ind = [sum(inner_t[i][:_+1]) for _ in range(len(inner_t[i]))]
Expand All @@ -26,6 +33,7 @@ def rope_apply_z(x, grid_sizes, freqs, inner_t, shift=6):
freqs_i = shot_freqs.view(f, 1, 1, -1).expand(f, h, w, -1).reshape(seq_len, 1, -1)

# apply rotary embedding
freqs_i = _rope_freqs_to(freqs_i, x_i)
x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
x_i = torch.cat([x_i, x[i, seq_len:]])

Expand All @@ -46,7 +54,7 @@ def rope_apply_c(x, freqs, inner_c, shift=6):

# precompute multipliers
x_i = torch.view_as_complex(
x[i].to(torch.float64).reshape(s, n, -1, 2)
x[i].to(_rope_real_dtype(x.device)).reshape(s, n, -1, 2)
)

freq_select = []
Expand All @@ -58,6 +66,7 @@ def rope_apply_c(x, freqs, inner_c, shift=6):
freqs_i = shot_freqs.view(s, 1, -1)

# apply rotary embedding
freqs_i = _rope_freqs_to(freqs_i, x_i)
x_i = torch.view_as_real(x_i * freqs_i).flatten(2)

# append to collection
Expand All @@ -79,7 +88,7 @@ def rope_apply_echoshot(x, grid_sizes, freqs, inner_t, shift=4):

# precompute multipliers
x_i = torch.view_as_complex(
x[i, :seq_len].to(torch.float64).reshape(seq_len, n, -1, 2)
x[i, :seq_len].to(_rope_real_dtype(x.device)).reshape(seq_len, n, -1, 2)
)
start_ind = [sum(inner_t[i][:_]) for _ in range(len(inner_t[i]))]
end_ind = [sum(inner_t[i][:_+1]) for _ in range(len(inner_t[i]))]
Expand All @@ -96,9 +105,10 @@ def rope_apply_echoshot(x, grid_sizes, freqs, inner_t, shift=4):
], dim=-1).reshape(seq_len, 1, -1)

# apply rotary embedding
freqs_i = _rope_freqs_to(freqs_i, x_i)
x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
x_i = torch.cat([x_i, x[i, seq_len:]])

# append to collection
output.append(x_i)
return torch.stack(output).float()
return torch.stack(output).float()
27 changes: 21 additions & 6 deletions wanvideo/modules/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,13 @@

__all__ = ['WanModel']

def _rope_real_dtype(device):
return torch.float32 if mm.is_device_mps(device) else torch.float64

def _rope_freqs_to(freqs, target):
dtype = target.dtype if target.is_complex() else None
return freqs.to(device=target.device, dtype=dtype)

def apply_rotary_emb_split(hidden_states, freqs_cis, t_dim):
"""Apply rotary embedding only to the spatial (H/W) dimensions, leaving temporal (T) unchanged."""
t_part, hw_part = torch.split(hidden_states, [t_dim, hidden_states.shape[-1] - t_dim], dim=-1)
Expand Down Expand Up @@ -231,7 +238,7 @@ def rope_apply_3d(x, grid_sizes, freqs, reverse_time=False):
seq_len = f * h * w

# precompute multipliers
x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float64).reshape(
x_i = torch.view_as_complex(x[i, :seq_len].to(_rope_real_dtype(x.device)).reshape(
seq_len, n, -1, 2))
if reverse_time:
time_freqs = freqs[0][:f].view(f, 1, 1, -1)
Expand All @@ -253,6 +260,7 @@ def rope_apply_3d(x, grid_sizes, freqs, reverse_time=False):
dim=-1).reshape(seq_len, 1, -1)

# apply rotary embedding
freqs_i = _rope_freqs_to(freqs_i, x_i)
x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
x_i = torch.cat([x_i, x[i, seq_len:]])

Expand All @@ -271,9 +279,9 @@ def rope_apply_1d(x, grid_sizes, freqs):
for i, (l, ) in enumerate(grid_sizes.tolist()):
seq_len = l
# precompute multipliers
x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float64).reshape(
x_i = torch.view_as_complex(x[i, :seq_len].to(_rope_real_dtype(x.device)).reshape(
seq_len, n, -1, 2)) # [l n d//2]
x_i_rope = x_i[:, :, :c_rope] * freqs[:seq_len, None, :] # [L, N, c_rope]
x_i_rope = x_i[:, :, :c_rope] * _rope_freqs_to(freqs[:seq_len, None, :], x_i) # [L, N, c_rope]
x_i_passthrough = x_i[:, :, c_rope:] # untouched dims
x_i = torch.cat([x_i_rope, x_i_passthrough], dim=2)

Expand Down Expand Up @@ -2429,8 +2437,11 @@ def forward(
# params
device = self.main_device

if freqs is not None and freqs.device != device:
freqs = freqs.to(device)
if freqs is not None:
if mm.is_device_mps(device) and freqs.is_complex():
freqs = freqs.to(device=device, dtype=torch.complex64)
elif freqs.device != device:
freqs = freqs.to(device)

_, F, H, W = x[0].shape
ref_frame_shape = pose_frame_shape = None
Expand Down Expand Up @@ -2726,7 +2737,11 @@ def forward(
inner_c = None
if inner_t is not None:
d = self.dim // self.num_heads
self.cross_freqs = rope_params(100, d).to(device=x.device)
self.cross_freqs = rope_params(100, d)
if mm.is_device_mps(x.device):
self.cross_freqs = self.cross_freqs.to(device=x.device, dtype=torch.complex64)
else:
self.cross_freqs = self.cross_freqs.to(device=x.device)

if s2v_ref_motion is not None:
motion_encoded, freqs_motion = self.frame_packer(s2v_ref_motion, self)
Expand Down