Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 18 additions & 9 deletions comfy/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -1202,27 +1202,36 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str
assert r is None
assert stream is None

r = torch.empty_like(weight, dtype=weight._model_dtype, device=device)
cast_geometry = comfy.memory_management.tensors_to_geometries([ weight ])

if dtype is None:
dtype = weight._model_dtype

r = torch.empty_like(weight, dtype=dtype, device=device)

signature = comfy_aimdo.model_vbar.vbar_fault(weight._v)
if signature is not None:
raw_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device)
v_tensor = comfy.memory_management.interpret_gathered_like([r], raw_tensor)[0]

if comfy_aimdo.model_vbar.vbar_signature_compare(signature, weight._v_signature):
v_tensor = comfy.memory_management.interpret_gathered_like(cast_geometry, raw_tensor)[0]
if not comfy_aimdo.model_vbar.vbar_signature_compare(signature, weight._v_signature):
weight._v_signature = signature
#Send it over
v_tensor.copy_(weight, non_blocking=non_blocking)
#always take a deep copy even if _v is good, as we have no reasonable point to unpin
#a non comfy weight
r.copy_(v_tensor)
comfy_aimdo.model_vbar.vbar_unpin(weight._v)
return r

if weight.dtype != r.dtype and weight.dtype != weight._model_dtype:
#Offloaded casting could skip this, however it would make the quantizations
#inconsistent between loaded and offloaded weights. So force the double casting
#that would happen in regular flow to make offload deterministic.
cast_buffer = torch.empty_like(weight, dtype=weight._model_dtype, device=device)
cast_buffer.copy_(weight, non_blocking=non_blocking)
weight = cast_buffer
r.copy_(weight, non_blocking=non_blocking)

if signature is not None:
weight._v_signature = signature
v_tensor.copy_(r)
comfy_aimdo.model_vbar.vbar_unpin(weight._v)

return r

if device is None or weight.device == device:
Expand Down
17 changes: 11 additions & 6 deletions comfy/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,11 @@ def get_key_weight(model, key):

return weight, set_func, convert_func

def key_param_name_to_key(key, param):
if len(key) == 0:
return param
return "{}.{}".format(key, param)

class AutoPatcherEjector:
def __init__(self, model: 'ModelPatcher', skip_and_inject_on_exit_only=False):
self.model = model
Expand Down Expand Up @@ -795,7 +800,7 @@ def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False
continue

for param in params:
key = "{}.{}".format(n, param)
key = key_param_name_to_key(n, param)
self.unpin_weight(key)
self.patch_weight_to_device(key, device_to=device_to)
if comfy.model_management.is_device_cuda(device_to):
Expand All @@ -811,7 +816,7 @@ def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False
n = x[1]
params = x[3]
for param in params:
self.pin_weight_to_device("{}.{}".format(n, param))
self.pin_weight_to_device(key_param_name_to_key(n, param))

usable_stat = "{:.2f} MB usable,".format(lowvram_model_memory / (1024 * 1024)) if lowvram_model_memory < 1e32 else ""
if lowvram_counter > 0:
Expand Down Expand Up @@ -917,7 +922,7 @@ def partially_unload(self, device_to, memory_to_free=0, force_patch_weights=Fals
if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True:
move_weight = True
for param in params:
key = "{}.{}".format(n, param)
key = key_param_name_to_key(n, param)
bk = self.backup.get(key, None)
if bk is not None:
if not lowvram_possible:
Expand Down Expand Up @@ -968,7 +973,7 @@ def partially_unload(self, device_to, memory_to_free=0, force_patch_weights=Fals
logging.debug("freed {}".format(n))

for param in params:
self.pin_weight_to_device("{}.{}".format(n, param))
self.pin_weight_to_device(key_param_name_to_key(n, param))


self.model.model_lowvram = True
Expand Down Expand Up @@ -1501,7 +1506,7 @@ def set_dirty(item, dirty):

def setup_param(self, m, n, param_key):
nonlocal num_patches
key = "{}.{}".format(n, param_key)
key = key_param_name_to_key(n, param_key)

weight_function = []

Expand Down Expand Up @@ -1540,7 +1545,7 @@ def setup_param(self, m, n, param_key):

else:
for param in params:
key = "{}.{}".format(n, param)
key = key_param_name_to_key(n, param)
weight, _, _ = get_key_weight(self.model, key)
weight.seed_key = key
set_dirty(weight, dirty)
Expand Down
6 changes: 5 additions & 1 deletion comfy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import json
import time
import mmap
import warnings

MMAP_TORCH_FILES = args.mmap_torch_files
DISABLE_MMAP = args.disable_mmap
Expand Down Expand Up @@ -85,7 +86,10 @@ def load_safetensors(ckpt):
header_size = struct.unpack("<Q", mapping[:8])[0]
header = json.loads(mapping[8:8+header_size].decode("utf-8"))

data_area = torch.frombuffer(mapping, dtype=torch.uint8)[8 + header_size:]
with warnings.catch_warnings():
#We are working with read-only RAM by design
warnings.filterwarnings("ignore", message="The given buffer is not writable")
data_area = torch.frombuffer(mapping, dtype=torch.uint8)[8 + header_size:]

sd = {}
for name, info in header.items():
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ alembic
SQLAlchemy
av>=14.2.0
comfy-kitchen>=0.2.7
comfy-aimdo>=0.1.6
comfy-aimdo>=0.1.7
requests

#non essential dependencies:
Expand Down
Loading