diff --git a/dia/model.py b/dia/model.py index a3b0f973..8e7a50a3 100644 --- a/dia/model.py +++ b/dia/model.py @@ -433,7 +433,7 @@ def _decoder_step( uncond_logits_BxCxV = logits_last_Bx2xCxV[:, 0, :, :] # Shape [B, C, V] cond_logits_BxCxV = logits_last_Bx2xCxV[:, 1, :, :] # Shape [B, C, V] - logits_BxCxV = cond_logits_BxCxV + cfg_scale * (cond_logits_BxCxV - uncond_logits_BxCxV) + logits_BxCxV = uncond_logits_BxCxV + cfg_scale * (cond_logits_BxCxV - uncond_logits_BxCxV) logits_BxCxV[:, :, audio_eos_value + 1 :] = torch.full_like( logits_BxCxV[:, :, audio_eos_value + 1 :],