Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion comfy/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -1098,13 +1098,14 @@ def cast_to_device(tensor, device, dtype, copy=False):
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.95
logging.info("Enabled pinned memory {}".format(MAX_PINNED_MEMORY // (1024 * 1024)))

PINNING_ALLOWED_TYPES = set(["Parameter", "QuantizedTensor"])

def pin_memory(tensor):
global TOTAL_PINNED_MEMORY
if MAX_PINNED_MEMORY <= 0:
return False

if type(tensor) is not torch.nn.parameter.Parameter:
if type(tensor).__name__ not in PINNING_ALLOWED_TYPES:
return False

if not is_device_cpu(tensor.device):
Expand All @@ -1124,6 +1125,9 @@ def pin_memory(tensor):
return False

ptr = tensor.data_ptr()
if ptr == 0:
return False

if torch.cuda.cudart().cudaHostRegister(ptr, size, 1) == 0:
PINNED_MEMORY[ptr] = size
TOTAL_PINNED_MEMORY += size
Expand Down
3 changes: 1 addition & 2 deletions comfy/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,6 @@ def __init__(self, model, load_device, offload_device, size=0, weight_inplace_up
self.object_patches_backup = {}
self.weight_wrapper_patches = {}
self.model_options = {"transformer_options":{}}
self.model_size()
self.load_device = load_device
self.offload_device = offload_device
self.weight_inplace_update = weight_inplace_update
Expand Down Expand Up @@ -286,7 +285,7 @@ def lowvram_patch_counter(self):
return self.model.lowvram_patch_counter

def clone(self):
n = self.__class__(self.model, self.load_device, self.offload_device, self.size, weight_inplace_update=self.weight_inplace_update)
n = self.__class__(self.model, self.load_device, self.offload_device, self.model_size(), weight_inplace_update=self.weight_inplace_update)
n.patches = {}
for k in self.patches:
n.patches[k] = self.patches[k][:]
Expand Down
191 changes: 97 additions & 94 deletions comfy/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,115 +540,118 @@ def forward(self, *args, **kwargs):
# ==============================================================================
from .quant_ops import QuantizedTensor, QUANT_ALGOS

class MixedPrecisionOps(disable_weight_init):
_layer_quant_config = {}
_compute_dtype = torch.bfloat16

class Linear(torch.nn.Module, CastWeightBiasOp):
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
device=None,
dtype=None,
) -> None:
super().__init__()

self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype}
# self.factory_kwargs = {"device": device, "dtype": dtype}

self.in_features = in_features
self.out_features = out_features
if bias:
self.bias = torch.nn.Parameter(torch.empty(out_features, **self.factory_kwargs))
else:
self.register_parameter("bias", None)

self.tensor_class = None
def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False):
class MixedPrecisionOps(manual_cast):
_layer_quant_config = layer_quant_config
_compute_dtype = compute_dtype
_full_precision_mm = full_precision_mm

class Linear(torch.nn.Module, CastWeightBiasOp):
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
device=None,
dtype=None,
) -> None:
super().__init__()

self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype}
# self.factory_kwargs = {"device": device, "dtype": dtype}

self.in_features = in_features
self.out_features = out_features
if bias:
self.bias = torch.nn.Parameter(torch.empty(out_features, **self.factory_kwargs))
else:
self.register_parameter("bias", None)

def reset_parameters(self):
return None
self.tensor_class = None
self._full_precision_mm = MixedPrecisionOps._full_precision_mm

def _load_from_state_dict(self, state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys, error_msgs):
def reset_parameters(self):
return None

device = self.factory_kwargs["device"]
layer_name = prefix.rstrip('.')
weight_key = f"{prefix}weight"
weight = state_dict.pop(weight_key, None)
if weight is None:
raise ValueError(f"Missing weight for layer {layer_name}")
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys, error_msgs):

manually_loaded_keys = [weight_key]
device = self.factory_kwargs["device"]
layer_name = prefix.rstrip('.')
weight_key = f"{prefix}weight"
weight = state_dict.pop(weight_key, None)
if weight is None:
raise ValueError(f"Missing weight for layer {layer_name}")

if layer_name not in MixedPrecisionOps._layer_quant_config:
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
else:
quant_format = MixedPrecisionOps._layer_quant_config[layer_name].get("format", None)
if quant_format is None:
raise ValueError(f"Unknown quantization format for layer {layer_name}")

