From 2af2e9d9206f3c0b5a7a82e062fea8a29e2ea788 Mon Sep 17 00:00:00 2001 From: 0xDELUXA Date: Wed, 18 Mar 2026 20:44:03 +0200 Subject: [PATCH] Fix quant_state None on AMD GPUs by caching quant_state_dict at load time --- README.md | 2 +- ops.py | 17 ++++++++++------- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 97752c2..57f6180 100644 --- a/README.md +++ b/README.md @@ -42,7 +42,7 @@ git clone https://github.com/mengqin/ComfyUI-UnetBnbModelLoader ComfyUI/custom_n .\python_embeded\python.exe -s -m pip install -r .\ComfyUI\custom_nodes\ComfyUI-UnetBnbModelLoader\requirements.txt ``` -Because this plugin relies on bitsandbytes, we are unable to support macOS and AMD GPUs. +Because this plugin relies on bitsandbytes, we are unable to support macOS. ## Usage diff --git a/ops.py b/ops.py index 509b954..c4b2b99 100644 --- a/ops.py +++ b/ops.py @@ -55,7 +55,8 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, data=weight_data, quantized_stats=quant_state_dict, device=device ) self.weight = bnb_param - + self._bnb_quant_state_dict = quant_state_dict + for k in bnb_state_dict.keys(): state_dict.pop(k) if k in unexpected_keys: unexpected_keys.remove(k) @@ -94,9 +95,10 @@ def forward(self, x): if getattr(self, "is_bnb_quantized", lambda : False)(): if not patches_for_this_layer: bias = self.bias.to(device=x.device, dtype=x.dtype) if self.bias is not None else None - return bnb.matmul_4bit( - x, self.weight.t(), bias=bias, quant_state=getattr(self.weight, "quant_state", None) - ).to(x.dtype) + qs = getattr(self.weight, "quant_state", None) + if qs is None and hasattr(self, "_bnb_quant_state_dict"): + qs = bnb.functional.QuantState.from_dict(self._bnb_quant_state_dict, device=x.device) + return bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=qs).to(x.dtype) try: base_w = self.weight.to(x.device) @@ -113,9 +115,10 @@ def forward(self, x): if weight_final_fp32 is None: bias = self.bias.to(device=x.device, dtype=x.dtype) if self.bias is not None else None - return bnb.matmul_4bit( - x, self.weight.t(), bias=bias, quant_state=getattr(self.weight, "quant_state", None) - ).to(x.dtype) + qs = getattr(self.weight, "quant_state", None) + if qs is None and hasattr(self, "_bnb_quant_state_dict"): + qs = bnb.functional.QuantState.from_dict(self._bnb_quant_state_dict, device=x.device) + return bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=qs).to(x.dtype) weight_final = comfy.float.stochastic_rounding(weight_final_fp32, x.dtype) bias = self.bias.to(device=x.device, dtype=x.dtype) if self.bias is not None else None