@@ -877,6 +877,67 @@ def get_audio_embed_bucket_fps(audio_embed, fps=16, batch_frames=81, m=0, video_
877877 return batch_audio_eb , min_batch_num
878878
879879
880+ def wan_sound_to_video (positive , negative , vae , width , height , length , batch_size , frame_offset = 0 , ref_image = None , audio_encoder_output = None , control_video = None , ref_motion = None , ref_motion_latent = None ):
881+ latent_t = ((length - 1 ) // 4 ) + 1
882+ if audio_encoder_output is not None :
883+ feat = torch .cat (audio_encoder_output ["encoded_audio_all_layers" ])
884+ video_rate = 30
885+ fps = 16
886+ feat = linear_interpolation (feat , input_fps = 50 , output_fps = video_rate )
887+ batch_frames = latent_t * 4
888+ audio_embed_bucket , num_repeat = get_audio_embed_bucket_fps (feat , fps = fps , batch_frames = batch_frames , m = 0 , video_rate = video_rate )
889+ audio_embed_bucket = audio_embed_bucket .unsqueeze (0 )
890+ if len (audio_embed_bucket .shape ) == 3 :
891+ audio_embed_bucket = audio_embed_bucket .permute (0 , 2 , 1 )
892+ elif len (audio_embed_bucket .shape ) == 4 :
893+ audio_embed_bucket = audio_embed_bucket .permute (0 , 2 , 3 , 1 )
894+
895+ audio_embed_bucket = audio_embed_bucket [:, :, :, frame_offset :frame_offset + batch_frames ]
896+ positive = node_helpers .conditioning_set_values (positive , {"audio_embed" : audio_embed_bucket })
897+ negative = node_helpers .conditioning_set_values (negative , {"audio_embed" : audio_embed_bucket * 0.0 })
898+ frame_offset += batch_frames
899+
900+ if ref_image is not None :
901+ ref_image = comfy .utils .common_upscale (ref_image [:1 ].movedim (- 1 , 1 ), width , height , "bilinear" , "center" ).movedim (1 , - 1 )
902+ ref_latent = vae .encode (ref_image [:, :, :, :3 ])
903+ positive = node_helpers .conditioning_set_values (positive , {"reference_latents" : [ref_latent ]}, append = True )
904+ negative = node_helpers .conditioning_set_values (negative , {"reference_latents" : [ref_latent ]}, append = True )
905+
906+ if ref_motion is not None :
907+ if ref_motion .shape [0 ] > 73 :
908+ ref_motion = ref_motion [- 73 :]
909+
910+ ref_motion = comfy .utils .common_upscale (ref_motion .movedim (- 1 , 1 ), width , height , "bilinear" , "center" ).movedim (1 , - 1 )
911+
912+ if ref_motion .shape [0 ] < 73 :
913+ r = torch .ones ([73 , height , width , 3 ]) * 0.5
914+ r [- ref_motion .shape [0 ]:] = ref_motion
915+ ref_motion = r
916+
917+ ref_motion_latent = vae .encode (ref_motion [:, :, :, :3 ])
918+
919+ if ref_motion_latent is not None :
920+ ref_motion_latent = ref_motion_latent [:, :, - 19 :]
921+ positive = node_helpers .conditioning_set_values (positive , {"reference_motion" : ref_motion_latent })
922+ negative = node_helpers .conditioning_set_values (negative , {"reference_motion" : ref_motion_latent })
923+
924+ latent = torch .zeros ([batch_size , 16 , latent_t , height // 8 , width // 8 ], device = comfy .model_management .intermediate_device ())
925+
926+ control_video_out = comfy .latent_formats .Wan21 ().process_out (torch .zeros_like (latent ))
927+ if control_video is not None :
928+ control_video = comfy .utils .common_upscale (control_video [:length ].movedim (- 1 , 1 ), width , height , "bilinear" , "center" ).movedim (1 , - 1 )
929+ control_video = vae .encode (control_video [:, :, :, :3 ])
930+ control_video_out [:, :, :control_video .shape [2 ]] = control_video
931+
932+ # TODO: check if zero is better than none if none provided
933+ positive = node_helpers .conditioning_set_values (positive , {"control_video" : control_video_out })
934+ negative = node_helpers .conditioning_set_values (negative , {"control_video" : control_video_out })
935+
936+ out_latent = {}
937+ out_latent ["samples" ] = latent
938+ return positive , negative , out_latent , frame_offset
939+
940+
880941class WanSoundImageToVideo (io .ComfyNode ):
881942 @classmethod
882943 def define_schema (cls ):
@@ -906,57 +967,44 @@ def define_schema(cls):
906967
907968 @classmethod
908969 def execute (cls , positive , negative , vae , width , height , length , batch_size , ref_image = None , audio_encoder_output = None , control_video = None , ref_motion = None ) -> io .NodeOutput :
909- latent_t = ((length - 1 ) // 4 ) + 1
910- if audio_encoder_output is not None :
911- feat = torch .cat (audio_encoder_output ["encoded_audio_all_layers" ])
912- video_rate = 30
913- fps = 16
914- feat = linear_interpolation (feat , input_fps = 50 , output_fps = video_rate )
915- audio_embed_bucket , num_repeat = get_audio_embed_bucket_fps (feat , fps = fps , batch_frames = latent_t * 4 , m = 0 , video_rate = video_rate )
916- audio_embed_bucket = audio_embed_bucket .unsqueeze (0 )
917- if len (audio_embed_bucket .shape ) == 3 :
918- audio_embed_bucket = audio_embed_bucket .permute (0 , 2 , 1 )
919- elif len (audio_embed_bucket .shape ) == 4 :
920- audio_embed_bucket = audio_embed_bucket .permute (0 , 2 , 3 , 1 )
921-
922- positive = node_helpers .conditioning_set_values (positive , {"audio_embed" : audio_embed_bucket })
923- negative = node_helpers .conditioning_set_values (negative , {"audio_embed" : audio_embed_bucket * 0.0 })
924-
925- if ref_image is not None :
926- ref_image = comfy .utils .common_upscale (ref_image [:1 ].movedim (- 1 , 1 ), width , height , "bilinear" , "center" ).movedim (1 , - 1 )
927- ref_latent = vae .encode (ref_image [:, :, :, :3 ])
928- positive = node_helpers .conditioning_set_values (positive , {"reference_latents" : [ref_latent ]}, append = True )
929- negative = node_helpers .conditioning_set_values (negative , {"reference_latents" : [ref_latent ]}, append = True )
930-
931- if ref_motion is not None :
932- if ref_motion .shape [0 ] > 73 :
933- ref_motion = ref_motion [- 73 :]
934-
935- ref_motion = comfy .utils .common_upscale (ref_motion .movedim (- 1 , 1 ), width , height , "bilinear" , "center" ).movedim (1 , - 1 )
936-
937- if ref_motion .shape [0 ] < 73 :
938- r = torch .ones ([73 , height , width , 3 ]) * 0.5
939- r [- ref_motion .shape [0 ]:] = ref_motion
940- ref_motion = r
941-
942- ref_motion = vae .encode (ref_motion [:, :, :, :3 ])
943- positive = node_helpers .conditioning_set_values (positive , {"reference_motion" : ref_motion })
944- negative = node_helpers .conditioning_set_values (negative , {"reference_motion" : ref_motion })
945-
946- latent = torch .zeros ([batch_size , 16 , latent_t , height // 8 , width // 8 ], device = comfy .model_management .intermediate_device ())
970+ positive , negative , out_latent , frame_offset = wan_sound_to_video (positive , negative , vae , width , height , length , batch_size , ref_image = ref_image , audio_encoder_output = audio_encoder_output ,
971+ control_video = control_video , ref_motion = ref_motion )
972+ return io .NodeOutput (positive , negative , out_latent )
947973
948- control_video_out = comfy .latent_formats .Wan21 ().process_out (torch .zeros_like (latent ))
949- if control_video is not None :
950- control_video = comfy .utils .common_upscale (control_video [:length ].movedim (- 1 , 1 ), width , height , "bilinear" , "center" ).movedim (1 , - 1 )
951- control_video = vae .encode (control_video [:, :, :, :3 ])
952- control_video_out [:, :, :control_video .shape [2 ]] = control_video
953974
954- # TODO: check if zero is better than none if none provided
955- positive = node_helpers .conditioning_set_values (positive , {"control_video" : control_video_out })
956- negative = node_helpers .conditioning_set_values (negative , {"control_video" : control_video_out })
975+ class WanSoundImageToVideoExtend (io .ComfyNode ):
976+ @classmethod
977+ def define_schema (cls ):
978+ return io .Schema (
979+ node_id = "WanSoundImageToVideoExtend" ,
980+ category = "conditioning/video_models" ,
981+ inputs = [
982+ io .Conditioning .Input ("positive" ),
983+ io .Conditioning .Input ("negative" ),
984+ io .Vae .Input ("vae" ),
985+ io .Int .Input ("length" , default = 77 , min = 1 , max = nodes .MAX_RESOLUTION , step = 4 ),
986+ io .Latent .Input ("video_latent" ),
987+ io .AudioEncoderOutput .Input ("audio_encoder_output" , optional = True ),
988+ io .Image .Input ("ref_image" , optional = True ),
989+ io .Image .Input ("control_video" , optional = True ),
990+ ],
991+ outputs = [
992+ io .Conditioning .Output (display_name = "positive" ),
993+ io .Conditioning .Output (display_name = "negative" ),
994+ io .Latent .Output (display_name = "latent" ),
995+ ],
996+ is_experimental = True ,
997+ )
957998
958- out_latent = {}
959- out_latent ["samples" ] = latent
999+ @classmethod
1000+ def execute (cls , positive , negative , vae , length , video_latent , ref_image = None , audio_encoder_output = None , control_video = None ) -> io .NodeOutput :
1001+ video_latent = video_latent ["samples" ]
1002+ width = video_latent .shape [- 1 ] * 8
1003+ height = video_latent .shape [- 2 ] * 8
1004+ batch_size = video_latent .shape [0 ]
1005+ frame_offset = video_latent .shape [- 3 ] * 4
1006+ positive , negative , out_latent , frame_offset = wan_sound_to_video (positive , negative , vae , width , height , length , batch_size , frame_offset = frame_offset , ref_image = ref_image , audio_encoder_output = audio_encoder_output ,
1007+ control_video = control_video , ref_motion = None , ref_motion_latent = video_latent )
9601008 return io .NodeOutput (positive , negative , out_latent )
9611009
9621010
@@ -1019,6 +1067,7 @@ async def get_node_list(self) -> list[type[io.ComfyNode]]:
10191067 WanCameraImageToVideo ,
10201068 WanPhantomSubjectToVideo ,
10211069 WanSoundImageToVideo ,
1070+ WanSoundImageToVideoExtend ,
10221071 Wan22ImageToVideoLatent ,
10231072 ]
10241073
0 commit comments