diff --git a/comfy/ldm/chroma_radiance/model.py b/comfy/ldm/chroma_radiance/model.py index e643b4414629..70d1738890cb 100644 --- a/comfy/ldm/chroma_radiance/model.py +++ b/comfy/ldm/chroma_radiance/model.py @@ -37,7 +37,7 @@ class ChromaRadianceParams(ChromaParams): nerf_final_head_type: str # None means use the same dtype as the model. nerf_embedder_dtype: Optional[torch.dtype] - + use_x0: bool class ChromaRadiance(Chroma): """ @@ -159,6 +159,9 @@ def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, self.skip_dit = [] self.lite = False + if params.use_x0: + self.register_buffer("__x0__", torch.tensor([])) + @property def _nerf_final_layer(self) -> nn.Module: if self.params.nerf_final_head_type == "linear": @@ -276,6 +279,12 @@ def radiance_get_override_params(self, overrides: dict) -> ChromaRadianceParams: params_dict |= overrides return params.__class__(**params_dict) + def _apply_x0_residual(self, predicted, noisy, timesteps): + + # non zero during training to prevent 0 div + eps = 0.0 + return (noisy - predicted) / (timesteps.view(-1,1,1,1) + eps) + def _forward( self, x: Tensor, @@ -316,4 +325,11 @@ def _forward( transformer_options, attn_mask=kwargs.get("attention_mask", None), ) - return self.forward_nerf(img, img_out, params)[:, :, :h, :w] + + out = self.forward_nerf(img, img_out, params)[:, :, :h, :w] + + # If x0 variant → v-pred, just return this instead + if hasattr(self, "__x0__"): + out = self._apply_x0_residual(out, img, timestep) + return out + diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 74c547427b43..19e6aa954259 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -257,6 +257,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["nerf_tile_size"] = 512 dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear" dit_config["nerf_embedder_dtype"] = torch.float32 + if "__x0__" in state_dict_keys: # x0 pred + dit_config["use_x0"] = True else: dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys diff --git a/comfy/ops.py b/comfy/ops.py index 35237c9f73e2..6f34d50fc051 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -22,7 +22,6 @@ from comfy.cli_args import args, PerformanceFeature import comfy.float import comfy.rmsnorm -import contextlib import json def run_every_op(): @@ -94,13 +93,6 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of else: offload_stream = None - if offload_stream is not None: - wf_context = offload_stream - if hasattr(wf_context, "as_context"): - wf_context = wf_context.as_context(offload_stream) - else: - wf_context = contextlib.nullcontext() - non_blocking = comfy.model_management.device_supports_non_blocking(device) weight_has_function = len(s.weight_function) > 0