Skip to content

Commit 8817f8f

Browse files
authored
Mixed Precision Quantization System (Comfy-Org#10498)
* Implement mixed precision operations with a registry design and metadate for quant spec in checkpoint. * Updated design using Tensor Subclasses * Fix FP8 MM * An actually functional POC * Remove CK reference and ensure correct compute dtype * Update unit tests * ruff lint * Implement mixed precision operations with a registry design and metadate for quant spec in checkpoint. * Updated design using Tensor Subclasses * Fix FP8 MM * An actually functional POC * Remove CK reference and ensure correct compute dtype * Update unit tests * ruff lint * Fix missing keys * Rename quant dtype parameter * Rename quant dtype parameter * Fix unittests for CPU build
1 parent 22e40d2 commit 8817f8f

File tree

8 files changed

+1030
-19
lines changed

8 files changed

+1030
-19
lines changed

comfy/model_base.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def __init__(self, model_config, model_type=ModelType.EPS, device=None, unet_mod
134134
if not unet_config.get("disable_unet_model_creation", False):
135135
if model_config.custom_operations is None:
136136
fp8 = model_config.optimizations.get("fp8", False)
137-
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8)
137+
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8, model_config=model_config)
138138
else:
139139
operations = model_config.custom_operations
140140
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
@@ -333,6 +333,14 @@ def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_
333333
if self.model_config.scaled_fp8 is not None:
334334
unet_state_dict["scaled_fp8"] = torch.tensor([], dtype=self.model_config.scaled_fp8)
335335

336+
# Save mixed precision metadata
337+
if hasattr(self.model_config, 'layer_quant_config') and self.model_config.layer_quant_config:
338+
metadata = {
339+
"format_version": "1.0",
340+
"layers": self.model_config.layer_quant_config
341+
}
342+
unet_state_dict["_quantization_metadata"] = metadata
343+
336344
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
337345

338346
if self.model_type == ModelType.V_PREDICTION:

comfy/model_detection.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,20 @@
66
import logging
77
import torch
88

9+
10+
def detect_layer_quantization(metadata):
11+
quant_key = "_quantization_metadata"
12+
if metadata is not None and quant_key in metadata:
13+
quant_metadata = metadata.pop(quant_key)
14+
quant_metadata = json.loads(quant_metadata)
15+
if isinstance(quant_metadata, dict) and "layers" in quant_metadata:
16+
logging.info(f"Found quantization metadata (version {quant_metadata.get('format_version', 'unknown')})")
17+
return quant_metadata["layers"]
18+
else:
19+
raise ValueError("Invalid quantization metadata format")
20+
return None
21+
22+
923
def count_blocks(state_dict_keys, prefix_string):
1024
count = 0
1125
while True:
@@ -701,6 +715,12 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
701715
else:
702716
model_config.optimizations["fp8"] = True
703717

718+
# Detect per-layer quantization (mixed precision)
719+
layer_quant_config = detect_layer_quantization(metadata)
720+
if layer_quant_config:
721+
model_config.layer_quant_config = layer_quant_config
722+
logging.info(f"Detected mixed precision quantization: {len(layer_quant_config)} layers quantized")
723+
704724
return model_config
705725

706726
def unet_prefix_from_state_dict(state_dict):

comfy/ops.py

Lines changed: 133 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,10 @@ class Embedding(disable_weight_init.Embedding):
344344

345345

346346
def fp8_linear(self, input):
347+
"""
348+
Legacy FP8 linear function for backward compatibility.
349+
Uses QuantizedTensor subclass for dispatch.
350+
"""
347351
dtype = self.weight.dtype
348352
if dtype not in [torch.float8_e4m3fn]:
349353
return None
@@ -355,9 +359,9 @@ def fp8_linear(self, input):
355359

356360
input_shape = input.shape
357361
input_dtype = input.dtype
362+
358363
if len(input.shape) == 3:
359364
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype)
360-
w = w.t()
361365

362366
scale_weight = self.scale_weight
363367
scale_input = self.scale_input
@@ -368,23 +372,18 @@ def fp8_linear(self, input):
368372

369373
if scale_input is None:
370374
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
371-
input = torch.clamp(input, min=-448, max=448, out=input)
372-
input = input.reshape(-1, input_shape[2]).to(dtype).contiguous()
373375
else:
374376
scale_input = scale_input.to(input.device)
375-
input = (input * (1.0 / scale_input).to(input_dtype)).reshape(-1, input_shape[2]).to(dtype).contiguous()
376-
377-
if bias is not None:
378-
o = torch._scaled_mm(input, w, out_dtype=input_dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)
379-
else:
380-
o = torch._scaled_mm(input, w, out_dtype=input_dtype, scale_a=scale_input, scale_b=scale_weight)
381377

382-
if isinstance(o, tuple):
383-
o = o[0]
378+
# Wrap weight in QuantizedTensor - this enables unified dispatch
379+
# Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
380+
layout_params_weight = {'scale': scale_weight, 'orig_dtype': input_dtype}
381+
quantized_weight = QuantizedTensor(w, TensorCoreFP8Layout, layout_params_weight)
382+
quantized_input = QuantizedTensor.from_float(input.reshape(-1, input_shape[2]), TensorCoreFP8Layout, scale=scale_input, dtype=dtype)
383+
o = torch.nn.functional.linear(quantized_input, quantized_weight, bias)
384384

385385
if tensor_2d:
386386
return o.reshape(input_shape[0], -1)
387-
388387
return o.reshape((-1, input_shape[1], self.weight.shape[0]))
389388

390389
return None
@@ -478,7 +477,128 @@ def forward_comfy_cast_weights(self, input):
478477
def forward(self, *args, **kwargs):
479478
return super().forward(*args, **kwargs)
480479

481-
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None):
480+
481+
# ==============================================================================
482+
# Mixed Precision Operations
483+
# ==============================================================================
484+
from .quant_ops import QuantizedTensor, TensorCoreFP8Layout
485+
486+
QUANT_FORMAT_MIXINS = {
487+
"float8_e4m3fn": {
488+
"dtype": torch.float8_e4m3fn,
489+
"layout_type": TensorCoreFP8Layout,
490+
"parameters": {
491+
"weight_scale": torch.nn.Parameter(torch.zeros((), dtype=torch.float32), requires_grad=False),
492+
"input_scale": torch.nn.Parameter(torch.zeros((), dtype=torch.float32), requires_grad=False),
493+
}
494+
}
495+
}
496+
497+
class MixedPrecisionOps(disable_weight_init):
498+
_layer_quant_config = {}
499+
_compute_dtype = torch.bfloat16
500+
501+
class Linear(torch.nn.Module, CastWeightBiasOp):
502+
def __init__(
503+
self,
504+
in_features: int,
505+
out_features: int,
506+
bias: bool = True,
507+
device=None,
508+
dtype=None,
509+
) -> None:
510+
super().__init__()
511+
512+
self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype}
513+
# self.factory_kwargs = {"device": device, "dtype": dtype}
514+
515+
self.in_features = in_features
516+
self.out_features = out_features
517+
if bias:
518+
self.bias = torch.nn.Parameter(torch.empty(out_features, **self.factory_kwargs))
519+
else:
520+
self.register_parameter("bias", None)
521+
522+
self.tensor_class = None
523+
524+
def reset_parameters(self):
525+
return None
526+
527+
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
528+
strict, missing_keys, unexpected_keys, error_msgs):
529+
530+
device = self.factory_kwargs["device"]
531+
layer_name = prefix.rstrip('.')
532+
weight_key = f"{prefix}weight"
533+
weight = state_dict.pop(weight_key, None)
534+
if weight is None:
535+
raise ValueError(f"Missing weight for layer {layer_name}")
536+
537+
manually_loaded_keys = [weight_key]
538+
539+
if layer_name not in MixedPrecisionOps._layer_quant_config:
540+
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
541+
else:
542+
quant_format = MixedPrecisionOps._layer_quant_config[layer_name].get("format", None)
543+
if quant_format is None:
544+
raise ValueError(f"Unknown quantization format for layer {layer_name}")
545+
546+
mixin = QUANT_FORMAT_MIXINS[quant_format]
547+
self.layout_type = mixin["layout_type"]
548+
549+
scale_key = f"{prefix}weight_scale"
550+
layout_params = {
551+
'scale': state_dict.pop(scale_key, None),
552+
'orig_dtype': MixedPrecisionOps._compute_dtype
553+
}
554+
if layout_params['scale'] is not None:
555+
manually_loaded_keys.append(scale_key)
556+
557+
self.weight = torch.nn.Parameter(
558+
QuantizedTensor(weight.to(device=device, dtype=mixin["dtype"]), self.layout_type, layout_params),
559+
requires_grad=False
560+
)
561+
562+
for param_name, param_value in mixin["parameters"].items():
563+
param_key = f"{prefix}{param_name}"
564+
_v = state_dict.pop(param_key, None)
565+
if _v is None:
566+
continue
567+
setattr(self, param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
568+
manually_loaded_keys.append(param_key)
569+
570+
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
571+
572+
for key in manually_loaded_keys:
573+
if key in missing_keys:
574+
missing_keys.remove(key)
575+
576+
def _forward(self, input, weight, bias):
577+
return torch.nn.functional.linear(input, weight, bias)
578+
579+
def forward_comfy_cast_weights(self, input):
580+
weight, bias = cast_bias_weight(self, input)
581+
return self._forward(input, weight, bias)
582+
583+
def forward(self, input, *args, **kwargs):
584+
run_every_op()
585+
586+
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
587+
return self.forward_comfy_cast_weights(input, *args, **kwargs)
588+
if (getattr(self, 'layout_type', None) is not None and
589+
getattr(self, 'input_scale', None) is not None and
590+
not isinstance(input, QuantizedTensor)):
591+
input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, fp8_dtype=self.weight.dtype)
592+
return self._forward(input, self.weight, self.bias)
593+
594+
595+
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None):
596+
if model_config and hasattr(model_config, 'layer_quant_config') and model_config.layer_quant_config:
597+
MixedPrecisionOps._layer_quant_config = model_config.layer_quant_config
598+
MixedPrecisionOps._compute_dtype = compute_dtype
599+
logging.info(f"Using mixed precision operations: {len(model_config.layer_quant_config)} quantized layers")
600+
return MixedPrecisionOps
601+
482602
fp8_compute = comfy.model_management.supports_fp8_compute(load_device)
483603
if scaled_fp8 is not None:
484604
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8)

0 commit comments

Comments
 (0)