Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 72 additions & 9 deletions liveavatar/models/wan/causal_s2v_pipeline_tpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,7 +608,36 @@ def _initialize_kv_cache(self, batch_size, dtype, device, kv_cache_size=13500):
})

self.kv_cache1 = kv_cache1 # always store the clean cache


def _sync_kv_cache(self, src_rank, num_gpus_dit):
"""
Broadcast KV cache from src_rank to all other DiT GPUs.

Called after each denoising step to ensure all GPUs have consistent
temporal attention state for modulo cycling.
"""
if self.kv_cache1 is None:
return

my_rank = dist.get_rank()
if my_rank >= num_gpus_dit:
return # VAE rank doesn't participate

for layer_cache in self.kv_cache1:
for key in ["k", "v", "cond_k", "cond_v", "cond_end"]:
tensor = layer_cache[key]

if my_rank == src_rank:
# Source: send to all other DiT GPUs
for dst in range(num_gpus_dit):
if dst != src_rank:
dist.send(tensor.contiguous(), dst)
else:
# Destination: receive from source
recv_buf = torch.empty_like(tensor)
dist.recv(recv_buf, src_rank)
tensor.copy_(recv_buf)

def _move_kv_cache_to_working_gpu(self,moved_id, gpu_id=0):
"""
Move the KV cache to the working GPU.
Expand Down Expand Up @@ -1048,13 +1077,31 @@ def dilate_mask_by_ratio(mask_tensor: torch.Tensor, ratio: float = 0.3, thr: flo
}

for i, t in enumerate(tqdm(timesteps)):
if i != dist.get_rank():
# Modulo cycling: step i runs on rank (i % num_gpus_dit)
my_rank = dist.get_rank()
if my_rank >= num_gpus_dit:
# VAE rank skips all denoising steps
continue
if self.src_gpu is None:
latent_model_input = block_latents #[16,num_frames_per_block,h,w]
else:
latent_model_input = torch.empty_like(block_latents) # 创建空tensor接收
if i % num_gpus_dit != my_rank:
continue

# Dynamic source: first step of each cycle starts fresh or receives from last DiT rank
is_first_in_cycle = (i % num_gpus_dit == 0)
is_very_first_step = (i == 0)

if is_very_first_step:
# Very first step: use initial block_latents
latent_model_input = block_latents
elif is_first_in_cycle:
# Start of new cycle: receive from last DiT rank (rank num_gpus_dit-1)
latent_model_input = torch.empty_like(block_latents)
dist.recv(latent_model_input, num_gpus_dit - 1)
elif self.src_gpu is not None:
# Normal step in cycle: receive from previous rank
latent_model_input = torch.empty_like(block_latents)
dist.recv(latent_model_input, self.src_gpu)
else:
latent_model_input = block_latents

timestep = [t] * self.num_frames_per_block
timestep = torch.tensor(timestep).to(self.device).unsqueeze(0)
Expand All @@ -1069,16 +1116,32 @@ def dilate_mask_by_ratio(mask_tensor: torch.Tensor, ratio: float = 0.3, thr: flo

noise_pred = [torch.cat(noise_pred_cond, dim=0)]

# Update scheduler step index for modulo cycling
sample_scheduler._step_index = i
temp_x0 = sample_scheduler.step(
noise_pred[0].unsqueeze(0),# [16,f,h,w]
t,
latent_model_input.unsqueeze(0), #[1,16,f,h,w]
return_dict=False,
generator=seed_g)[0]
block_latents = temp_x0.squeeze(0) #[16,num_frames_per_block,h,w]
if self.tgt_gpu is None:
pass
else:

# Sync KV cache after each step for modulo cycling
self._sync_kv_cache(i % num_gpus_dit, num_gpus_dit)

# Dynamic target: last rank in cycle sends to rank 0, unless it's the final step
is_last_in_cycle = (my_rank == num_gpus_dit - 1)
is_final_step = (i == len(timesteps) - 1)

if is_final_step:
# Final step: send to VAE rank
if self.tgt_gpu is not None:
dist.send(block_latents.contiguous(), self.tgt_gpu)
elif is_last_in_cycle:
# End of cycle but not final: send to rank 0 for next cycle
dist.send(block_latents.contiguous(), 0)
elif self.tgt_gpu is not None and self.tgt_gpu < num_gpus_dit:
# Normal step: send to next rank (but not to VAE)
dist.send(block_latents.contiguous(), self.tgt_gpu)

if enable_vae_parallel and dist.get_rank() == num_gpus_dit-1+int(enable_vae_parallel):
Expand Down