Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions comfy/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,16 +96,16 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
pin = comfy.pinned_memory.get_pin(s)
if pin is not None:
xfer_source = [ pin ]
else:
for data, geometry in zip([ s.weight, s.bias ], cast_geometry):
if data is None:
continue
if data.dtype != geometry.dtype:
cast_dest = xfer_dest
if cast_dest is None:
cast_dest = torch.empty((comfy.memory_management.vram_aligned_size(cast_geometry),), dtype=torch.uint8, device=device)
xfer_dest = None
break

for data, geometry in zip([ s.weight, s.bias ], cast_geometry):
if data is None:
continue
if data.dtype != geometry.dtype:
cast_dest = xfer_dest
if cast_dest is None:
cast_dest = torch.empty((comfy.memory_management.vram_aligned_size(cast_geometry),), dtype=torch.uint8, device=device)
xfer_dest = None
break

dest_size = comfy.memory_management.vram_aligned_size(xfer_source)
offload_stream = comfy.model_management.get_offload_stream(device)
Expand All @@ -132,7 +132,7 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
comfy.model_management.sync_stream(device, offload_stream)

if cast_dest is not None:
for pre_cast, post_cast in zip(comfy.memory_management.interpret_gathered_like(xfer_source, xfer_dest),
for pre_cast, post_cast in zip(comfy.memory_management.interpret_gathered_like([s.weight, s.bias ], xfer_dest),
comfy.memory_management.interpret_gathered_like(cast_geometry, cast_dest)):
if post_cast is not None:
post_cast.copy_(pre_cast)
Expand Down
3 changes: 1 addition & 2 deletions comfy/pinned_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@ def pin_memory(module):
if module.pin_failed or args.disable_pinned_memory or get_pin(module) is not None:
return
#FIXME: This is a RAM cache trigger event
params = comfy.memory_management.tensors_to_geometries([ module.weight, module.bias ])
size = comfy.memory_management.vram_aligned_size(params)
size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ])
pin = torch.empty((size,), dtype=torch.uint8)
if comfy.model_management.pin_memory(pin):
module._pin = pin
Expand Down
Loading