diff --git a/comfy/ldm/ace/ace_step15.py b/comfy/ldm/ace/ace_step15.py index 69338336d596..1d7dc59a8486 100644 --- a/comfy/ldm/ace/ace_step15.py +++ b/comfy/ldm/ace/ace_step15.py @@ -1110,7 +1110,7 @@ def prepare_condition( return encoder_hidden, encoder_mask, context_latents - def forward(self, x, timestep, context, lyric_embed=None, refer_audio=None, audio_codes=None, is_covers=None, **kwargs): + def forward(self, x, timestep, context, lyric_embed=None, refer_audio=None, audio_codes=None, is_covers=None, replace_with_null_embeds=False, **kwargs): text_attention_mask = None lyric_attention_mask = None refer_audio_order_mask = None @@ -1140,6 +1140,9 @@ def forward(self, x, timestep, context, lyric_embed=None, refer_audio=None, audi src_latents, chunk_masks, is_covers, precomputed_lm_hints_25Hz=precomputed_lm_hints_25Hz, audio_codes=audio_codes ) + if replace_with_null_embeds: + enc_hidden[:] = self.null_condition_emb.to(enc_hidden) + out = self.decoder(hidden_states=x, timestep=timestep, timestep_r=timestep, diff --git a/comfy/ldm/cosmos/predict2.py b/comfy/ldm/cosmos/predict2.py index c270e6333e1b..2268bff383a4 100644 --- a/comfy/ldm/cosmos/predict2.py +++ b/comfy/ldm/cosmos/predict2.py @@ -335,7 +335,7 @@ def __init__( device=None, dtype=None, operations=None ): super().__init__() - self.layer_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.layer_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.linear = operations.Linear( hidden_size, spatial_patch_size * spatial_patch_size * temporal_patch_size * out_channels, bias=False, device=device, dtype=dtype ) @@ -463,6 +463,8 @@ def forward( extra_per_block_pos_emb: Optional[torch.Tensor] = None, transformer_options: Optional[dict] = {}, ) -> torch.Tensor: + residual_dtype = x_B_T_H_W_D.dtype + compute_dtype = emb_B_T_D.dtype if extra_per_block_pos_emb is not None: x_B_T_H_W_D = x_B_T_H_W_D + extra_per_block_pos_emb @@ -512,7 +514,7 @@ def _fn(_x_B_T_H_W_D, _norm_layer, _scale_B_T_1_1_D, _shift_B_T_1_1_D): result_B_T_H_W_D = rearrange( self.self_attn( # normalized_x_B_T_HW_D, - rearrange(normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"), + rearrange(normalized_x_B_T_H_W_D.to(compute_dtype), "b t h w d -> b (t h w) d"), None, rope_emb=rope_emb_L_1_1_D, transformer_options=transformer_options, @@ -522,7 +524,7 @@ def _fn(_x_B_T_H_W_D, _norm_layer, _scale_B_T_1_1_D, _shift_B_T_1_1_D): h=H, w=W, ) - x_B_T_H_W_D = x_B_T_H_W_D + gate_self_attn_B_T_1_1_D * result_B_T_H_W_D + x_B_T_H_W_D = x_B_T_H_W_D + gate_self_attn_B_T_1_1_D.to(residual_dtype) * result_B_T_H_W_D.to(residual_dtype) def _x_fn( _x_B_T_H_W_D: torch.Tensor, @@ -536,7 +538,7 @@ def _x_fn( ) _result_B_T_H_W_D = rearrange( self.cross_attn( - rearrange(_normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"), + rearrange(_normalized_x_B_T_H_W_D.to(compute_dtype), "b t h w d -> b (t h w) d"), crossattn_emb, rope_emb=rope_emb_L_1_1_D, transformer_options=transformer_options, @@ -555,7 +557,7 @@ def _x_fn( shift_cross_attn_B_T_1_1_D, transformer_options=transformer_options, ) - x_B_T_H_W_D = result_B_T_H_W_D * gate_cross_attn_B_T_1_1_D + x_B_T_H_W_D + x_B_T_H_W_D = result_B_T_H_W_D.to(residual_dtype) * gate_cross_attn_B_T_1_1_D.to(residual_dtype) + x_B_T_H_W_D normalized_x_B_T_H_W_D = _fn( x_B_T_H_W_D, @@ -563,8 +565,8 @@ def _x_fn( scale_mlp_B_T_1_1_D, shift_mlp_B_T_1_1_D, ) - result_B_T_H_W_D = self.mlp(normalized_x_B_T_H_W_D) - x_B_T_H_W_D = x_B_T_H_W_D + gate_mlp_B_T_1_1_D * result_B_T_H_W_D + result_B_T_H_W_D = self.mlp(normalized_x_B_T_H_W_D.to(compute_dtype)) + x_B_T_H_W_D = x_B_T_H_W_D + gate_mlp_B_T_1_1_D.to(residual_dtype) * result_B_T_H_W_D.to(residual_dtype) return x_B_T_H_W_D @@ -876,6 +878,14 @@ def _forward( "extra_per_block_pos_emb": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, "transformer_options": kwargs.get("transformer_options", {}), } + + # The residual stream for this model has large values. To make fp16 compute_dtype work, we keep the residual stream + # in fp32, but run attention and MLP modules in fp16. + # An alternate method that clamps fp16 values "works" in the sense that it makes coherent images, but there is noticeable + # quality degradation and visual artifacts. + if x_B_T_H_W_D.dtype == torch.float16: + x_B_T_H_W_D = x_B_T_H_W_D.float() + for block in self.blocks: x_B_T_H_W_D = block( x_B_T_H_W_D, @@ -884,6 +894,6 @@ def _forward( **block_kwargs, ) - x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D, t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D) + x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D.to(crossattn_emb.dtype), t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D) x_B_C_Tt_Hp_Wp = self.unpatchify(x_B_T_H_W_O)[:, :, :orig_shape[-3], :orig_shape[-2], :orig_shape[-1]] return x_B_C_Tt_Hp_Wp diff --git a/comfy/model_base.py b/comfy/model_base.py index 3bb54f59e661..3aa345254b6e 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1552,6 +1552,8 @@ def extra_conds(self, **kwargs): cross_attn = kwargs.get("cross_attn", None) if cross_attn is not None: + if torch.count_nonzero(cross_attn) == 0: + out['replace_with_null_embeds'] = comfy.conds.CONDConstant(True) out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) conditioning_lyrics = kwargs.get("conditioning_lyrics", None) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 77264ed28eca..d33db750775d 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -993,7 +993,7 @@ class CosmosT2IPredict2(supported_models_base.BASE): memory_usage_factor = 1.0 - supported_inference_dtypes = [torch.bfloat16, torch.float32] + supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] def __init__(self, unet_config): super().__init__(unet_config) @@ -1023,11 +1023,7 @@ class Anima(supported_models_base.BASE): memory_usage_factor = 1.0 - supported_inference_dtypes = [torch.bfloat16, torch.float32] - - def __init__(self, unet_config): - super().__init__(unet_config) - self.memory_usage_factor = (unet_config.get("model_channels", 2048) / 2048) * 0.95 + supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] def get_model(self, state_dict, prefix="", device=None): out = model_base.Anima(self, device=device) @@ -1038,6 +1034,12 @@ def clip_target(self, state_dict={}): detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_06b.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.anima.AnimaTokenizer, comfy.text_encoders.anima.te(**detect)) + def set_inference_dtype(self, dtype, manual_cast_dtype, **kwargs): + self.memory_usage_factor = (self.unet_config.get("model_channels", 2048) / 2048) * 0.95 + if dtype is torch.float16: + self.memory_usage_factor *= 1.4 + return super().set_inference_dtype(dtype, manual_cast_dtype, **kwargs) + class CosmosI2VPredict2(CosmosT2IPredict2): unet_config = { "image_model": "cosmos_predict2", diff --git a/comfy/text_encoders/anima.py b/comfy/text_encoders/anima.py index b6f58cb259db..d8c5a6f92220 100644 --- a/comfy/text_encoders/anima.py +++ b/comfy/text_encoders/anima.py @@ -23,7 +23,7 @@ def __init__(self, embedding_directory=None, tokenizer_data={}): def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): out = {} qwen_ids = self.qwen3_06b.tokenize_with_weights(text, return_word_ids, **kwargs) - out["qwen3_06b"] = [[(token, 1.0) for token, _ in inner_list] for inner_list in qwen_ids] # Set weights to 1.0 + out["qwen3_06b"] = [[(k[0], 1.0, k[2]) if return_word_ids else (k[0], 1.0) for k in inner_list] for inner_list in qwen_ids] # Set weights to 1.0 out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids, **kwargs) return out