Skip to content

Commit dc202a2

Browse files
Properly save mixed ops. (Comfy-Org#11772)
1 parent 153bc52 commit dc202a2

File tree

2 files changed

+20
-12
lines changed

2 files changed

+20
-12
lines changed

comfy/ops.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -625,21 +625,29 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata,
625625
missing_keys.remove(key)
626626

627627
def state_dict(self, *args, destination=None, prefix="", **kwargs):
628-
sd = super().state_dict(*args, destination=destination, prefix=prefix, **kwargs)
629-
if isinstance(self.weight, QuantizedTensor):
630-
layout_cls = self.weight._layout_cls
628+
if destination is not None:
629+
sd = destination
630+
else:
631+
sd = {}
632+
633+
if self.bias is not None:
634+
sd["{}bias".format(prefix)] = self.bias
631635

632-
# Check if it's any FP8 variant (E4M3 or E5M2)
633-
if layout_cls in ("TensorCoreFP8E4M3Layout", "TensorCoreFP8E5M2Layout", "TensorCoreFP8Layout"):
634-
sd["{}weight_scale".format(prefix)] = self.weight._params.scale
635-
elif layout_cls == "TensorCoreNVFP4Layout":
636-
sd["{}weight_scale_2".format(prefix)] = self.weight._params.scale
637-
sd["{}weight_scale".format(prefix)] = self.weight._params.block_scale
636+
if isinstance(self.weight, QuantizedTensor):
637+
sd_out = self.weight.state_dict("{}weight".format(prefix))
638+
for k in sd_out:
639+
sd[k] = sd_out[k]
638640

639641
quant_conf = {"format": self.quant_format}
640642
if self._full_precision_mm_config:
641643
quant_conf["full_precision_matrix_mult"] = True
642644
sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8)
645+
646+
input_scale = getattr(self, 'input_scale', None)
647+
if input_scale is not None:
648+
sd["{}input_scale".format(prefix)] = input_scale
649+
else:
650+
sd["{}weight".format(prefix)] = self.weight
643651
return sd
644652

645653
def _forward(self, input, weight, bias):

tests-unit/comfy_quant/test_mixed_precision.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -153,9 +153,9 @@ def test_state_dict_quantized_preserved(self):
153153
state_dict2 = model.state_dict()
154154

155155
# Verify layer1.weight is a QuantizedTensor with scale preserved
156-
self.assertIsInstance(state_dict2["layer1.weight"], QuantizedTensor)
157-
self.assertEqual(state_dict2["layer1.weight"]._params.scale.item(), 3.0)
158-
self.assertEqual(state_dict2["layer1.weight"]._layout_cls, "TensorCoreFP8E4M3Layout")
156+
self.assertTrue(torch.equal(state_dict2["layer1.weight"].view(torch.uint8), fp8_weight.view(torch.uint8)))
157+
self.assertEqual(state_dict2["layer1.weight_scale"].item(), 3.0)
158+
self.assertEqual(model.layer1.weight._layout_cls, "TensorCoreFP8E4M3Layout")
159159

160160
# Verify non-quantized layers are standard tensors
161161
self.assertNotIsInstance(state_dict2["layer2.weight"], QuantizedTensor)

0 commit comments

Comments
 (0)