diff --git a/comfy/ldm/ace/model.py b/comfy/ldm/ace/model.py index e5883df90b55..12c5247011aa 100644 --- a/comfy/ldm/ace/model.py +++ b/comfy/ldm/ace/model.py @@ -273,6 +273,7 @@ def encode( speaker_embeds: Optional[torch.FloatTensor] = None, lyric_token_idx: Optional[torch.LongTensor] = None, lyric_mask: Optional[torch.LongTensor] = None, + lyrics_strength=1.0, ): bs = encoder_text_hidden_states.shape[0] @@ -291,6 +292,8 @@ def encode( out_dtype=encoder_text_hidden_states.dtype, ) + encoder_lyric_hidden_states *= lyrics_strength + encoder_hidden_states = torch.cat([encoder_spk_hidden_states, encoder_text_hidden_states, encoder_lyric_hidden_states], dim=1) encoder_hidden_mask = None @@ -310,7 +313,6 @@ def decode( output_length: int = 0, block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, controlnet_scale: Union[float, torch.Tensor] = 1.0, - return_dict: bool = True, ): embedded_timestep = self.timestep_embedder(self.time_proj(timestep).to(dtype=hidden_states.dtype)) temb = self.t_block(embedded_timestep) @@ -353,6 +355,7 @@ def forward( lyric_mask: Optional[torch.LongTensor] = None, block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, controlnet_scale: Union[float, torch.Tensor] = 1.0, + lyrics_strength=1.0, **kwargs ): hidden_states = x @@ -363,6 +366,7 @@ def forward( speaker_embeds=speaker_embeds, lyric_token_idx=lyric_token_idx, lyric_mask=lyric_mask, + lyrics_strength=lyrics_strength, ) output_length = hidden_states.shape[-1] diff --git a/comfy/ldm/ace/vae/music_dcae_pipeline.py b/comfy/ldm/ace/vae/music_dcae_pipeline.py index 3188bc7703a1..af81280eb0dd 100644 --- a/comfy/ldm/ace/vae/music_dcae_pipeline.py +++ b/comfy/ldm/ace/vae/music_dcae_pipeline.py @@ -1,7 +1,12 @@ # Original from: https://github.com/ace-step/ACE-Step/blob/main/music_dcae/music_dcae_pipeline.py import torch from .autoencoder_dc import AutoencoderDC -import torchaudio +import logging +try: + import torchaudio +except: + logging.warning("torchaudio missing, ACE model will be broken") + import torchvision.transforms as transforms from .music_vocoder import ADaMoSHiFiGANV1 diff --git a/comfy/ldm/ace/vae/music_log_mel.py b/comfy/ldm/ace/vae/music_log_mel.py index d73d3f8e8a3b..9c584eb7fa75 100755 --- a/comfy/ldm/ace/vae/music_log_mel.py +++ b/comfy/ldm/ace/vae/music_log_mel.py @@ -2,7 +2,12 @@ import torch import torch.nn as nn from torch import Tensor -from torchaudio.transforms import MelScale +import logging +try: + from torchaudio.transforms import MelScale +except: + logging.warning("torchaudio missing, ACE model will be broken") + import comfy.model_management class LinearSpectrogram(nn.Module): diff --git a/comfy/model_base.py b/comfy/model_base.py index 6408005b664f..6d27930dc139 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1139,4 +1139,5 @@ def extra_conds(self, **kwargs): if cross_attn is not None: out['lyric_token_idx'] = comfy.conds.CONDRegular(conditioning_lyrics) out['speaker_embeds'] = comfy.conds.CONDRegular(torch.zeros(noise.shape[0], 512, device=noise.device, dtype=noise.dtype)) + out['lyrics_strength'] = comfy.conds.CONDConstant(kwargs.get("lyrics_strength", 1.0)) return out diff --git a/comfy/model_detection.py b/comfy/model_detection.py index ff4c29d7e2aa..28c586389a5e 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -222,6 +222,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): if '{}adaln_single.emb.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys: #Lightricks ltxv dit_config = {} dit_config["image_model"] = "ltxv" + dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.') + shape = state_dict['{}transformer_blocks.0.attn2.to_k.weight'.format(key_prefix)].shape + dit_config["attention_head_dim"] = shape[0] // 32 + dit_config["cross_attention_dim"] = shape[1] if metadata is not None and "config" in metadata: dit_config.update(json.loads(metadata["config"]).get("transformer", {})) return dit_config diff --git a/comfy_extras/nodes_ace.py b/comfy_extras/nodes_ace.py index 36eb999d147f..cbfec15a2198 100644 --- a/comfy_extras/nodes_ace.py +++ b/comfy_extras/nodes_ace.py @@ -1,6 +1,6 @@ import torch import comfy.model_management - +import node_helpers class TextEncodeAceStepAudio: @classmethod @@ -9,15 +9,18 @@ def INPUT_TYPES(s): "clip": ("CLIP", ), "tags": ("STRING", {"multiline": True, "dynamicPrompts": True}), "lyrics": ("STRING", {"multiline": True, "dynamicPrompts": True}), + "lyrics_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), }} RETURN_TYPES = ("CONDITIONING",) FUNCTION = "encode" CATEGORY = "conditioning" - def encode(self, clip, tags, lyrics): + def encode(self, clip, tags, lyrics, lyrics_strength): tokens = clip.tokenize(tags, lyrics=lyrics) - return (clip.encode_from_tokens_scheduled(tokens), ) + conditioning = clip.encode_from_tokens_scheduled(tokens) + conditioning = node_helpers.conditioning_set_values(conditioning, {"lyrics_strength": lyrics_strength}) + return (conditioning, ) class EmptyAceStepLatentAudio: