Skip to content

Commit 25022e0

Browse files
Cleanup and fix issues with text encoder quants. (Comfy-Org#10872)
1 parent 22a2644 commit 25022e0

File tree

7 files changed

+138
-112
lines changed

7 files changed

+138
-112
lines changed

comfy/model_patcher.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,6 @@ def __init__(self, model, load_device, offload_device, size=0, weight_inplace_up
231231
self.object_patches_backup = {}
232232
self.weight_wrapper_patches = {}
233233
self.model_options = {"transformer_options":{}}
234-
self.model_size()
235234
self.load_device = load_device
236235
self.offload_device = offload_device
237236
self.weight_inplace_update = weight_inplace_update
@@ -286,7 +285,7 @@ def lowvram_patch_counter(self):
286285
return self.model.lowvram_patch_counter
287286

288287
def clone(self):
289-
n = self.__class__(self.model, self.load_device, self.offload_device, self.size, weight_inplace_update=self.weight_inplace_update)
288+
n = self.__class__(self.model, self.load_device, self.offload_device, self.model_size(), weight_inplace_update=self.weight_inplace_update)
290289
n.patches = {}
291290
for k in self.patches:
292291
n.patches[k] = self.patches[k][:]

comfy/ops.py

Lines changed: 95 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -540,113 +540,115 @@ def forward(self, *args, **kwargs):
540540
# ==============================================================================
541541
from .quant_ops import QuantizedTensor, QUANT_ALGOS
542542

543-
class MixedPrecisionOps(disable_weight_init):
544-
_layer_quant_config = {}
545-
_compute_dtype = torch.bfloat16
546-
547-
class Linear(torch.nn.Module, CastWeightBiasOp):
548-
def __init__(
549-
self,
550-
in_features: int,
551-
out_features: int,
552-
bias: bool = True,
553-
device=None,
554-
dtype=None,
555-
) -> None:
556-
super().__init__()
557-
558-
self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype}
559-
# self.factory_kwargs = {"device": device, "dtype": dtype}
560-
561-
self.in_features = in_features
562-
self.out_features = out_features
563-
if bias:
564-
self.bias = torch.nn.Parameter(torch.empty(out_features, **self.factory_kwargs))
565-
else:
566-
self.register_parameter("bias", None)
567543

568-
self.tensor_class = None
544+
def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False):
545+
class MixedPrecisionOps(manual_cast):
546+
_layer_quant_config = layer_quant_config
547+
_compute_dtype = compute_dtype
548+
_full_precision_mm = full_precision_mm
549+
550+
class Linear(torch.nn.Module, CastWeightBiasOp):
551+
def __init__(
552+
self,
553+
in_features: int,
554+
out_features: int,
555+
bias: bool = True,
556+
device=None,
557+
dtype=None,
558+
) -> None:
559+
super().__init__()
560+
561+
self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype}
562+
# self.factory_kwargs = {"device": device, "dtype": dtype}
563+
564+
self.in_features = in_features
565+
self.out_features = out_features
566+
if bias:
567+
self.bias = torch.nn.Parameter(torch.empty(out_features, **self.factory_kwargs))
568+
else:
569+
self.register_parameter("bias", None)
569570

570-
def reset_parameters(self):
571-
return None
571+
self.tensor_class = None
572+
self._full_precision_mm = MixedPrecisionOps._full_precision_mm
572573

573-
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
574-
strict, missing_keys, unexpected_keys, error_msgs):
574+
def reset_parameters(self):
575+
return None
575576

576-
device = self.factory_kwargs["device"]
577-
layer_name = prefix.rstrip('.')
578-
weight_key = f"{prefix}weight"
579-
weight = state_dict.pop(weight_key, None)
580-
if weight is None:
581-
raise ValueError(f"Missing weight for layer {layer_name}")
577+
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
578+
strict, missing_keys, unexpected_keys, error_msgs):
582579

583-
manually_loaded_keys = [weight_key]
580+
device = self.factory_kwargs["device"]
581+
layer_name = prefix.rstrip('.')
582+
weight_key = f"{prefix}weight"
583+
weight = state_dict.pop(weight_key, None)
584+
if weight is None:
585+
raise ValueError(f"Missing weight for layer {layer_name}")
584586

585-
if layer_name not in MixedPrecisionOps._layer_quant_config:
586-
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
587-
else:
588-
quant_format = MixedPrecisionOps._layer_quant_config[layer_name].get("format", None)
589-
if quant_format is None:
590-
raise ValueError(f"Unknown quantization format for layer {layer_name}")
591-
592-
qconfig = QUANT_ALGOS[quant_format]
593-
self.layout_type = qconfig["comfy_tensor_layout"]
594-
595-
weight_scale_key = f"{prefix}weight_scale"
596-
layout_params = {
597-
'scale': state_dict.pop(weight_scale_key, None),
598-
'orig_dtype': MixedPrecisionOps._compute_dtype,
599-
'block_size': qconfig.get("group_size", None),
600-
}
601-
if layout_params['scale'] is not None:
602-
manually_loaded_keys.append(weight_scale_key)
603-
604-
self.weight = torch.nn.Parameter(
605-
QuantizedTensor(weight.to(device=device), self.layout_type, layout_params),
606-
requires_grad=False
607-
)
608-
609-
for param_name in qconfig["parameters"]:
610-
param_key = f"{prefix}{param_name}"
611-
_v = state_dict.pop(param_key, None)
612-
if _v is None:
613-
continue
614-
setattr(self, param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
615-
manually_loaded_keys.append(param_key)
616-
617-
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
618-
619-
for key in manually_loaded_keys:
620-
if key in missing_keys:
621-
missing_keys.remove(key)
622-
623-
def _forward(self, input, weight, bias):
624-
return torch.nn.functional.linear(input, weight, bias)
587+
manually_loaded_keys = [weight_key]
625588

626-
def forward_comfy_cast_weights(self, input):
627-
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
628-
x = self._forward(input, weight, bias)
629-
uncast_bias_weight(self, weight, bias, offload_stream)
630-
return x
589+
if layer_name not in MixedPrecisionOps._layer_quant_config:
590+
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
591+
else:
592+
quant_format = MixedPrecisionOps._layer_quant_config[layer_name].get("format", None)
593+
if quant_format is None:
594+
raise ValueError(f"Unknown quantization format for layer {layer_name}")
595+
596+
qconfig = QUANT_ALGOS[quant_format]
597+
self.layout_type = qconfig["comfy_tensor_layout"]
598+
599+
weight_scale_key = f"{prefix}weight_scale"
600+
layout_params = {
601+
'scale': state_dict.pop(weight_scale_key, None),
602+
'orig_dtype': MixedPrecisionOps._compute_dtype,
603+
'block_size': qconfig.get("group_size", None),
604+
}
605+
if layout_params['scale'] is not None:
606+
manually_loaded_keys.append(weight_scale_key)
607+
608+
self.weight = torch.nn.Parameter(
609+
QuantizedTensor(weight.to(device=device), self.layout_type, layout_params),
610+
requires_grad=False
611+
)
612+
613+
for param_name in qconfig["parameters"]:
614+
param_key = f"{prefix}{param_name}"
615+
_v = state_dict.pop(param_key, None)
616+
if _v is None:
617+
continue
618+
setattr(self, param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
619+
manually_loaded_keys.append(param_key)
620+
621+
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
622+
623+
for key in manually_loaded_keys:
624+
if key in missing_keys:
625+
missing_keys.remove(key)
626+
627+
def _forward(self, input, weight, bias):
628+
return torch.nn.functional.linear(input, weight, bias)
631629

632-
def forward(self, input, *args, **kwargs):
633-
run_every_op()
630+
def forward_comfy_cast_weights(self, input):
631+
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
632+
x = self._forward(input, weight, bias)
633+
uncast_bias_weight(self, weight, bias, offload_stream)
634+
return x
634635

635-
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
636-
return self.forward_comfy_cast_weights(input, *args, **kwargs)
637-
if (getattr(self, 'layout_type', None) is not None and
638-
getattr(self, 'input_scale', None) is not None and
639-
not isinstance(input, QuantizedTensor)):
640-
input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, dtype=self.weight.dtype)
641-
return self._forward(input, self.weight, self.bias)
636+
def forward(self, input, *args, **kwargs):
637+
run_every_op()
642638

639+
if self._full_precision_mm or self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
640+
return self.forward_comfy_cast_weights(input, *args, **kwargs)
641+
if (getattr(self, 'layout_type', None) is not None and
642+
getattr(self, 'input_scale', None) is not None and
643+
not isinstance(input, QuantizedTensor)):
644+
input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, dtype=self.weight.dtype)
645+
return self._forward(input, self.weight, self.bias)
646+
return MixedPrecisionOps
643647

644648
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None):
645649
if model_config and hasattr(model_config, 'layer_quant_config') and model_config.layer_quant_config:
646-
MixedPrecisionOps._layer_quant_config = model_config.layer_quant_config
647-
MixedPrecisionOps._compute_dtype = compute_dtype
648650
logging.info(f"Using mixed precision operations: {len(model_config.layer_quant_config)} quantized layers")
649-
return MixedPrecisionOps
651+
return mixed_precision_ops(model_config.layer_quant_config, compute_dtype)
650652

651653
fp8_compute = comfy.model_management.supports_fp8_compute(load_device)
652654
if scaled_fp8 is not None:

comfy/quant_ops.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,18 @@ def generic_copy_(func, args, kwargs):
338338
return func(*args, **kwargs)
339339

340340

341+
@register_generic_util(torch.ops.aten.to.dtype)
342+
def generic_to_dtype(func, args, kwargs):
343+
"""Handle .to(dtype) calls - dtype conversion only."""
344+
src = args[0]
345+
if isinstance(src, QuantizedTensor):
346+
# For dtype-only conversion, just change the orig_dtype, no real cast is needed
347+
target_dtype = args[1] if len(args) > 1 else kwargs.get('dtype')
348+
src._layout_params["orig_dtype"] = target_dtype
349+
return src
350+
return func(*args, **kwargs)
351+
352+
341353
@register_generic_util(torch.ops.aten._has_compatible_shallow_copy_type.default)
342354
def generic_has_compatible_shallow_copy_type(func, args, kwargs):
343355
return True

comfy/sd.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -917,7 +917,12 @@ class CLIPType(Enum):
917917
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
918918
clip_data = []
919919
for p in ckpt_paths:
920-
clip_data.append(comfy.utils.load_torch_file(p, safe_load=True))
920+
sd, metadata = comfy.utils.load_torch_file(p, safe_load=True, return_metadata=True)
921+
if metadata is not None:
922+
quant_metadata = metadata.get("_quantization_metadata", None)
923+
if quant_metadata is not None:
924+
sd["_quantization_metadata"] = quant_metadata
925+
clip_data.append(sd)
921926
return load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options)
922927

923928

@@ -1142,6 +1147,8 @@ class EmptyClass:
11421147

11431148
parameters = 0
11441149
for c in clip_data:
1150+
if "_quantization_metadata" in c:
1151+
c.pop("_quantization_metadata")
11451152
parameters += comfy.utils.calculate_parameters(c)
11461153
tokenizer_data, model_options = comfy.text_encoders.long_clipl.model_options_long_clip(c, tokenizer_data, model_options)
11471154

comfy/sd1_clip.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -109,13 +109,23 @@ def __init__(self, device="cpu", max_length=77,
109109

110110
operations = model_options.get("custom_operations", None)
111111
scaled_fp8 = None
112+
quantization_metadata = model_options.get("quantization_metadata", None)
112113

113114
if operations is None:
114-
scaled_fp8 = model_options.get("scaled_fp8", None)
115-
if scaled_fp8 is not None:
116-
operations = comfy.ops.scaled_fp8_ops(fp8_matrix_mult=False, override_dtype=scaled_fp8)
115+
layer_quant_config = None
116+
if quantization_metadata is not None:
117+
layer_quant_config = json.loads(quantization_metadata).get("layers", None)
118+
119+
if layer_quant_config is not None:
120+
operations = comfy.ops.mixed_precision_ops(layer_quant_config, dtype, full_precision_mm=True)
121+
logging.info(f"Using MixedPrecisionOps for text encoder: {len(layer_quant_config)} quantized layers")
117122
else:
118-
operations = comfy.ops.manual_cast
123+
# Fallback to scaled_fp8_ops for backward compatibility
124+
scaled_fp8 = model_options.get("scaled_fp8", None)
125+
if scaled_fp8 is not None:
126+
operations = comfy.ops.scaled_fp8_ops(fp8_matrix_mult=False, override_dtype=scaled_fp8)
127+
else:
128+
operations = comfy.ops.manual_cast
119129

120130
self.operations = operations
121131
self.transformer = model_class(config, dtype, device, self.operations)

comfy/text_encoders/hunyuan_video.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@ def llama_detect(state_dict, prefix=""):
1818
if scaled_fp8_key in state_dict:
1919
out["llama_scaled_fp8"] = state_dict[scaled_fp8_key].dtype
2020

21+
if "_quantization_metadata" in state_dict:
22+
out["llama_quantization_metadata"] = state_dict["_quantization_metadata"]
23+
2124
return out
2225

2326

tests-unit/comfy_quant/test_mixed_precision.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,8 @@ class TestMixedPrecisionOps(unittest.TestCase):
3737

3838
def test_all_layers_standard(self):
3939
"""Test that model with no quantization works normally"""
40-
# Configure no quantization
41-
ops.MixedPrecisionOps._layer_quant_config = {}
42-
4340
# Create model
44-
model = SimpleModel(operations=ops.MixedPrecisionOps)
41+
model = SimpleModel(operations=ops.mixed_precision_ops({}))
4542

4643
# Initialize weights manually
4744
model.layer1.weight = torch.nn.Parameter(torch.randn(20, 10, dtype=torch.bfloat16))
@@ -76,7 +73,6 @@ def test_mixed_precision_load(self):
7673
"params": {}
7774
}
7875
}
79-
ops.MixedPrecisionOps._layer_quant_config = layer_quant_config
8076

8177
# Create state dict with mixed precision
8278
fp8_weight1 = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn)
@@ -99,7 +95,7 @@ def test_mixed_precision_load(self):
9995
}
10096

10197
# Create model and load state dict (strict=False because custom loading pops keys)
102-
model = SimpleModel(operations=ops.MixedPrecisionOps)
98+
model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config))
10399
model.load_state_dict(state_dict, strict=False)
104100

105101
# Verify weights are wrapped in QuantizedTensor
@@ -132,7 +128,6 @@ def test_state_dict_quantized_preserved(self):
132128
"params": {}
133129
}
134130
}
135-
ops.MixedPrecisionOps._layer_quant_config = layer_quant_config
136131

137132
# Create and load model
138133
fp8_weight = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn)
@@ -146,7 +141,7 @@ def test_state_dict_quantized_preserved(self):
146141
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
147142
}
148143

149-
model = SimpleModel(operations=ops.MixedPrecisionOps)
144+
model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config))
150145
model.load_state_dict(state_dict1, strict=False)
151146

152147
# Save state dict
@@ -170,7 +165,6 @@ def test_weight_function_compatibility(self):
170165
"params": {}
171166
}
172167
}
173-
ops.MixedPrecisionOps._layer_quant_config = layer_quant_config
174168

175169
# Create and load model
176170
fp8_weight = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn)
@@ -184,7 +178,7 @@ def test_weight_function_compatibility(self):
184178
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
185179
}
186180

187-
model = SimpleModel(operations=ops.MixedPrecisionOps)
181+
model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config))
188182
model.load_state_dict(state_dict, strict=False)
189183

190184
# Add a weight function (simulating LoRA)
@@ -210,7 +204,6 @@ def test_error_handling_unknown_format(self):
210204
"params": {}
211205
}
212206
}
213-
ops.MixedPrecisionOps._layer_quant_config = layer_quant_config
214207

215208
# Create state dict
216209
state_dict = {
@@ -223,7 +216,7 @@ def test_error_handling_unknown_format(self):
223216
}
224217

225218
# Load should raise KeyError for unknown format in QUANT_FORMAT_MIXINS
226-
model = SimpleModel(operations=ops.MixedPrecisionOps)
219+
model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config))
227220
with self.assertRaises(KeyError):
228221
model.load_state_dict(state_dict, strict=False)
229222

0 commit comments

Comments
 (0)