Skip to content

Commit 6125b80

Browse files
Add llm sampling options and make reference audio work on ace step 1.5 (Comfy-Org#12295)
1 parent c8fcbd6 commit 6125b80

File tree

4 files changed

+45
-20
lines changed

4 files changed

+45
-20
lines changed

comfy/ldm/ace/ace_step15.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1035,8 +1035,7 @@ def prepare_condition(
10351035
audio_codes = torch.nn.functional.pad(audio_codes, (0, math.ceil(src_latents.shape[1] / 5) - audio_codes.shape[1]), "constant", 35847)
10361036
lm_hints_5Hz = self.tokenizer.quantizer.get_output_from_indices(audio_codes, dtype=text_hidden_states.dtype)
10371037
else:
1038-
assert False
1039-
# TODO ?
1038+
lm_hints_5Hz, indices = self.tokenizer.tokenize(refer_audio_acoustic_hidden_states_packed)
10401039

10411040
lm_hints = self.detokenizer(lm_hints_5Hz)
10421041

comfy/model_base.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1548,6 +1548,7 @@ def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
15481548
def extra_conds(self, **kwargs):
15491549
out = super().extra_conds(**kwargs)
15501550
device = kwargs["device"]
1551+
noise = kwargs["noise"]
15511552

15521553
cross_attn = kwargs.get("cross_attn", None)
15531554
if cross_attn is not None:
@@ -1571,15 +1572,19 @@ def extra_conds(self, **kwargs):
15711572
1.4844e-01, 9.4727e-02, 3.8477e-01, -1.2578e+00, -3.3203e-01,
15721573
-8.5547e-01, 4.3359e-01, 4.2383e-01, -8.9453e-01, -5.0391e-01,
15731574
-5.6152e-02, -2.9219e+00, -2.4658e-02, 5.0391e-01, 9.8438e-01,
1574-
7.2754e-02, -2.1582e-01, 6.3672e-01, 1.0000e+00]]], device=device).movedim(-1, 1).repeat(1, 1, 750)
1575+
7.2754e-02, -2.1582e-01, 6.3672e-01, 1.0000e+00]]], device=device).movedim(-1, 1).repeat(1, 1, noise.shape[2])
1576+
pass_audio_codes = True
15751577
else:
1576-
refer_audio = refer_audio[-1]
1577-
out['refer_audio'] = comfy.conds.CONDRegular(refer_audio)
1578+
refer_audio = refer_audio[-1][:, :, :noise.shape[2]]
1579+
pass_audio_codes = False
15781580

1579-
audio_codes = kwargs.get("audio_codes", None)
1580-
if audio_codes is not None:
1581-
out['audio_codes'] = comfy.conds.CONDRegular(torch.tensor(audio_codes, device=device))
1581+
if pass_audio_codes:
1582+
audio_codes = kwargs.get("audio_codes", None)
1583+
if audio_codes is not None:
1584+
out['audio_codes'] = comfy.conds.CONDRegular(torch.tensor(audio_codes, device=device))
1585+
refer_audio = refer_audio[:, :, :750]
15821586

1587+
out['refer_audio'] = comfy.conds.CONDRegular(refer_audio)
15831588
return out
15841589

15851590
class Omnigen2(BaseModel):

