Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/stable-release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ jobs:
./python.exe get-pip.py
./python.exe -s -m pip install ../${{ inputs.cache_tag }}_python_deps/*
grep comfyui ../ComfyUI/requirements.txt > ./requirements_comfyui.txt
grep comfy ../ComfyUI/requirements.txt > ./requirements_comfyui.txt
./python.exe -s -m pip install -r requirements_comfyui.txt
rm requirements_comfyui.txt
Expand Down
2 changes: 2 additions & 0 deletions .github/workflows/test-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ jobs:
test-stable:
strategy:
fail-fast: false
max-parallel: 1 # This forces sequential execution
matrix:
# os: [macos, linux, windows]
# os: [macos, linux]
Expand Down Expand Up @@ -74,6 +75,7 @@ jobs:
test-unix-nightly:
strategy:
fail-fast: false
max-parallel: 1 # This forces sequential execution
matrix:
# os: [macos, linux]
os: [linux]
Expand Down
4 changes: 3 additions & 1 deletion comfy/latent_formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,9 @@ def __init__(self):
self.latent_rgb_factors_bias = [-0.0571, -0.1657, -0.2512]

class LTXAV(LTXV):
pass
def __init__(self):
self.latent_rgb_factors = None
self.latent_rgb_factors_bias = None

class HunyuanVideo(LatentFormat):
latent_channels = 16
Expand Down
2 changes: 1 addition & 1 deletion comfy/ldm/lightricks/embeddings_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def forward(
max(1024, hidden_states.shape[1]) / self.num_learnable_registers
)
learnable_registers = torch.tile(
self.learnable_registers, (num_registers_duplications, 1)
self.learnable_registers.to(hidden_states), (num_registers_duplications, 1)
)

hidden_states = torch.cat((hidden_states, learnable_registers[hidden_states.shape[1]:].unsqueeze(0).repeat(hidden_states.shape[0], 1, 1)), dim=1)
Expand Down
10 changes: 10 additions & 0 deletions comfy/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -1504,6 +1504,16 @@ def supports_fp8_compute(device=None):

return True

def supports_nvfp4_compute(device=None):
if not is_nvidia():
return False

props = torch.cuda.get_device_properties(device)
if props.major < 10:
return False

return True

def extended_fp16_support():
# TODO: check why some models work with fp16 on newer torch versions but not on older
if torch_version_numeric < (2, 7):
Expand Down
25 changes: 19 additions & 6 deletions comfy/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,12 +427,12 @@ def fp8_linear(self, input):
input = torch.clamp(input, min=-448, max=448, out=input)
input_fp8 = input.to(dtype).contiguous()
layout_params_input = TensorCoreFP8Layout.Params(scale=scale_input, orig_dtype=input_dtype, orig_shape=tuple(input_fp8.shape))
quantized_input = QuantizedTensor(input_fp8, TensorCoreFP8Layout, layout_params_input)
quantized_input = QuantizedTensor(input_fp8, "TensorCoreFP8Layout", layout_params_input)

# Wrap weight in QuantizedTensor - this enables unified dispatch
# Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
layout_params_weight = TensorCoreFP8Layout.Params(scale=scale_weight, orig_dtype=input_dtype, orig_shape=tuple(w.shape))
quantized_weight = QuantizedTensor(w, TensorCoreFP8Layout, layout_params_weight)
quantized_weight = QuantizedTensor(w, "TensorCoreFP8Layout", layout_params_weight)
o = torch.nn.functional.linear(quantized_input, quantized_weight, bias)

uncast_bias_weight(self, w, bias, offload_stream)
Expand Down Expand Up @@ -493,11 +493,12 @@ def forward(self, *args, **kwargs):
)


def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False):
def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False, disabled=[]):
class MixedPrecisionOps(manual_cast):
_quant_config = quant_config
_compute_dtype = compute_dtype
_full_precision_mm = full_precision_mm
_disabled = disabled

class Linear(torch.nn.Module, CastWeightBiasOp):
def __init__(
Expand All @@ -522,6 +523,7 @@ def __init__(

self.tensor_class = None
self._full_precision_mm = MixedPrecisionOps._full_precision_mm
self._full_precision_mm_config = False

def reset_parameters(self):
return None
Expand Down Expand Up @@ -556,8 +558,12 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata,
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
else:
self.quant_format = layer_conf.get("format", None)
self._full_precision_mm_config = layer_conf.get("full_precision_matrix_mult", False)
if not self._full_precision_mm:
self._full_precision_mm = layer_conf.get("full_precision_matrix_mult", False)
self._full_precision_mm = self._full_precision_mm_config

if self.quant_format in MixedPrecisionOps._disabled:
self._full_precision_mm = True

if self.quant_format is None:
raise ValueError(f"Unknown quantization format for layer {layer_name}")
Expand Down Expand Up @@ -630,7 +636,7 @@ def state_dict(self, *args, destination=None, prefix="", **kwargs):
sd["{}weight_scale".format(prefix)] = self.weight._params.block_scale

quant_conf = {"format": self.quant_format}
if self._full_precision_mm:
if self._full_precision_mm_config:
quant_conf["full_precision_matrix_mult"] = True
sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8)
return sd
Expand Down Expand Up @@ -711,10 +717,17 @@ def _apply(self, fn, recurse=True): # This is to get torch.compile + moving wei

def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None):
fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular
nvfp4_compute = comfy.model_management.supports_nvfp4_compute(load_device)

if model_config and hasattr(model_config, 'quant_config') and model_config.quant_config:
logging.info("Using mixed precision operations")
return mixed_precision_ops(model_config.quant_config, compute_dtype, full_precision_mm=not fp8_compute)
disabled = set()
if not nvfp4_compute:
disabled.add("nvfp4")
if not fp8_compute:
disabled.add("float8_e4m3fn")
disabled.add("float8_e5m2")
return mixed_precision_ops(model_config.quant_config, compute_dtype, disabled=disabled)

if (
fp8_compute and
Expand Down
7 changes: 7 additions & 0 deletions comfy/quant_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,13 @@
get_layout_class,
)
_CK_AVAILABLE = True
if torch.version.cuda is None:
ck.registry.disable("cuda")
else:
cuda_version = tuple(map(int, str(torch.version.cuda).split('.')))
if cuda_version < (13,):
ck.registry.disable("cuda")

ck.registry.disable("triton")
for k, v in ck.list_backends().items():
logging.info(f"Found comfy_kitchen backend {k}: {v}")
Expand Down
16 changes: 9 additions & 7 deletions comfy/text_encoders/lt.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ def __init__(self, embedding_directory=None, tokenizer_data={}):

class Gemma3_12BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="all", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
llama_scaled_fp8 = model_options.get("gemma_scaled_fp8", None)
if llama_scaled_fp8 is not None:
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["scaled_fp8"] = llama_scaled_fp8
model_options["quantization_metadata"] = llama_quantization_metadata

super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_12B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)

Expand Down Expand Up @@ -86,17 +86,19 @@ def __init__(self, dtype_llama=None, device="cpu", dtype=None, model_options={})
)

def set_clip_options(self, options):
self.execution_device = options.get("execution_device", self.execution_device)
self.gemma3_12b.set_clip_options(options)

def reset_clip_options(self):
self.gemma3_12b.reset_clip_options()
self.execution_device = None

def encode_token_weights(self, token_weight_pairs):
token_weight_pairs = token_weight_pairs["gemma3_12b"]

out, pooled, extra = self.gemma3_12b.encode_token_weights(token_weight_pairs)
out_device = out.device
out = out.movedim(1, -1).to(self.text_embedding_projection.weight.device)
out = out.movedim(1, -1).to(self.execution_device)
out = 8.0 * (out - out.mean(dim=(1, 2), keepdim=True)) / (out.amax(dim=(1, 2), keepdim=True) - out.amin(dim=(1, 2), keepdim=True) + 1e-6)
out = out.reshape((out.shape[0], out.shape[1], -1))
out = self.text_embedding_projection(out)
Expand All @@ -117,12 +119,12 @@ def load_sd(self, sd):
return self.load_state_dict(sdo, strict=False)


def ltxav_te(dtype_llama=None, llama_scaled_fp8=None):
def ltxav_te(dtype_llama=None, llama_quantization_metadata=None):
class LTXAVTEModel_(LTXAVTEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if llama_scaled_fp8 is not None and "llama_scaled_fp8" not in model_options:
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["llama_scaled_fp8"] = llama_scaled_fp8
model_options["llama_quantization_metadata"] = llama_quantization_metadata
if dtype_llama is not None:
dtype = dtype_llama
super().__init__(dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options)
Expand Down
Loading
Loading