Skip to content

Commit edde0b5

Browse files
WanSoundImageToVideoExtend node to manually extend s2v video. (Comfy-Org#9606)
1 parent 0063610 commit edde0b5

File tree

1 file changed

+97
-48
lines changed

1 file changed

+97
-48
lines changed

comfy_extras/nodes_wan.py

Lines changed: 97 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -877,6 +877,67 @@ def get_audio_embed_bucket_fps(audio_embed, fps=16, batch_frames=81, m=0, video_
877877
return batch_audio_eb, min_batch_num
878878

879879

880+
def wan_sound_to_video(positive, negative, vae, width, height, length, batch_size, frame_offset=0, ref_image=None, audio_encoder_output=None, control_video=None, ref_motion=None, ref_motion_latent=None):
881+
latent_t = ((length - 1) // 4) + 1
882+
if audio_encoder_output is not None:
883+
feat = torch.cat(audio_encoder_output["encoded_audio_all_layers"])
884+
video_rate = 30
885+
fps = 16
886+
feat = linear_interpolation(feat, input_fps=50, output_fps=video_rate)
887+
batch_frames = latent_t * 4
888+
audio_embed_bucket, num_repeat = get_audio_embed_bucket_fps(feat, fps=fps, batch_frames=batch_frames, m=0, video_rate=video_rate)
889+
audio_embed_bucket = audio_embed_bucket.unsqueeze(0)
890+
if len(audio_embed_bucket.shape) == 3:
891+
audio_embed_bucket = audio_embed_bucket.permute(0, 2, 1)
892+
elif len(audio_embed_bucket.shape) == 4:
893+
audio_embed_bucket = audio_embed_bucket.permute(0, 2, 3, 1)
894+
895+
audio_embed_bucket = audio_embed_bucket[:, :, :, frame_offset:frame_offset + batch_frames]
896+
positive = node_helpers.conditioning_set_values(positive, {"audio_embed": audio_embed_bucket})
897+
negative = node_helpers.conditioning_set_values(negative, {"audio_embed": audio_embed_bucket * 0.0})
898+
frame_offset += batch_frames
899+
900+
if ref_image is not None:
901+
ref_image = comfy.utils.common_upscale(ref_image[:1].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
902+
ref_latent = vae.encode(ref_image[:, :, :, :3])
903+
positive = node_helpers.conditioning_set_values(positive, {"reference_latents": [ref_latent]}, append=True)
904+
negative = node_helpers.conditioning_set_values(negative, {"reference_latents": [ref_latent]}, append=True)
905+
906+
if ref_motion is not None:
907+
if ref_motion.shape[0] > 73:
908+
ref_motion = ref_motion[-73:]
909+
910+
ref_motion = comfy.utils.common_upscale(ref_motion.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
911+
912+
if ref_motion.shape[0] < 73:
913+
r = torch.ones([73, height, width, 3]) * 0.5
914+
r[-ref_motion.shape[0]:] = ref_motion
915+
ref_motion = r
916+
917+
ref_motion_latent = vae.encode(ref_motion[:, :, :, :3])
918+
919+
if ref_motion_latent is not None:
920+
ref_motion_latent = ref_motion_latent[:, :, -19:]
921+
positive = node_helpers.conditioning_set_values(positive, {"reference_motion": ref_motion_latent})
922+
negative = node_helpers.conditioning_set_values(negative, {"reference_motion": ref_motion_latent})
923+
924+
latent = torch.zeros([batch_size, 16, latent_t, height // 8, width // 8], device=comfy.model_management.intermediate_device())
925+
926+
control_video_out = comfy.latent_formats.Wan21().process_out(torch.zeros_like(latent))
927+
if control_video is not None:
928+
control_video = comfy.utils.common_upscale(control_video[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
929+
control_video = vae.encode(control_video[:, :, :, :3])
930+
control_video_out[:, :, :control_video.shape[2]] = control_video
931+
932+
# TODO: check if zero is better than none if none provided
933+
positive = node_helpers.conditioning_set_values(positive, {"control_video": control_video_out})
934+
negative = node_helpers.conditioning_set_values(negative, {"control_video": control_video_out})
935+
936+
out_latent = {}
937+
out_latent["samples"] = latent
938+
return positive, negative, out_latent, frame_offset
939+
940+
880941
class WanSoundImageToVideo(io.ComfyNode):
881942
@classmethod
882943
def define_schema(cls):
@@ -906,57 +967,44 @@ def define_schema(cls):
906967

907968
@classmethod
908969
def execute(cls, positive, negative, vae, width, height, length, batch_size, ref_image=None, audio_encoder_output=None, control_video=None, ref_motion=None) -> io.NodeOutput:
909-
latent_t = ((length - 1) // 4) + 1
910-
if audio_encoder_output is not None:
911-
feat = torch.cat(audio_encoder_output["encoded_audio_all_layers"])
912-
video_rate = 30
913-
fps = 16
914-
feat = linear_interpolation(feat, input_fps=50, output_fps=video_rate)
915-
audio_embed_bucket, num_repeat = get_audio_embed_bucket_fps(feat, fps=fps, batch_frames=latent_t * 4, m=0, video_rate=video_rate)
916-
audio_embed_bucket = audio_embed_bucket.unsqueeze(0)
917-
if len(audio_embed_bucket.shape) == 3:
918-
audio_embed_bucket = audio_embed_bucket.permute(0, 2, 1)
919-
elif len(audio_embed_bucket.shape) == 4:
920-
audio_embed_bucket = audio_embed_bucket.permute(0, 2, 3, 1)
921-
922-
positive = node_helpers.conditioning_set_values(positive, {"audio_embed": audio_embed_bucket})
923-
negative = node_helpers.conditioning_set_values(negative, {"audio_embed": audio_embed_bucket * 0.0})
924-
925-
if ref_image is not None:
926-
ref_image = comfy.utils.common_upscale(ref_image[:1].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
927-
ref_latent = vae.encode(ref_image[:, :, :, :3])
928-
positive = node_helpers.conditioning_set_values(positive, {"reference_latents": [ref_latent]}, append=True)
929-
negative = node_helpers.conditioning_set_values(negative, {"reference_latents": [ref_latent]}, append=True)
930-
931-
if ref_motion is not None:
932-
if ref_motion.shape[0] > 73:
933-
ref_motion = ref_motion[-73:]
934-
935-
ref_motion = comfy.utils.common_upscale(ref_motion.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
936-
937-
if ref_motion.shape[0] < 73:
938-
r = torch.ones([73, height, width, 3]) * 0.5
939-
r[-ref_motion.shape[0]:] = ref_motion
940-
ref_motion = r
941-
942-
ref_motion = vae.encode(ref_motion[:, :, :, :3])
943-
positive = node_helpers.conditioning_set_values(positive, {"reference_motion": ref_motion})
944-
negative = node_helpers.conditioning_set_values(negative, {"reference_motion": ref_motion})
945-
946-
latent = torch.zeros([batch_size, 16, latent_t, height // 8, width // 8], device=comfy.model_management.intermediate_device())
970+
positive, negative, out_latent, frame_offset = wan_sound_to_video(positive, negative, vae, width, height, length, batch_size, ref_image=ref_image, audio_encoder_output=audio_encoder_output,
971+
control_video=control_video, ref_motion=ref_motion)
972+
return io.NodeOutput(positive, negative, out_latent)
947973

948-
control_video_out = comfy.latent_formats.Wan21().process_out(torch.zeros_like(latent))
949-
if control_video is not None:
950-
control_video = comfy.utils.common_upscale(control_video[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
951-
control_video = vae.encode(control_video[:, :, :, :3])
952-
control_video_out[:, :, :control_video.shape[2]] = control_video
953974

954-
# TODO: check if zero is better than none if none provided
955-
positive = node_helpers.conditioning_set_values(positive, {"control_video": control_video_out})
956-
negative = node_helpers.conditioning_set_values(negative, {"control_video": control_video_out})
975+
class WanSoundImageToVideoExtend(io.ComfyNode):
976+
@classmethod
977+
def define_schema(cls):
978+
return io.Schema(
979+
node_id="WanSoundImageToVideoExtend",
980+
category="conditioning/video_models",
981+
inputs=[
982+
io.Conditioning.Input("positive"),
983+
io.Conditioning.Input("negative"),
984+
io.Vae.Input("vae"),
985+
io.Int.Input("length", default=77, min=1, max=nodes.MAX_RESOLUTION, step=4),
986+
io.Latent.Input("video_latent"),
987+
io.AudioEncoderOutput.Input("audio_encoder_output", optional=True),
988+
io.Image.Input("ref_image", optional=True),
989+
io.Image.Input("control_video", optional=True),
990+
],
991+
outputs=[
992+
io.Conditioning.Output(display_name="positive"),
993+
io.Conditioning.Output(display_name="negative"),
994+
io.Latent.Output(display_name="latent"),
995+
],
996+
is_experimental=True,
997+
)
957998

958-
out_latent = {}
959-
out_latent["samples"] = latent
999+
@classmethod
1000+
def execute(cls, positive, negative, vae, length, video_latent, ref_image=None, audio_encoder_output=None, control_video=None) -> io.NodeOutput:
1001+
video_latent = video_latent["samples"]
1002+
width = video_latent.shape[-1] * 8
1003+
height = video_latent.shape[-2] * 8
1004+
batch_size = video_latent.shape[0]
1005+
frame_offset = video_latent.shape[-3] * 4
1006+
positive, negative, out_latent, frame_offset = wan_sound_to_video(positive, negative, vae, width, height, length, batch_size, frame_offset=frame_offset, ref_image=ref_image, audio_encoder_output=audio_encoder_output,
1007+
control_video=control_video, ref_motion=None, ref_motion_latent=video_latent)
9601008
return io.NodeOutput(positive, negative, out_latent)
9611009

9621010

@@ -1019,6 +1067,7 @@ async def get_node_list(self) -> list[type[io.ComfyNode]]:
10191067
WanCameraImageToVideo,
10201068
WanPhantomSubjectToVideo,
10211069
WanSoundImageToVideo,
1070+
WanSoundImageToVideoExtend,
10221071
Wan22ImageToVideoLatent,
10231072
]
10241073

0 commit comments

Comments
 (0)