Skip to content

Commit 7436674

Browse files
committed
ops: prioritize mem transfer
The async offload streams reason for existence is to transfer from RAM to GPU. The post processing compute steps are a bonus on the side stream, but if the compute stream is running a long kernel, it can stall the side stream, as it wait to type-cast the bias before transferring the weight. So do a pure xfer of the weight straight up, then do everything bias, then go back to fix the weight type and do weight patches.
1 parent fcdb4a5 commit 7436674

File tree

1 file changed

+11
-7
lines changed

1 file changed

+11
-7
lines changed

comfy/ops.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -95,20 +95,24 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
9595
else:
9696
wf_context = contextlib.nullcontext()
9797

98-
bias = None
9998
non_blocking = comfy.model_management.device_supports_non_blocking(device)
99+
100+
weight_has_function = len(s.weight_function) > 0
101+
bias_has_function = len(s.bias_function) > 0
102+
103+
weight = comfy.model_management.cast_to(s.weight, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream)
104+
105+
bias = None
100106
if s.bias is not None:
101-
has_function = len(s.bias_function) > 0
102-
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream)
107+
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream)
103108

104-
if has_function:
109+
if bias_has_function:
105110
with wf_context:
106111
for f in s.bias_function:
107112
bias = f(bias)
108113

109-
has_function = len(s.weight_function) > 0
110-
weight = comfy.model_management.cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream)
111-
if has_function:
114+
weight = weight.to(dtype=dtype)
115+
if weight_has_function:
112116
with wf_context:
113117
for f in s.weight_function:
114118
weight = f(weight)

0 commit comments

Comments
 (0)