comfy/text_encoders/ace15.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -101,9 +101,7 @@ def sample_manual_loop_no_classes(
101101
return output_audio_codes
102102

103103

104-
def generate_audio_codes(model, positive, negative, min_tokens=1, max_tokens=1024, seed=0):
105-
cfg_scale = 2.0
106-
104+
def generate_audio_codes(model, positive, negative, min_tokens=1, max_tokens=1024, seed=0, cfg_scale=2.0, temperature=0.85, top_p=0.9, top_k=0):
107105
positive = [[token for token, _ in inner_list] for inner_list in positive]
108106
negative = [[token for token, _ in inner_list] for inner_list in negative]
109107
positive = positive[0]
@@ -120,7 +118,7 @@ def generate_audio_codes(model, positive, negative, min_tokens=1, max_tokens=102
120118
positive = [model.special_tokens["pad"]] * pos_pad + positive
121119

122120
paddings = [pos_pad, neg_pad]
123-
return sample_manual_loop_no_classes(model, [positive, negative], paddings, cfg_scale=cfg_scale, seed=seed, min_tokens=min_tokens, max_new_tokens=max_tokens)
121+
return sample_manual_loop_no_classes(model, [positive, negative], paddings, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, seed=seed, min_tokens=min_tokens, max_new_tokens=max_tokens)
124122

125123

126124
class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
@@ -137,6 +135,12 @@ def tokenize_with_weights(self, text, return_word_ids=False, **kwargs):
137135
language = kwargs.get("language", "en")
138136
seed = kwargs.get("seed", 0)
139137

138+
generate_audio_codes = kwargs.get("generate_audio_codes", True)
139+
cfg_scale = kwargs.get("cfg_scale", 2.0)
140+
temperature = kwargs.get("temperature", 0.85)
141+
top_p = kwargs.get("top_p", 0.9)
142+
top_k = kwargs.get("top_k", 0.0)
143+
140144
duration = math.ceil(duration)
141145
meta_lm = 'bpm: {}\nduration: {}\nkeyscale: {}\ntimesignature: {}'.format(bpm, duration, keyscale, timesignature)
142146
lm_template = "<|im_start|>system\n# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n<|im_end|>\n<|im_start|>user\n# Caption\n{}\n{}\n<|im_end|>\n<|im_start|>assistant\n<think>\n{}\n</think>\n\n<|im_end|>\n"
@@ -147,7 +151,14 @@ def tokenize_with_weights(self, text, return_word_ids=False, **kwargs):
147151

148152
out["lyrics"] = self.qwen3_06b.tokenize_with_weights("# Languages\n{}\n\n# Lyric{}<|endoftext|><|endoftext|>".format(language, lyrics), return_word_ids, disable_weights=True, **kwargs)
149153
out["qwen3_06b"] = self.qwen3_06b.tokenize_with_weights("# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n# Caption\n{}# Metas\n{}<|endoftext|>\n<|endoftext|>".format(text, meta_cap), return_word_ids, **kwargs)
150-
out["lm_metadata"] = {"min_tokens": duration * 5, "seed": seed}
154+
out["lm_metadata"] = {"min_tokens": duration * 5,
155+
"seed": seed,
156+
"generate_audio_codes": generate_audio_codes,
157+
"cfg_scale": cfg_scale,
158+
"temperature": temperature,
159+
"top_p": top_p,
160+
"top_k": top_k,
161+
}
151162
return out
152163

153164

@@ -203,10 +214,14 @@ def encode_token_weights(self, token_weight_pairs):
203214
self.qwen3_06b.set_clip_options({"layer": [0]})
204215
lyrics_embeds, _, extra_l = self.qwen3_06b.encode_token_weights(token_weight_pairs_lyrics)
205216

217+
out = {"conditioning_lyrics": lyrics_embeds[:, 0]}
218+
206219
lm_metadata = token_weight_pairs["lm_metadata"]
207-
audio_codes = generate_audio_codes(getattr(self, self.lm_model, self.qwen3_06b), token_weight_pairs["lm_prompt"], token_weight_pairs["lm_prompt_negative"], min_tokens=lm_metadata["min_tokens"], max_tokens=lm_metadata["min_tokens"], seed=lm_metadata["seed"])
220+
if lm_metadata["generate_audio_codes"]:
221+
audio_codes = generate_audio_codes(getattr(self, self.lm_model, self.qwen3_06b), token_weight_pairs["lm_prompt"], token_weight_pairs["lm_prompt_negative"], min_tokens=lm_metadata["min_tokens"], max_tokens=lm_metadata["min_tokens"], seed=lm_metadata["seed"], cfg_scale=lm_metadata["cfg_scale"], temperature=lm_metadata["temperature"], top_p=lm_metadata["top_p"], top_k=lm_metadata["top_k"])
222+
out["audio_codes"] = [audio_codes]
208223

209-
return base_out, None, {"conditioning_lyrics": lyrics_embeds[:, 0], "audio_codes": [audio_codes]}
224+
return base_out, None, out
210225

211226
def set_clip_options(self, options):
212227
self.qwen3_06b.set_clip_options(options)

comfy_extras/nodes_ace.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,18 @@ def define_schema(cls):
4444
io.Combo.Input("timesignature", options=['2', '3', '4', '6']),
4545
io.Combo.Input("language", options=["en", "ja", "zh", "es", "de", "fr", "pt", "ru", "it", "nl", "pl", "tr", "vi", "cs", "fa", "id", "ko", "uk", "hu", "ar", "sv", "ro", "el"]),
4646
io.Combo.Input("keyscale", options=[f"{root} {quality}" for quality in ["major", "minor"] for root in ["C", "C#", "Db", "D", "D#", "Eb", "E", "F", "F#", "Gb", "G", "G#", "Ab", "A", "A#", "Bb", "B"]]),
47+
io.Boolean.Input("generate_audio_codes", default=True, tooltip="Enable the LLM that generates audio codes. This can be slow but will increase the quality of the generated audio. Turn this off if you are giving the model an audio reference.", advanced=True),
48+
io.Float.Input("cfg_scale", default=2.0, min=0.0, max=100.0, step=0.1, advanced=True),
49+
io.Float.Input("temperature", default=0.85, min=0.0, max=2.0, step=0.01, advanced=True),
50+
io.Float.Input("top_p", default=0.9, min=0.0, max=2000.0, step=0.01, advanced=True),
51+
io.Int.Input("top_k", default=0, min=0, max=100, advanced=True),
4752
],
4853
outputs=[io.Conditioning.Output()],
4954
)
5055

