Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ git clone https://github.com/mengqin/ComfyUI-UnetBnbModelLoader ComfyUI/custom_n
.\python_embeded\python.exe -s -m pip install -r .\ComfyUI\custom_nodes\ComfyUI-UnetBnbModelLoader\requirements.txt
```

Because this plugin relies on bitsandbytes, we are unable to support macOS and AMD GPUs.
Because this plugin relies on bitsandbytes, we are unable to support macOS.

## Usage

Expand Down
17 changes: 10 additions & 7 deletions ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
data=weight_data, quantized_stats=quant_state_dict, device=device
)
self.weight = bnb_param

self._bnb_quant_state_dict = quant_state_dict

for k in bnb_state_dict.keys():
state_dict.pop(k)
if k in unexpected_keys: unexpected_keys.remove(k)
Expand Down Expand Up @@ -94,9 +95,10 @@ def forward(self, x):
if getattr(self, "is_bnb_quantized", lambda : False)():
if not patches_for_this_layer:
bias = self.bias.to(device=x.device, dtype=x.dtype) if self.bias is not None else None
return bnb.matmul_4bit(
x, self.weight.t(), bias=bias, quant_state=getattr(self.weight, "quant_state", None)
).to(x.dtype)
qs = getattr(self.weight, "quant_state", None)
if qs is None and hasattr(self, "_bnb_quant_state_dict"):
qs = bnb.functional.QuantState.from_dict(self._bnb_quant_state_dict, device=x.device)
return bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=qs).to(x.dtype)

try:
base_w = self.weight.to(x.device)
Expand All @@ -113,9 +115,10 @@ def forward(self, x):

if weight_final_fp32 is None:
bias = self.bias.to(device=x.device, dtype=x.dtype) if self.bias is not None else None
return bnb.matmul_4bit(
x, self.weight.t(), bias=bias, quant_state=getattr(self.weight, "quant_state", None)
).to(x.dtype)
qs = getattr(self.weight, "quant_state", None)
if qs is None and hasattr(self, "_bnb_quant_state_dict"):
qs = bnb.functional.QuantState.from_dict(self._bnb_quant_state_dict, device=x.device)
return bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=qs).to(x.dtype)

weight_final = comfy.float.stochastic_rounding(weight_final_fp32, x.dtype)
bias = self.bias.to(device=x.device, dtype=x.dtype) if self.bias is not None else None
Expand Down