diff --git a/comfy/ldm/hunyuan_video/model.py b/comfy/ldm/hunyuan_video/model.py index 2749c53f516d..55ab550f879a 100644 --- a/comfy/ldm/hunyuan_video/model.py +++ b/comfy/ldm/hunyuan_video/model.py @@ -43,6 +43,7 @@ class HunyuanVideoParams: meanflow: bool use_cond_type_embedding: bool vision_in_dim: int + meanflow_sum: bool class SelfAttentionRef(nn.Module): @@ -317,7 +318,7 @@ def forward_orig( timesteps_r = transformer_options['sample_sigmas'][w[0] + 1] timesteps_r = timesteps_r.unsqueeze(0).to(device=timesteps.device, dtype=timesteps.dtype) vec_r = self.time_r_in(timestep_embedding(timesteps_r, 256, time_factor=1000.0).to(img.dtype)) - vec = (vec + vec_r) / 2 + vec = (vec + vec_r) if self.params.meanflow_sum else (vec + vec_r) / 2 if ref_latent is not None: ref_latent_ids = self.img_ids(ref_latent) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 19e6aa954259..1f5d34bddae4 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -180,8 +180,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["use_cond_type_embedding"] = False if '{}vision_in.proj.0.weight'.format(key_prefix) in state_dict_keys: dit_config["vision_in_dim"] = state_dict['{}vision_in.proj.0.weight'.format(key_prefix)].shape[0] + dit_config["meanflow_sum"] = True else: dit_config["vision_in_dim"] = None + dit_config["meanflow_sum"] = False return dit_config if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}distilled_guidance_layer.norms.0.scale" in state_dict_keys): #Flux, Chroma or Chroma Radiance (has no img_in.weight) diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index 571d3f760624..cd96541d78eb 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -399,7 +399,10 @@ def quantize(cls, tensor, scale=None, dtype=torch.float8_e4m3fn, stochastic_roun orig_dtype = tensor.dtype if isinstance(scale, str) and scale == "recalculate": - scale = torch.amax(tensor.abs()) / torch.finfo(dtype).max + scale = torch.amax(tensor.abs()).to(dtype=torch.float32) / torch.finfo(dtype).max + if tensor.dtype not in [torch.float32, torch.bfloat16]: # Prevent scale from being too small + tensor_info = torch.finfo(tensor.dtype) + scale = (1.0 / torch.clamp((1.0 / scale), min=tensor_info.min, max=tensor_info.max)) if scale is not None: if not isinstance(scale, torch.Tensor):