Skip to content

Commit b680542

Browse files
Allow pinning quantized tensors. (Comfy-Org#10873)
1 parent 25022e0 commit b680542

File tree

2 files changed

+13
-1
lines changed

2 files changed

+13
-1
lines changed

comfy/model_management.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1098,13 +1098,14 @@ def cast_to_device(tensor, device, dtype, copy=False):
10981098
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.95
10991099
logging.info("Enabled pinned memory {}".format(MAX_PINNED_MEMORY // (1024 * 1024)))
11001100

1101+
PINNING_ALLOWED_TYPES = set(["Parameter", "QuantizedTensor"])
11011102

11021103
def pin_memory(tensor):
11031104
global TOTAL_PINNED_MEMORY
11041105
if MAX_PINNED_MEMORY <= 0:
11051106
return False
11061107

1107-
if type(tensor) is not torch.nn.parameter.Parameter:
1108+
if type(tensor).__name__ not in PINNING_ALLOWED_TYPES:
11081109
return False
11091110

11101111
if not is_device_cpu(tensor.device):
@@ -1124,6 +1125,9 @@ def pin_memory(tensor):
11241125
return False
11251126

11261127
ptr = tensor.data_ptr()
1128+
if ptr == 0:
1129+
return False
1130+
11271131
if torch.cuda.cudart().cudaHostRegister(ptr, size, 1) == 0:
11281132
PINNED_MEMORY[ptr] = size
11291133
TOTAL_PINNED_MEMORY += size

comfy/quant_ops.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,14 @@ def dequant_arg(arg):
228228
new_kwargs = dequant_arg(kwargs)
229229
return func(*new_args, **new_kwargs)
230230

231+
def data_ptr(self):
232+
return self._qdata.data_ptr()
233+
234+
def is_pinned(self):
235+
return self._qdata.is_pinned()
236+
237+
def is_contiguous(self):
238+
return self._qdata.is_contiguous()
231239

232240
# ==============================================================================
233241
# Generic Utilities (Layout-Agnostic Operations)

0 commit comments

Comments
 (0)