5156
@classmethod
52-
def execute(cls, clip, tags, lyrics, seed, bpm, duration, timesignature, language, keyscale) -> io.NodeOutput:
53-
tokens = clip.tokenize(tags, lyrics=lyrics, bpm=bpm, duration=duration, timesignature=int(timesignature), language=language, keyscale=keyscale, seed=seed)
57+
def execute(cls, clip, tags, lyrics, seed, bpm, duration, timesignature, language, keyscale, generate_audio_codes, cfg_scale, temperature, top_p, top_k) -> io.NodeOutput:
58+
tokens = clip.tokenize(tags, lyrics=lyrics, bpm=bpm, duration=duration, timesignature=int(timesignature), language=language, keyscale=keyscale, seed=seed, generate_audio_codes=generate_audio_codes, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k)
5459
conditioning = clip.encode_from_tokens_scheduled(tokens)
5560
return io.NodeOutput(conditioning)
5661

@@ -100,14 +105,15 @@ def execute(cls, seconds, batch_size) -> io.NodeOutput:
100105
latent = torch.zeros([batch_size, 64, length], device=comfy.model_management.intermediate_device())
101106
return io.NodeOutput({"samples": latent, "type": "audio"})
102107

103-
class ReferenceTimbreAudio(io.ComfyNode):
108+
class ReferenceAudio(io.ComfyNode):
104109
@classmethod
105110
def define_schema(cls):
106111
return io.Schema(
107112
node_id="ReferenceTimbreAudio",
113+
display_name="Reference Audio",
108114
category="advanced/conditioning/audio",
109115
is_experimental=True,
110-
description="This node sets the reference audio for timbre (for ace step 1.5)",
116+
description="This node sets the reference audio for ace step 1.5",
111117
inputs=[
112118
io.Conditioning.Input("conditioning"),
113119
io.Latent.Input("latent", optional=True),
@@ -131,7 +137,7 @@ async def get_node_list(self) -> list[type[io.ComfyNode]]:
131137
EmptyAceStepLatentAudio,
132138
TextEncodeAceStepAudio15,
133139
EmptyAceStep15LatentAudio,
134-
ReferenceTimbreAudio,
140+
ReferenceAudio,
135141
]
136142

137143
async def comfy_entrypoint() -> AceExtension:

0 commit comments

Comments
 (0)