Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion comfy/ldm/ace/ace_step15.py
Original file line number Diff line number Diff line change
Expand Up @@ -1110,7 +1110,7 @@ def prepare_condition(

return encoder_hidden, encoder_mask, context_latents

def forward(self, x, timestep, context, lyric_embed=None, refer_audio=None, audio_codes=None, is_covers=None, **kwargs):
def forward(self, x, timestep, context, lyric_embed=None, refer_audio=None, audio_codes=None, is_covers=None, replace_with_null_embeds=False, **kwargs):
text_attention_mask = None
lyric_attention_mask = None
refer_audio_order_mask = None
Expand Down Expand Up @@ -1140,6 +1140,9 @@ def forward(self, x, timestep, context, lyric_embed=None, refer_audio=None, audi
src_latents, chunk_masks, is_covers, precomputed_lm_hints_25Hz=precomputed_lm_hints_25Hz, audio_codes=audio_codes
)

if replace_with_null_embeds:
enc_hidden[:] = self.null_condition_emb.to(enc_hidden)

out = self.decoder(hidden_states=x,
timestep=timestep,
timestep_r=timestep,
Expand Down
26 changes: 18 additions & 8 deletions comfy/ldm/cosmos/predict2.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ def __init__(
device=None, dtype=None, operations=None
):
super().__init__()
self.layer_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.layer_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = operations.Linear(
hidden_size, spatial_patch_size * spatial_patch_size * temporal_patch_size * out_channels, bias=False, device=device, dtype=dtype
)
Expand Down Expand Up @@ -463,6 +463,8 @@ def forward(
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
transformer_options: Optional[dict] = {},
) -> torch.Tensor:
residual_dtype = x_B_T_H_W_D.dtype
compute_dtype = emb_B_T_D.dtype
if extra_per_block_pos_emb is not None:
x_B_T_H_W_D = x_B_T_H_W_D + extra_per_block_pos_emb

Expand Down Expand Up @@ -512,7 +514,7 @@ def _fn(_x_B_T_H_W_D, _norm_layer, _scale_B_T_1_1_D, _shift_B_T_1_1_D):
result_B_T_H_W_D = rearrange(
self.self_attn(
# normalized_x_B_T_HW_D,
rearrange(normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"),
rearrange(normalized_x_B_T_H_W_D.to(compute_dtype), "b t h w d -> b (t h w) d"),
None,
rope_emb=rope_emb_L_1_1_D,
transformer_options=transformer_options,
Expand All @@ -522,7 +524,7 @@ def _fn(_x_B_T_H_W_D, _norm_layer, _scale_B_T_1_1_D, _shift_B_T_1_1_D):
h=H,
w=W,
)
x_B_T_H_W_D = x_B_T_H_W_D + gate_self_attn_B_T_1_1_D * result_B_T_H_W_D
x_B_T_H_W_D = x_B_T_H_W_D + gate_self_attn_B_T_1_1_D.to(residual_dtype) * result_B_T_H_W_D.to(residual_dtype)

def _x_fn(
_x_B_T_H_W_D: torch.Tensor,
Expand All @@ -536,7 +538,7 @@ def _x_fn(
)
_result_B_T_H_W_D = rearrange(
self.cross_attn(
rearrange(_normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"),
rearrange(_normalized_x_B_T_H_W_D.to(compute_dtype), "b t h w d -> b (t h w) d"),
crossattn_emb,
rope_emb=rope_emb_L_1_1_D,
transformer_options=transformer_options,
Expand All @@ -555,16 +557,16 @@ def _x_fn(
shift_cross_attn_B_T_1_1_D,
transformer_options=transformer_options,
)
x_B_T_H_W_D = result_B_T_H_W_D * gate_cross_attn_B_T_1_1_D + x_B_T_H_W_D
x_B_T_H_W_D = result_B_T_H_W_D.to(residual_dtype) * gate_cross_attn_B_T_1_1_D.to(residual_dtype) + x_B_T_H_W_D

normalized_x_B_T_H_W_D = _fn(
x_B_T_H_W_D,
self.layer_norm_mlp,
scale_mlp_B_T_1_1_D,
shift_mlp_B_T_1_1_D,
)
result_B_T_H_W_D = self.mlp(normalized_x_B_T_H_W_D)
x_B_T_H_W_D = x_B_T_H_W_D + gate_mlp_B_T_1_1_D * result_B_T_H_W_D
result_B_T_H_W_D = self.mlp(normalized_x_B_T_H_W_D.to(compute_dtype))
x_B_T_H_W_D = x_B_T_H_W_D + gate_mlp_B_T_1_1_D.to(residual_dtype) * result_B_T_H_W_D.to(residual_dtype)
return x_B_T_H_W_D


Expand Down Expand Up @@ -876,6 +878,14 @@ def _forward(
"extra_per_block_pos_emb": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D,
"transformer_options": kwargs.get("transformer_options", {}),
}

# The residual stream for this model has large values. To make fp16 compute_dtype work, we keep the residual stream
# in fp32, but run attention and MLP modules in fp16.
# An alternate method that clamps fp16 values "works" in the sense that it makes coherent images, but there is noticeable
# quality degradation and visual artifacts.
if x_B_T_H_W_D.dtype == torch.float16:
x_B_T_H_W_D = x_B_T_H_W_D.float()

for block in self.blocks:
x_B_T_H_W_D = block(
x_B_T_H_W_D,
Expand All @@ -884,6 +894,6 @@ def _forward(
**block_kwargs,
)

x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D, t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D)
x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D.to(crossattn_emb.dtype), t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D)
x_B_C_Tt_Hp_Wp = self.unpatchify(x_B_T_H_W_O)[:, :, :orig_shape[-3], :orig_shape[-2], :orig_shape[-1]]
return x_B_C_Tt_Hp_Wp
2 changes: 2 additions & 0 deletions comfy/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1552,6 +1552,8 @@ def extra_conds(self, **kwargs):

cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
if torch.count_nonzero(cross_attn) == 0:
out['replace_with_null_embeds'] = comfy.conds.CONDConstant(True)
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)

conditioning_lyrics = kwargs.get("conditioning_lyrics", None)
Expand Down
14 changes: 8 additions & 6 deletions comfy/supported_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -993,7 +993,7 @@ class CosmosT2IPredict2(supported_models_base.BASE):

memory_usage_factor = 1.0

supported_inference_dtypes = [torch.bfloat16, torch.float32]
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]

def __init__(self, unet_config):
super().__init__(unet_config)
Expand Down Expand Up @@ -1023,11 +1023,7 @@ class Anima(supported_models_base.BASE):

memory_usage_factor = 1.0

supported_inference_dtypes = [torch.bfloat16, torch.float32]

def __init__(self, unet_config):
super().__init__(unet_config)
self.memory_usage_factor = (unet_config.get("model_channels", 2048) / 2048) * 0.95
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]

def get_model(self, state_dict, prefix="", device=None):
out = model_base.Anima(self, device=device)
Expand All @@ -1038,6 +1034,12 @@ def clip_target(self, state_dict={}):
detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_06b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.anima.AnimaTokenizer, comfy.text_encoders.anima.te(**detect))

def set_inference_dtype(self, dtype, manual_cast_dtype, **kwargs):
self.memory_usage_factor = (self.unet_config.get("model_channels", 2048) / 2048) * 0.95
if dtype is torch.float16:
self.memory_usage_factor *= 1.4
return super().set_inference_dtype(dtype, manual_cast_dtype, **kwargs)

class CosmosI2VPredict2(CosmosT2IPredict2):
unet_config = {
"image_model": "cosmos_predict2",
Expand Down
2 changes: 1 addition & 1 deletion comfy/text_encoders/anima.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __init__(self, embedding_directory=None, tokenizer_data={}):
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
out = {}
qwen_ids = self.qwen3_06b.tokenize_with_weights(text, return_word_ids, **kwargs)
out["qwen3_06b"] = [[(token, 1.0) for token, _ in inner_list] for inner_list in qwen_ids] # Set weights to 1.0
out["qwen3_06b"] = [[(k[0], 1.0, k[2]) if return_word_ids else (k[0], 1.0) for k in inner_list] for inner_list in qwen_ids] # Set weights to 1.0
out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids, **kwargs)
return out

Expand Down
Loading