Skip to content

Commit 4293e4d

Browse files
authored
Add WAN ATI support (Comfy-Org#8874)
* Add WAN ATI support * Fixes * Fix length * Remove extra functions * Fix * Fix * Ruff fix * Remove torch.no_grad * Add batch trajectory logic * Scale inputs before and after motion patch * Batch image/trajectory * Ruff fix * Clean up
1 parent 69cb57b commit 4293e4d

File tree

2 files changed

+324
-1
lines changed

2 files changed

+324
-1
lines changed

comfy/utils.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -698,6 +698,26 @@ def resize_to_batch_size(tensor, batch_size):
698698

699699
return output
700700

701+
def resize_list_to_batch_size(l, batch_size):
702+
in_batch_size = len(l)
703+
if in_batch_size == batch_size or in_batch_size == 0:
704+
return l
705+
706+
if batch_size <= 1:
707+
return l[:batch_size]
708+
709+
output = []
710+
if batch_size < in_batch_size:
711+
scale = (in_batch_size - 1) / (batch_size - 1)
712+
for i in range(batch_size):
713+
output.append(l[min(round(i * scale), in_batch_size - 1)])
714+
else:
715+
scale = in_batch_size / batch_size
716+
for i in range(batch_size):
717+
output.append(l[min(math.floor((i + 0.5) * scale), in_batch_size - 1)])
718+
719+
return output
720+
701721
def convert_sd_to(state_dict, dtype):
702722
keys = list(state_dict.keys())
703723
for k in keys:

comfy_extras/nodes_wan.py

Lines changed: 304 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
1+
import math
12
import nodes
23
import node_helpers
34
import torch
45
import comfy.model_management
56
import comfy.utils
67
import comfy.latent_formats
78
import comfy.clip_vision
8-
9+
import json
10+
import numpy as np
11+
from typing import Tuple
912

1013
class WanImageToVideo:
1114
@classmethod
@@ -383,7 +386,307 @@ def encode(self, positive, negative, vae, width, height, length, batch_size, ima
383386
out_latent["samples"] = latent
384387
return (positive, cond2, negative, out_latent)
385388

389+
def parse_json_tracks(tracks):
390+
"""Parse JSON track data into a standardized format"""
391+
tracks_data = []
392+
try:
393+
# If tracks is a string, try to parse it as JSON
394+
if isinstance(tracks, str):
395+
parsed = json.loads(tracks.replace("'", '"'))
396+
tracks_data.extend(parsed)
397+
else:
398+
# If tracks is a list of strings, parse each one
399+
for track_str in tracks:
400+
parsed = json.loads(track_str.replace("'", '"'))
401+
tracks_data.append(parsed)
402+
403+
# Check if we have a single track (dict with x,y) or a list of tracks
404+
if tracks_data and isinstance(tracks_data[0], dict) and 'x' in tracks_data[0]:
405+
# Single track detected, wrap it in a list
406+
tracks_data = [tracks_data]
407+
elif tracks_data and isinstance(tracks_data[0], list) and tracks_data[0] and isinstance(tracks_data[0][0], dict) and 'x' in tracks_data[0][0]:
408+
# Already a list of tracks, nothing to do
409+
pass
410+
else:
411+
# Unexpected format
412+
pass
413+
414+
except json.JSONDecodeError:
415+
tracks_data = []
416+
return tracks_data
417+
418+
def process_tracks(tracks_np: np.ndarray, frame_size: Tuple[int, int], num_frames, quant_multi: int = 8, **kwargs):
419+
# tracks: shape [t, h, w, 3] => samples align with 24 fps, model trained with 16 fps.
420+
# frame_size: tuple (W, H)
421+
tracks = torch.from_numpy(tracks_np).float()
422+
423+
if tracks.shape[1] == 121:
424+
tracks = torch.permute(tracks, (1, 0, 2, 3))
425+
426+
tracks, visibles = tracks[..., :2], tracks[..., 2:3]
427+
428+
short_edge = min(*frame_size)
429+
430+
frame_center = torch.tensor([*frame_size]).type_as(tracks) / 2
431+
tracks = tracks - frame_center
432+
433+
tracks = tracks / short_edge * 2
434+
435+
visibles = visibles * 2 - 1
436+
437+
trange = torch.linspace(-1, 1, tracks.shape[0]).view(-1, 1, 1, 1).expand(*visibles.shape)
438+
439+
out_ = torch.cat([trange, tracks, visibles], dim=-1).view(121, -1, 4)
440+
441+
out_0 = out_[:1]
442+
443+
out_l = out_[1:] # 121 => 120 | 1
444+
a = 120 // math.gcd(120, num_frames)
445+
b = num_frames // math.gcd(120, num_frames)
446+
out_l = torch.repeat_interleave(out_l, b, dim=0)[1::a] # 120 => 120 * b => 120 * b / a == F
447+
448+
final_result = torch.cat([out_0, out_l], dim=0)
449+
450+
return final_result
451+
452+
FIXED_LENGTH = 121
453+
def pad_pts(tr):
454+
"""Convert list of {x,y} to (FIXED_LENGTH,1,3) array, padding/truncating."""
455+
pts = np.array([[p['x'], p['y'], 1] for p in tr], dtype=np.float32)
456+
n = pts.shape[0]
457+
if n < FIXED_LENGTH:
458+
pad = np.zeros((FIXED_LENGTH - n, 3), dtype=np.float32)
459+
pts = np.vstack((pts, pad))
460+
else:
461+
pts = pts[:FIXED_LENGTH]
462+
return pts.reshape(FIXED_LENGTH, 1, 3)
463+
464+
def ind_sel(target: torch.Tensor, ind: torch.Tensor, dim: int = 1):
465+
"""Index selection utility function"""
466+
assert (
467+
len(ind.shape) > dim
468+
), "Index must have the target dim, but get dim: %d, ind shape: %s" % (dim, str(ind.shape))
469+
470+
target = target.expand(
471+
*tuple(
472+
[ind.shape[k] if target.shape[k] == 1 else -1 for k in range(dim)]
473+
+ [
474+
-1,
475+
]
476+
* (len(target.shape) - dim)
477+
)
478+
)
479+
480+
ind_pad = ind
481+
482+
if len(target.shape) > dim + 1:
483+
for _ in range(len(target.shape) - (dim + 1)):
484+
ind_pad = ind_pad.unsqueeze(-1)
485+
ind_pad = ind_pad.expand(*(-1,) * (dim + 1), *target.shape[(dim + 1) : :])
486+
487+
return torch.gather(target, dim=dim, index=ind_pad)
488+
489+
def merge_final(vert_attr: torch.Tensor, weight: torch.Tensor, vert_assign: torch.Tensor):
490+
"""Merge vertex attributes with weights"""
491+
target_dim = len(vert_assign.shape) - 1
492+
if len(vert_attr.shape) == 2:
493+
assert vert_attr.shape[0] > vert_assign.max()
494+
new_shape = [1] * target_dim + list(vert_attr.shape)
495+
tensor = vert_attr.reshape(new_shape)
496+
sel_attr = ind_sel(tensor, vert_assign.type(torch.long), dim=target_dim)
497+
else:
498+
assert vert_attr.shape[1] > vert_assign.max()
499+
new_shape = [vert_attr.shape[0]] + [1] * (target_dim - 1) + list(vert_attr.shape[1:])
500+
tensor = vert_attr.reshape(new_shape)
501+
sel_attr = ind_sel(tensor, vert_assign.type(torch.long), dim=target_dim)
502+
503+
final_attr = torch.sum(sel_attr * weight.unsqueeze(-1), dim=-2)
504+
return final_attr
505+
506+
507+
def _patch_motion_single(
508+
tracks: torch.FloatTensor, # (B, T, N, 4)
509+
vid: torch.FloatTensor, # (C, T, H, W)
510+
temperature: float,
511+
vae_divide: tuple,
512+
topk: int,
513+
):
514+
"""Apply motion patching based on tracks"""
515+
_, T, H, W = vid.shape
516+
N = tracks.shape[2]
517+
_, tracks_xy, visible = torch.split(
518+
tracks, [1, 2, 1], dim=-1
519+
) # (B, T, N, 2) | (B, T, N, 1)
520+
tracks_n = tracks_xy / torch.tensor([W / min(H, W), H / min(H, W)], device=tracks_xy.device)
521+
tracks_n = tracks_n.clamp(-1, 1)
522+
visible = visible.clamp(0, 1)
523+
524+
xx = torch.linspace(-W / min(H, W), W / min(H, W), W)
525+
yy = torch.linspace(-H / min(H, W), H / min(H, W), H)
526+
527+
grid = torch.stack(torch.meshgrid(yy, xx, indexing="ij")[::-1], dim=-1).to(
528+
tracks_xy.device
529+
)
530+
531+
tracks_pad = tracks_xy[:, 1:]
532+
visible_pad = visible[:, 1:]
533+
534+
visible_align = visible_pad.view(T - 1, 4, *visible_pad.shape[2:]).sum(1)
535+
tracks_align = (tracks_pad * visible_pad).view(T - 1, 4, *tracks_pad.shape[2:]).sum(
536+
1
537+
) / (visible_align + 1e-5)
538+
dist_ = (
539+
(tracks_align[:, None, None] - grid[None, :, :, None]).pow(2).sum(-1)
540+
) # T, H, W, N
541+
weight = torch.exp(-dist_ * temperature) * visible_align.clamp(0, 1).view(
542+
T - 1, 1, 1, N
543+
)
544+
vert_weight, vert_index = torch.topk(
545+
weight, k=min(topk, weight.shape[-1]), dim=-1
546+
)
547+
548+
grid_mode = "bilinear"
549+
point_feature = torch.nn.functional.grid_sample(
550+
vid.permute(1, 0, 2, 3)[:1],
551+
tracks_n[:, :1].type(vid.dtype),
552+
mode=grid_mode,
553+
padding_mode="zeros",
554+
align_corners=False,
555+
)
556+
point_feature = point_feature.squeeze(0).squeeze(1).permute(1, 0) # N, C=16
557+
558+
out_feature = merge_final(point_feature, vert_weight, vert_index).permute(3, 0, 1, 2) # T - 1, H, W, C => C, T - 1, H, W
559+
out_weight = vert_weight.sum(-1) # T - 1, H, W
560+
561+
# out feature -> already soft weighted
562+
mix_feature = out_feature + vid[:, 1:] * (1 - out_weight.clamp(0, 1))
563+
564+
out_feature_full = torch.cat([vid[:, :1], mix_feature], dim=1) # C, T, H, W
565+
out_mask_full = torch.cat([torch.ones_like(out_weight[:1]), out_weight], dim=0) # T, H, W
566+
567+
return out_mask_full[None].expand(vae_divide[0], -1, -1, -1), out_feature_full
568+
569+
570+
def patch_motion(
571+
tracks: torch.FloatTensor, # (B, TB, T, N, 4)
572+
vid: torch.FloatTensor, # (C, T, H, W)
573+
temperature: float = 220.0,
574+
vae_divide: tuple = (4, 16),
575+
topk: int = 2,
576+
):
577+
B = len(tracks)
578+
579+
# Process each batch separately
580+
out_masks = []
581+
out_features = []
582+
583+
for b in range(B):
584+
mask, feature = _patch_motion_single(
585+
tracks[b], # (T, N, 4)
586+
vid[b], # (C, T, H, W)
587+
temperature,
588+
vae_divide,
589+
topk
590+
)
591+
out_masks.append(mask)
592+
out_features.append(feature)
593+
594+
# Stack results: (B, C, T, H, W)
595+
out_mask_full = torch.stack(out_masks, dim=0)
596+
out_feature_full = torch.stack(out_features, dim=0)
597+
598+
return out_mask_full, out_feature_full
599+
600+
class WanTrackToVideo:
601+
@classmethod
602+
def INPUT_TYPES(s):
603+
return {"required": {
604+
"positive": ("CONDITIONING", ),
605+
"negative": ("CONDITIONING", ),
606+
"vae": ("VAE", ),
607+
"tracks": ("STRING", {"multiline": True, "default": "[]"}),
608+
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
609+
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
610+
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
611+
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
612+
"temperature": ("FLOAT", {"default": 220.0, "min": 1.0, "max": 1000.0, "step": 0.1}),
613+
"topk": ("INT", {"default": 2, "min": 1, "max": 10}),
614+
"start_image": ("IMAGE", ),
615+
},
616+
"optional": {
617+
"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
618+
}}
619+
620+
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
621+
RETURN_NAMES = ("positive", "negative", "latent")
622+
FUNCTION = "encode"
623+
624+
CATEGORY = "conditioning/video_models"
625+
626+
def encode(self, positive, negative, vae, tracks, width, height, length, batch_size,
627+
temperature, topk, start_image=None, clip_vision_output=None):
628+
629+
tracks_data = parse_json_tracks(tracks)
630+
631+
if not tracks_data:
632+
return WanImageToVideo().encode(positive, negative, vae, width, height, length, batch_size, start_image=start_image, clip_vision_output=clip_vision_output)
633+
634+
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8],
635+
device=comfy.model_management.intermediate_device())
636+
637+
if isinstance(tracks_data[0][0], dict):
638+
tracks_data = [tracks_data]
639+
640+
processed_tracks = []
641+
for batch in tracks_data:
642+
arrs = []
643+
for track in batch:
644+
pts = pad_pts(track)
645+
arrs.append(pts)
646+
647+
tracks_np = np.stack(arrs, axis=0)
648+
processed_tracks.append(process_tracks(tracks_np, (width, height), length - 1).unsqueeze(0))
649+
650+
if start_image is not None:
651+
start_image = comfy.utils.common_upscale(start_image[:batch_size].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
652+
videos = torch.ones((start_image.shape[0], length, height, width, start_image.shape[-1]), device=start_image.device, dtype=start_image.dtype) * 0.5
653+
for i in range(start_image.shape[0]):
654+
videos[i, 0] = start_image[i]
655+
656+
latent_videos = []
657+
videos = comfy.utils.resize_to_batch_size(videos, batch_size)
658+
for i in range(batch_size):
659+
latent_videos += [vae.encode(videos[i, :, :, :, :3])]
660+
y = torch.cat(latent_videos, dim=0)
661+
662+
# Scale latent since patch_motion is non-linear
663+
y = comfy.latent_formats.Wan21().process_in(y)
664+
665+
processed_tracks = comfy.utils.resize_list_to_batch_size(processed_tracks, batch_size)
666+
res = patch_motion(
667+
processed_tracks, y, temperature=temperature, topk=topk, vae_divide=(4, 16)
668+
)
669+
670+
mask, concat_latent_image = res
671+
concat_latent_image = comfy.latent_formats.Wan21().process_out(concat_latent_image)
672+
mask = -mask + 1.0 # Invert mask to match expected format
673+
positive = node_helpers.conditioning_set_values(positive,
674+
{"concat_mask": mask,
675+
"concat_latent_image": concat_latent_image})
676+
negative = node_helpers.conditioning_set_values(negative,
677+
{"concat_mask": mask,
678+
"concat_latent_image": concat_latent_image})
679+
680+
if clip_vision_output is not None:
681+
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
682+
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
683+
684+
out_latent = {}
685+
out_latent["samples"] = latent
686+
return (positive, negative, out_latent)
687+
386688
NODE_CLASS_MAPPINGS = {
689+
"WanTrackToVideo": WanTrackToVideo,
387690
"WanImageToVideo": WanImageToVideo,
388691
"WanFunControlToVideo": WanFunControlToVideo,
389692
"WanFunInpaintToVideo": WanFunInpaintToVideo,

0 commit comments

Comments
 (0)