|
| 1 | +import math |
1 | 2 | import nodes |
2 | 3 | import node_helpers |
3 | 4 | import torch |
4 | 5 | import comfy.model_management |
5 | 6 | import comfy.utils |
6 | 7 | import comfy.latent_formats |
7 | 8 | import comfy.clip_vision |
8 | | - |
| 9 | +import json |
| 10 | +import numpy as np |
| 11 | +from typing import Tuple |
9 | 12 |
|
10 | 13 | class WanImageToVideo: |
11 | 14 | @classmethod |
@@ -383,7 +386,307 @@ def encode(self, positive, negative, vae, width, height, length, batch_size, ima |
383 | 386 | out_latent["samples"] = latent |
384 | 387 | return (positive, cond2, negative, out_latent) |
385 | 388 |
|
| 389 | +def parse_json_tracks(tracks): |
| 390 | + """Parse JSON track data into a standardized format""" |
| 391 | + tracks_data = [] |
| 392 | + try: |
| 393 | + # If tracks is a string, try to parse it as JSON |
| 394 | + if isinstance(tracks, str): |
| 395 | + parsed = json.loads(tracks.replace("'", '"')) |
| 396 | + tracks_data.extend(parsed) |
| 397 | + else: |
| 398 | + # If tracks is a list of strings, parse each one |
| 399 | + for track_str in tracks: |
| 400 | + parsed = json.loads(track_str.replace("'", '"')) |
| 401 | + tracks_data.append(parsed) |
| 402 | + |
| 403 | + # Check if we have a single track (dict with x,y) or a list of tracks |
| 404 | + if tracks_data and isinstance(tracks_data[0], dict) and 'x' in tracks_data[0]: |
| 405 | + # Single track detected, wrap it in a list |
| 406 | + tracks_data = [tracks_data] |
| 407 | + elif tracks_data and isinstance(tracks_data[0], list) and tracks_data[0] and isinstance(tracks_data[0][0], dict) and 'x' in tracks_data[0][0]: |
| 408 | + # Already a list of tracks, nothing to do |
| 409 | + pass |
| 410 | + else: |
| 411 | + # Unexpected format |
| 412 | + pass |
| 413 | + |
| 414 | + except json.JSONDecodeError: |
| 415 | + tracks_data = [] |
| 416 | + return tracks_data |
| 417 | + |
| 418 | +def process_tracks(tracks_np: np.ndarray, frame_size: Tuple[int, int], num_frames, quant_multi: int = 8, **kwargs): |
| 419 | + # tracks: shape [t, h, w, 3] => samples align with 24 fps, model trained with 16 fps. |
| 420 | + # frame_size: tuple (W, H) |
| 421 | + tracks = torch.from_numpy(tracks_np).float() |
| 422 | + |
| 423 | + if tracks.shape[1] == 121: |
| 424 | + tracks = torch.permute(tracks, (1, 0, 2, 3)) |
| 425 | + |
| 426 | + tracks, visibles = tracks[..., :2], tracks[..., 2:3] |
| 427 | + |
| 428 | + short_edge = min(*frame_size) |
| 429 | + |
| 430 | + frame_center = torch.tensor([*frame_size]).type_as(tracks) / 2 |
| 431 | + tracks = tracks - frame_center |
| 432 | + |
| 433 | + tracks = tracks / short_edge * 2 |
| 434 | + |
| 435 | + visibles = visibles * 2 - 1 |
| 436 | + |
| 437 | + trange = torch.linspace(-1, 1, tracks.shape[0]).view(-1, 1, 1, 1).expand(*visibles.shape) |
| 438 | + |
| 439 | + out_ = torch.cat([trange, tracks, visibles], dim=-1).view(121, -1, 4) |
| 440 | + |
| 441 | + out_0 = out_[:1] |
| 442 | + |
| 443 | + out_l = out_[1:] # 121 => 120 | 1 |
| 444 | + a = 120 // math.gcd(120, num_frames) |
| 445 | + b = num_frames // math.gcd(120, num_frames) |
| 446 | + out_l = torch.repeat_interleave(out_l, b, dim=0)[1::a] # 120 => 120 * b => 120 * b / a == F |
| 447 | + |
| 448 | + final_result = torch.cat([out_0, out_l], dim=0) |
| 449 | + |
| 450 | + return final_result |
| 451 | + |
| 452 | +FIXED_LENGTH = 121 |
| 453 | +def pad_pts(tr): |
| 454 | + """Convert list of {x,y} to (FIXED_LENGTH,1,3) array, padding/truncating.""" |
| 455 | + pts = np.array([[p['x'], p['y'], 1] for p in tr], dtype=np.float32) |
| 456 | + n = pts.shape[0] |
| 457 | + if n < FIXED_LENGTH: |
| 458 | + pad = np.zeros((FIXED_LENGTH - n, 3), dtype=np.float32) |
| 459 | + pts = np.vstack((pts, pad)) |
| 460 | + else: |
| 461 | + pts = pts[:FIXED_LENGTH] |
| 462 | + return pts.reshape(FIXED_LENGTH, 1, 3) |
| 463 | + |
| 464 | +def ind_sel(target: torch.Tensor, ind: torch.Tensor, dim: int = 1): |
| 465 | + """Index selection utility function""" |
| 466 | + assert ( |
| 467 | + len(ind.shape) > dim |
| 468 | + ), "Index must have the target dim, but get dim: %d, ind shape: %s" % (dim, str(ind.shape)) |
| 469 | + |
| 470 | + target = target.expand( |
| 471 | + *tuple( |
| 472 | + [ind.shape[k] if target.shape[k] == 1 else -1 for k in range(dim)] |
| 473 | + + [ |
| 474 | + -1, |
| 475 | + ] |
| 476 | + * (len(target.shape) - dim) |
| 477 | + ) |
| 478 | + ) |
| 479 | + |
| 480 | + ind_pad = ind |
| 481 | + |
| 482 | + if len(target.shape) > dim + 1: |
| 483 | + for _ in range(len(target.shape) - (dim + 1)): |
| 484 | + ind_pad = ind_pad.unsqueeze(-1) |
| 485 | + ind_pad = ind_pad.expand(*(-1,) * (dim + 1), *target.shape[(dim + 1) : :]) |
| 486 | + |
| 487 | + return torch.gather(target, dim=dim, index=ind_pad) |
| 488 | + |
| 489 | +def merge_final(vert_attr: torch.Tensor, weight: torch.Tensor, vert_assign: torch.Tensor): |
| 490 | + """Merge vertex attributes with weights""" |
| 491 | + target_dim = len(vert_assign.shape) - 1 |
| 492 | + if len(vert_attr.shape) == 2: |
| 493 | + assert vert_attr.shape[0] > vert_assign.max() |
| 494 | + new_shape = [1] * target_dim + list(vert_attr.shape) |
| 495 | + tensor = vert_attr.reshape(new_shape) |
| 496 | + sel_attr = ind_sel(tensor, vert_assign.type(torch.long), dim=target_dim) |
| 497 | + else: |
| 498 | + assert vert_attr.shape[1] > vert_assign.max() |
| 499 | + new_shape = [vert_attr.shape[0]] + [1] * (target_dim - 1) + list(vert_attr.shape[1:]) |
| 500 | + tensor = vert_attr.reshape(new_shape) |
| 501 | + sel_attr = ind_sel(tensor, vert_assign.type(torch.long), dim=target_dim) |
| 502 | + |
| 503 | + final_attr = torch.sum(sel_attr * weight.unsqueeze(-1), dim=-2) |
| 504 | + return final_attr |
| 505 | + |
| 506 | + |
| 507 | +def _patch_motion_single( |
| 508 | + tracks: torch.FloatTensor, # (B, T, N, 4) |
| 509 | + vid: torch.FloatTensor, # (C, T, H, W) |
| 510 | + temperature: float, |
| 511 | + vae_divide: tuple, |
| 512 | + topk: int, |
| 513 | +): |
| 514 | + """Apply motion patching based on tracks""" |
| 515 | + _, T, H, W = vid.shape |
| 516 | + N = tracks.shape[2] |
| 517 | + _, tracks_xy, visible = torch.split( |
| 518 | + tracks, [1, 2, 1], dim=-1 |
| 519 | + ) # (B, T, N, 2) | (B, T, N, 1) |
| 520 | + tracks_n = tracks_xy / torch.tensor([W / min(H, W), H / min(H, W)], device=tracks_xy.device) |
| 521 | + tracks_n = tracks_n.clamp(-1, 1) |
| 522 | + visible = visible.clamp(0, 1) |
| 523 | + |
| 524 | + xx = torch.linspace(-W / min(H, W), W / min(H, W), W) |
| 525 | + yy = torch.linspace(-H / min(H, W), H / min(H, W), H) |
| 526 | + |
| 527 | + grid = torch.stack(torch.meshgrid(yy, xx, indexing="ij")[::-1], dim=-1).to( |
| 528 | + tracks_xy.device |
| 529 | + ) |
| 530 | + |
| 531 | + tracks_pad = tracks_xy[:, 1:] |
| 532 | + visible_pad = visible[:, 1:] |
| 533 | + |
| 534 | + visible_align = visible_pad.view(T - 1, 4, *visible_pad.shape[2:]).sum(1) |
| 535 | + tracks_align = (tracks_pad * visible_pad).view(T - 1, 4, *tracks_pad.shape[2:]).sum( |
| 536 | + 1 |
| 537 | + ) / (visible_align + 1e-5) |
| 538 | + dist_ = ( |
| 539 | + (tracks_align[:, None, None] - grid[None, :, :, None]).pow(2).sum(-1) |
| 540 | + ) # T, H, W, N |
| 541 | + weight = torch.exp(-dist_ * temperature) * visible_align.clamp(0, 1).view( |
| 542 | + T - 1, 1, 1, N |
| 543 | + ) |
| 544 | + vert_weight, vert_index = torch.topk( |
| 545 | + weight, k=min(topk, weight.shape[-1]), dim=-1 |
| 546 | + ) |
| 547 | + |
| 548 | + grid_mode = "bilinear" |
| 549 | + point_feature = torch.nn.functional.grid_sample( |
| 550 | + vid.permute(1, 0, 2, 3)[:1], |
| 551 | + tracks_n[:, :1].type(vid.dtype), |
| 552 | + mode=grid_mode, |
| 553 | + padding_mode="zeros", |
| 554 | + align_corners=False, |
| 555 | + ) |
| 556 | + point_feature = point_feature.squeeze(0).squeeze(1).permute(1, 0) # N, C=16 |
| 557 | + |
| 558 | + out_feature = merge_final(point_feature, vert_weight, vert_index).permute(3, 0, 1, 2) # T - 1, H, W, C => C, T - 1, H, W |
| 559 | + out_weight = vert_weight.sum(-1) # T - 1, H, W |
| 560 | + |
| 561 | + # out feature -> already soft weighted |
| 562 | + mix_feature = out_feature + vid[:, 1:] * (1 - out_weight.clamp(0, 1)) |
| 563 | + |
| 564 | + out_feature_full = torch.cat([vid[:, :1], mix_feature], dim=1) # C, T, H, W |
| 565 | + out_mask_full = torch.cat([torch.ones_like(out_weight[:1]), out_weight], dim=0) # T, H, W |
| 566 | + |
| 567 | + return out_mask_full[None].expand(vae_divide[0], -1, -1, -1), out_feature_full |
| 568 | + |
| 569 | + |
| 570 | +def patch_motion( |
| 571 | + tracks: torch.FloatTensor, # (B, TB, T, N, 4) |
| 572 | + vid: torch.FloatTensor, # (C, T, H, W) |
| 573 | + temperature: float = 220.0, |
| 574 | + vae_divide: tuple = (4, 16), |
| 575 | + topk: int = 2, |
| 576 | +): |
| 577 | + B = len(tracks) |
| 578 | + |
| 579 | + # Process each batch separately |
| 580 | + out_masks = [] |
| 581 | + out_features = [] |
| 582 | + |
| 583 | + for b in range(B): |
| 584 | + mask, feature = _patch_motion_single( |
| 585 | + tracks[b], # (T, N, 4) |
| 586 | + vid[b], # (C, T, H, W) |
| 587 | + temperature, |
| 588 | + vae_divide, |
| 589 | + topk |
| 590 | + ) |
| 591 | + out_masks.append(mask) |
| 592 | + out_features.append(feature) |
| 593 | + |
| 594 | + # Stack results: (B, C, T, H, W) |
| 595 | + out_mask_full = torch.stack(out_masks, dim=0) |
| 596 | + out_feature_full = torch.stack(out_features, dim=0) |
| 597 | + |
| 598 | + return out_mask_full, out_feature_full |
| 599 | + |
| 600 | +class WanTrackToVideo: |
| 601 | + @classmethod |
| 602 | + def INPUT_TYPES(s): |
| 603 | + return {"required": { |
| 604 | + "positive": ("CONDITIONING", ), |
| 605 | + "negative": ("CONDITIONING", ), |
| 606 | + "vae": ("VAE", ), |
| 607 | + "tracks": ("STRING", {"multiline": True, "default": "[]"}), |
| 608 | + "width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), |
| 609 | + "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), |
| 610 | + "length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), |
| 611 | + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), |
| 612 | + "temperature": ("FLOAT", {"default": 220.0, "min": 1.0, "max": 1000.0, "step": 0.1}), |
| 613 | + "topk": ("INT", {"default": 2, "min": 1, "max": 10}), |
| 614 | + "start_image": ("IMAGE", ), |
| 615 | + }, |
| 616 | + "optional": { |
| 617 | + "clip_vision_output": ("CLIP_VISION_OUTPUT", ), |
| 618 | + }} |
| 619 | + |
| 620 | + RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") |
| 621 | + RETURN_NAMES = ("positive", "negative", "latent") |
| 622 | + FUNCTION = "encode" |
| 623 | + |
| 624 | + CATEGORY = "conditioning/video_models" |
| 625 | + |
| 626 | + def encode(self, positive, negative, vae, tracks, width, height, length, batch_size, |
| 627 | + temperature, topk, start_image=None, clip_vision_output=None): |
| 628 | + |
| 629 | + tracks_data = parse_json_tracks(tracks) |
| 630 | + |
| 631 | + if not tracks_data: |
| 632 | + return WanImageToVideo().encode(positive, negative, vae, width, height, length, batch_size, start_image=start_image, clip_vision_output=clip_vision_output) |
| 633 | + |
| 634 | + latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], |
| 635 | + device=comfy.model_management.intermediate_device()) |
| 636 | + |
| 637 | + if isinstance(tracks_data[0][0], dict): |
| 638 | + tracks_data = [tracks_data] |
| 639 | + |
| 640 | + processed_tracks = [] |
| 641 | + for batch in tracks_data: |
| 642 | + arrs = [] |
| 643 | + for track in batch: |
| 644 | + pts = pad_pts(track) |
| 645 | + arrs.append(pts) |
| 646 | + |
| 647 | + tracks_np = np.stack(arrs, axis=0) |
| 648 | + processed_tracks.append(process_tracks(tracks_np, (width, height), length - 1).unsqueeze(0)) |
| 649 | + |
| 650 | + if start_image is not None: |
| 651 | + start_image = comfy.utils.common_upscale(start_image[:batch_size].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) |
| 652 | + videos = torch.ones((start_image.shape[0], length, height, width, start_image.shape[-1]), device=start_image.device, dtype=start_image.dtype) * 0.5 |
| 653 | + for i in range(start_image.shape[0]): |
| 654 | + videos[i, 0] = start_image[i] |
| 655 | + |
| 656 | + latent_videos = [] |
| 657 | + videos = comfy.utils.resize_to_batch_size(videos, batch_size) |
| 658 | + for i in range(batch_size): |
| 659 | + latent_videos += [vae.encode(videos[i, :, :, :, :3])] |
| 660 | + y = torch.cat(latent_videos, dim=0) |
| 661 | + |
| 662 | + # Scale latent since patch_motion is non-linear |
| 663 | + y = comfy.latent_formats.Wan21().process_in(y) |
| 664 | + |
| 665 | + processed_tracks = comfy.utils.resize_list_to_batch_size(processed_tracks, batch_size) |
| 666 | + res = patch_motion( |
| 667 | + processed_tracks, y, temperature=temperature, topk=topk, vae_divide=(4, 16) |
| 668 | + ) |
| 669 | + |
| 670 | + mask, concat_latent_image = res |
| 671 | + concat_latent_image = comfy.latent_formats.Wan21().process_out(concat_latent_image) |
| 672 | + mask = -mask + 1.0 # Invert mask to match expected format |
| 673 | + positive = node_helpers.conditioning_set_values(positive, |
| 674 | + {"concat_mask": mask, |
| 675 | + "concat_latent_image": concat_latent_image}) |
| 676 | + negative = node_helpers.conditioning_set_values(negative, |
| 677 | + {"concat_mask": mask, |
| 678 | + "concat_latent_image": concat_latent_image}) |
| 679 | + |
| 680 | + if clip_vision_output is not None: |
| 681 | + positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output}) |
| 682 | + negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output}) |
| 683 | + |
| 684 | + out_latent = {} |
| 685 | + out_latent["samples"] = latent |
| 686 | + return (positive, negative, out_latent) |
| 687 | + |
386 | 688 | NODE_CLASS_MAPPINGS = { |
| 689 | + "WanTrackToVideo": WanTrackToVideo, |
387 | 690 | "WanImageToVideo": WanImageToVideo, |
388 | 691 | "WanFunControlToVideo": WanFunControlToVideo, |
389 | 692 | "WanFunInpaintToVideo": WanFunInpaintToVideo, |
|
0 commit comments