qconfig = QUANT_ALGOS[quant_format]
self.layout_type = qconfig["comfy_tensor_layout"]

weight_scale_key = f"{prefix}weight_scale"
layout_params = {
'scale': state_dict.pop(weight_scale_key, None),
'orig_dtype': MixedPrecisionOps._compute_dtype,
'block_size': qconfig.get("group_size", None),
}
if layout_params['scale'] is not None:
manually_loaded_keys.append(weight_scale_key)

self.weight = torch.nn.Parameter(
QuantizedTensor(weight.to(device=device), self.layout_type, layout_params),
requires_grad=False
)

for param_name in qconfig["parameters"]:
param_key = f"{prefix}{param_name}"
_v = state_dict.pop(param_key, None)
if _v is None:
continue
setattr(self, param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
manually_loaded_keys.append(param_key)

super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)

for key in manually_loaded_keys:
if key in missing_keys:
missing_keys.remove(key)

def _forward(self, input, weight, bias):
return torch.nn.functional.linear(input, weight, bias)
manually_loaded_keys = [weight_key]

def forward_comfy_cast_weights(self, input):
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
x = self._forward(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
if layer_name not in MixedPrecisionOps._layer_quant_config:
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
else:
quant_format = MixedPrecisionOps._layer_quant_config[layer_name].get("format", None)
if quant_format is None:
raise ValueError(f"Unknown quantization format for layer {layer_name}")

qconfig = QUANT_ALGOS[quant_format]
self.layout_type = qconfig["comfy_tensor_layout"]

weight_scale_key = f"{prefix}weight_scale"
layout_params = {
'scale': state_dict.pop(weight_scale_key, None),
'orig_dtype': MixedPrecisionOps._compute_dtype,
'block_size': qconfig.get("group_size", None),
}
if layout_params['scale'] is not None:
manually_loaded_keys.append(weight_scale_key)

self.weight = torch.nn.Parameter(
QuantizedTensor(weight.to(device=device), self.layout_type, layout_params),
requires_grad=False
)

for param_name in qconfig["parameters"]:
param_key = f"{prefix}{param_name}"
_v = state_dict.pop(param_key, None)
if _v is None:
continue
setattr(self, param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
manually_loaded_keys.append(param_key)

super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)

for key in manually_loaded_keys:
if key in missing_keys:
missing_keys.remove(key)

def _forward(self, input, weight, bias):
return torch.nn.functional.linear(input, weight, bias)

def forward(self, input, *args, **kwargs):
run_every_op()
def forward_comfy_cast_weights(self, input):
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
x = self._forward(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x

if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(input, *args, **kwargs)
if (getattr(self, 'layout_type', None) is not None and
getattr(self, 'input_scale', None) is not None and
not isinstance(input, QuantizedTensor)):
input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, dtype=self.weight.dtype)
return self._forward(input, self.weight, self.bias)
def forward(self, input, *args, **kwargs):
run_every_op()

if self._full_precision_mm or self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(input, *args, **kwargs)
if (getattr(self, 'layout_type', None) is not None and
getattr(self, 'input_scale', None) is not None and
not isinstance(input, QuantizedTensor)):
input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, dtype=self.weight.dtype)
return self._forward(input, self.weight, self.bias)
return MixedPrecisionOps

def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None):
fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular

if model_config and hasattr(model_config, 'layer_quant_config') and model_config.layer_quant_config:
MixedPrecisionOps._layer_quant_config = model_config.layer_quant_config
MixedPrecisionOps._compute_dtype = compute_dtype
logging.info(f"Using mixed precision operations: {len(model_config.layer_quant_config)} quantized layers")
return MixedPrecisionOps
return mixed_precision_ops(model_config.layer_quant_config, compute_dtype, full_precision_mm=not fp8_compute)

fp8_compute = comfy.model_management.supports_fp8_compute(load_device)
if scaled_fp8 is not None:
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8)

Expand Down
24 changes: 22 additions & 2 deletions comfy/quant_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,14 @@ def dequant_arg(arg):
new_kwargs = dequant_arg(kwargs)
return func(*new_args, **new_kwargs)

def data_ptr(self):
return self._qdata.data_ptr()

def is_pinned(self):
return self._qdata.is_pinned()

def is_contiguous(self):
return self._qdata.is_contiguous()

# ==============================================================================
# Generic Utilities (Layout-Agnostic Operations)
Expand Down Expand Up @@ -338,6 +346,18 @@ def generic_copy_(func, args, kwargs):
return func(*args, **kwargs)


@register_generic_util(torch.ops.aten.to.dtype)
def generic_to_dtype(func, args, kwargs):
"""Handle .to(dtype) calls - dtype conversion only."""
src = args[0]
if isinstance(src, QuantizedTensor):
# For dtype-only conversion, just change the orig_dtype, no real cast is needed
target_dtype = args[1] if len(args) > 1 else kwargs.get('dtype')
src._layout_params["orig_dtype"] = target_dtype
return src
return func(*args, **kwargs)


@register_generic_util(torch.ops.aten._has_compatible_shallow_copy_type.default)
def generic_has_compatible_shallow_copy_type(func, args, kwargs):
return True
Expand Down Expand Up @@ -385,8 +405,8 @@ def quantize(cls, tensor, scale=None, dtype=torch.float8_e4m3fn):

tensor_scaled = tensor * (1.0 / scale).to(tensor.dtype)
# TODO: uncomment this if it's actually needed because the clamp has a small performance penality'
# lp_amax = torch.finfo(dtype).max
# torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled)
lp_amax = torch.finfo(dtype).max
torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled)
qdata = tensor_scaled.to(dtype, memory_format=torch.contiguous_format)

layout_params = {
Expand Down
9 changes: 8 additions & 1 deletion comfy/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -917,7 +917,12 @@ class CLIPType(Enum):
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
clip_data = []
for p in ckpt_paths:
clip_data.append(comfy.utils.load_torch_file(p, safe_load=True))
sd, metadata = comfy.utils.load_torch_file(p, safe_load=True, return_metadata=True)
if metadata is not None:
quant_metadata = metadata.get("_quantization_metadata", None)
if quant_metadata is not None:
sd["_quantization_metadata"] = quant_metadata
clip_data.append(sd)
return load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options)


Expand Down Expand Up @@ -1142,6 +1147,8 @@ class EmptyClass:

parameters = 0
for c in clip_data:
if "_quantization_metadata" in c:
c.pop("_quantization_metadata")
parameters += comfy.utils.calculate_parameters(c)
tokenizer_data, model_options = comfy.text_encoders.long_clipl.model_options_long_clip(c, tokenizer_data, model_options)

Expand Down
18 changes: 14 additions & 4 deletions comfy/sd1_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,13 +109,23 @@ def __init__(self, device="cpu", max_length=77,

operations = model_options.get("custom_operations", None)
scaled_fp8 = None
quantization_metadata = model_options.get("quantization_metadata", None)

if operations is None:
scaled_fp8 = model_options.get("scaled_fp8", None)
if scaled_fp8 is not None:
operations = comfy.ops.scaled_fp8_ops(fp8_matrix_mult=False, override_dtype=scaled_fp8)
layer_quant_config = None
if quantization_metadata is not None:
layer_quant_config = json.loads(quantization_metadata).get("layers", None)

if layer_quant_config is not None:
operations = comfy.ops.mixed_precision_ops(layer_quant_config, dtype, full_precision_mm=True)
logging.info(f"Using MixedPrecisionOps for text encoder: {len(layer_quant_config)} quantized layers")
else:
operations = comfy.ops.manual_cast
# Fallback to scaled_fp8_ops for backward compatibility
scaled_fp8 = model_options.get("scaled_fp8", None)
if scaled_fp8 is not None:
operations = comfy.ops.scaled_fp8_ops(fp8_matrix_mult=False, override_dtype=scaled_fp8)
else:
operations = comfy.ops.manual_cast

self.operations = operations
self.transformer = model_class(config, dtype, device, self.operations)
Expand Down
3 changes: 3 additions & 0 deletions comfy/text_encoders/hunyuan_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ def llama_detect(state_dict, prefix=""):
if scaled_fp8_key in state_dict:
out["llama_scaled_fp8"] = state_dict[scaled_fp8_key].dtype

if "_quantization_metadata" in state_dict:
out["llama_quantization_metadata"] = state_dict["_quantization_metadata"]

return out


Expand Down
Loading
Loading