Skip to content

Commit 3fa7a5c

Browse files
Speed up offloading using pinned memory. (Comfy-Org#10526)
To enable this feature use: --fast pinned_memory
1 parent 210f7a1 commit 3fa7a5c

File tree

3 files changed

+56
-1
lines changed

3 files changed

+56
-1
lines changed

comfy/cli_args.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ class PerformanceFeature(enum.Enum):
144144
Fp8MatrixMultiplication = "fp8_matrix_mult"
145145
CublasOps = "cublas_ops"
146146
AutoTune = "autotune"
147+
PinnedMem = "pinned_memory"
147148

148149
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature))))
149150

comfy/model_management.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1080,6 +1080,36 @@ def cast_to_device(tensor, device, dtype, copy=False):
10801080
non_blocking = device_supports_non_blocking(device)
10811081
return cast_to(tensor, dtype=dtype, device=device, non_blocking=non_blocking, copy=copy)
10821082

1083+
def pin_memory(tensor):
1084+
if PerformanceFeature.PinnedMem not in args.fast:
1085+
return False
1086+
1087+
if not is_nvidia():
1088+
return False
1089+
1090+
if not is_device_cpu(tensor.device):
1091+
return False
1092+
1093+
if torch.cuda.cudart().cudaHostRegister(tensor.data_ptr(), tensor.numel() * tensor.element_size(), 1) == 0:
1094+
return True
1095+
1096+
return False
1097+
1098+
def unpin_memory(tensor):
1099+
if PerformanceFeature.PinnedMem not in args.fast:
1100+
return False
1101+
1102+
if not is_nvidia():
1103+
return False
1104+
1105+
if not is_device_cpu(tensor.device):
1106+
return False
1107+
1108+
if torch.cuda.cudart().cudaHostUnregister(tensor.data_ptr()) == 0:
1109+
return True
1110+
1111+
return False
1112+
10831113
def sage_attention_enabled():
10841114
return args.use_sage_attention
10851115

comfy/model_patcher.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,7 @@ def __init__(self, model, load_device, offload_device, size=0, weight_inplace_up
238238
self.force_cast_weights = False
239239
self.patches_uuid = uuid.uuid4()
240240
self.parent = None
241+
self.pinned = set()
241242

242243
self.attachments: dict[str] = {}
243244
self.additional_models: dict[str, list[ModelPatcher]] = {}
@@ -618,6 +619,21 @@ def patch_weight_to_device(self, key, device_to=None, inplace_update=False):
618619
else:
619620
set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key))
620621

622+
def pin_weight_to_device(self, key):
623+
weight, set_func, convert_func = get_key_weight(self.model, key)
624+
if comfy.model_management.pin_memory(weight):
625+
self.pinned.add(key)
626+
627+
def unpin_weight(self, key):
628+
if key in self.pinned:
629+
weight, set_func, convert_func = get_key_weight(self.model, key)
630+
comfy.model_management.unpin_memory(weight)
631+
self.pinned.remove(key)
632+
633+
def unpin_all_weights(self):
634+
for key in list(self.pinned):
635+
self.unpin_weight(key)
636+
621637
def _load_list(self):
622638
loading = []
623639
for n, m in self.model.named_modules():
@@ -683,6 +699,8 @@ def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False
683699
patch_counter += 1
684700

685701
cast_weight = True
702+
for param in params:
703+
self.pin_weight_to_device("{}.{}".format(n, param))
686704
else:
687705
if hasattr(m, "comfy_cast_weights"):
688706
wipe_lowvram_weight(m)
@@ -713,7 +731,9 @@ def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False
713731
continue
714732

715733
for param in params:
716-
self.patch_weight_to_device("{}.{}".format(n, param), device_to=device_to)
734+
key = "{}.{}".format(n, param)
735+
self.unpin_weight(key)
736+
self.patch_weight_to_device(key, device_to=device_to)
717737

718738
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
719739
m.comfy_patched_weights = True
@@ -762,6 +782,7 @@ def unpatch_model(self, device_to=None, unpatch_weights=True):
762782
self.eject_model()
763783
if unpatch_weights:
764784
self.unpatch_hooks()
785+
self.unpin_all_weights()
765786
if self.model.model_lowvram:
766787
for m in self.model.modules():
767788
move_weight_functions(m, device_to)
@@ -857,6 +878,9 @@ def partially_unload(self, device_to, memory_to_free=0):
857878
memory_freed += module_mem
858879
logging.debug("freed {}".format(n))
859880

881+
for param in params:
882+
self.pin_weight_to_device("{}.{}".format(n, param))
883+
860884
self.model.model_lowvram = True
861885
self.model.lowvram_patch_counter += patch_counter
862886
self.model.model_loaded_weight_memory -= memory_freed

0 commit comments

Comments
 (0)