Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions comfy/ldm/chroma_radiance/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class ChromaRadianceParams(ChromaParams):
nerf_final_head_type: str
# None means use the same dtype as the model.
nerf_embedder_dtype: Optional[torch.dtype]

use_x0: bool

class ChromaRadiance(Chroma):
"""
Expand Down Expand Up @@ -159,6 +159,9 @@ def __init__(self, image_model=None, final_layer=True, dtype=None, device=None,
self.skip_dit = []
self.lite = False

if params.use_x0:
self.register_buffer("__x0__", torch.tensor([]))

@property
def _nerf_final_layer(self) -> nn.Module:
if self.params.nerf_final_head_type == "linear":
Expand Down Expand Up @@ -276,6 +279,12 @@ def radiance_get_override_params(self, overrides: dict) -> ChromaRadianceParams:
params_dict |= overrides
return params.__class__(**params_dict)

def _apply_x0_residual(self, predicted, noisy, timesteps):

# non zero during training to prevent 0 div
eps = 0.0
return (noisy - predicted) / (timesteps.view(-1,1,1,1) + eps)

def _forward(
self,
x: Tensor,
Expand Down Expand Up @@ -316,4 +325,11 @@ def _forward(
transformer_options,
attn_mask=kwargs.get("attention_mask", None),
)
return self.forward_nerf(img, img_out, params)[:, :, :h, :w]

out = self.forward_nerf(img, img_out, params)[:, :, :h, :w]

# If x0 variant → v-pred, just return this instead
if hasattr(self, "__x0__"):
out = self._apply_x0_residual(out, img, timestep)
return out

2 changes: 2 additions & 0 deletions comfy/model_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["nerf_tile_size"] = 512
dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear"
dit_config["nerf_embedder_dtype"] = torch.float32
if "__x0__" in state_dict_keys: # x0 pred
dit_config["use_x0"] = True
else:
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys
Expand Down
8 changes: 0 additions & 8 deletions comfy/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from comfy.cli_args import args, PerformanceFeature
import comfy.float
import comfy.rmsnorm
import contextlib
import json

def run_every_op():
Expand Down Expand Up @@ -94,13 +93,6 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
else:
offload_stream = None

if offload_stream is not None:
wf_context = offload_stream
if hasattr(wf_context, "as_context"):
wf_context = wf_context.as_context(offload_stream)
else:
wf_context = contextlib.nullcontext()

non_blocking = comfy.model_management.device_supports_non_blocking(device)

weight_has_function = len(s.weight_function) > 0
Expand Down
Loading