Skip to content

Commit b27d3d8

Browse files
committed
ops: Implement prefetching API
Implement an API that allows instrumenting a model with a prefetch queue. Units of work are on the nn.Module level.
1 parent 7436674 commit b27d3d8

File tree

1 file changed

+106
-24
lines changed

1 file changed

+106
-24
lines changed

comfy/ops.py

Lines changed: 106 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,93 @@ def cast_to_input(weight, input, non_blocking=False, copy=True):
7171
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
7272

7373

74+
def cast_prefetch_all(module, device):
75+
if not comfy.model_management.device_supports_non_blocking(device):
76+
#Adios! prefetching works against you if you can't get the CPU past it
77+
return None
78+
79+
offload_stream = None
80+
81+
for n, m in module.named_modules():
82+
if hasattr(m, "comfy_cast_weights"):
83+
if m.weight is not None and m.weight.device != device and not hasattr(m, "weight_prefetch"):
84+
if offload_stream is None:
85+
offload_stream = comfy.model_management.get_offload_stream(device)
86+
if offload_stream is None:
87+
return None
88+
m.weight_prefetch = comfy.model_management.cast_to(m.weight, None, device, non_blocking=True, copy=True, stream=offload_stream)
89+
if m.bias is not None and m.bias.device != device and not hasattr(m, "bias_prefetch"):
90+
if offload_stream is None:
91+
offload_stream = comfy.model_management.get_offload_stream(device)
92+
if offload_stream is None:
93+
return None
94+
m.bias_prefetch = comfy.model_management.cast_to(m.bias, None, device, non_blocking=True, copy = True, stream=offload_stream)
95+
96+
return offload_stream
97+
98+
99+
def uncast_prefetch_all(module):
100+
for n, m in module.named_modules():
101+
if hasattr(m, "comfy_cast_weights"):
102+
if hasattr(m, "weight_prefetch"):
103+
delattr(m, "weight_prefetch")
104+
if hasattr(m, "bias_prefetch"):
105+
delattr(m, "bias_prefetch")
106+
107+
108+
def prefetch_queue_pop(queue, device, module):
109+
consumed = queue.pop(0)
110+
if consumed is not None:
111+
offload_stream, m = consumed
112+
#Sync the offload stream with compute so when it starts
113+
#freeing the prefetches the compute stream has finished
114+
if offload_stream is not None:
115+
offload_stream.wait_stream(comfy.model_management.current_stream(device))
116+
uncast_prefetch_all(m)
117+
118+
active = queue[0]
119+
if active is not None:
120+
offload_stream, m = active
121+
assert m == module
122+
#wait for the prefetch to complete before using the data
123+
if offload_stream is not None:
124+
comfy.model_management.sync_stream(device, offload_stream)
125+
126+
prefetch = queue[1]
127+
if prefetch is not None:
128+
offload_stream = comfy.ops.cast_prefetch_all(prefetch, device)
129+
queue[1] = (offload_stream, prefetch)
130+
131+
132+
def make_prefetch_queue(queue):
133+
return [None, None] + queue + [None, None]
134+
135+
136+
def move_bias_weight(s, device, offloadable=False):
137+
138+
bias_has_function = len(s.bias_function) > 0
139+
weight_has_function = len(s.weight_function) > 0
140+
141+
if offloadable and (
142+
s.weight.device != device or (s.bias is not None and s.bias.device != device) or
143+
bias_has_function or weight_has_function):
144+
offload_stream = comfy.model_management.get_offload_stream(device)
145+
else:
146+
offload_stream = None
147+
148+
bias = None
149+
non_blocking = comfy.model_management.device_supports_non_blocking(device)
150+
151+
weight = comfy.model_management.cast_to(s.weight, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream)
152+
153+
if s.bias is not None:
154+
bias = comfy.model_management.cast_to(s.bias, None, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream)
155+
156+
comfy.model_management.sync_stream(device, offload_stream)
157+
158+
return weight, bias, offload_stream
159+
160+
74161
@torch.compiler.disable()
75162
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False):
76163
# NOTE: offloadable=False is a a legacy and if you are a custom node author reading this please pass
@@ -84,40 +171,35 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
84171
if device is None:
85172
device = input.device
86173

87-
if offloadable and (device != s.weight.device or
88-
(s.bias is not None and device != s.bias.device)):
89-
offload_stream = comfy.model_management.get_offload_stream(device)
90-
else:
174+
bias_has_function = len(s.bias_function) > 0
175+
weight_has_function = len(s.weight_function) > 0
176+
177+
if hasattr(s, "weight_prefetch") or hasattr(s, "bias_prefetch"):
178+
weight = getattr(s, "weight_prefetch", None)
179+
bias = getattr(s, "bias_prefetch", None)
91180
offload_stream = None
181+
else:
182+
weight, bias, offload_stream = move_bias_weight(s, device, offloadable=offloadable)
92183

93184
if offload_stream is not None:
94185
wf_context = offload_stream
95186
else:
96187
wf_context = contextlib.nullcontext()
97188

98-
non_blocking = comfy.model_management.device_supports_non_blocking(device)
99-
100-
weight_has_function = len(s.weight_function) > 0
101-
bias_has_function = len(s.bias_function) > 0
102-
103-
weight = comfy.model_management.cast_to(s.weight, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream)
104-
105-
bias = None
106-
if s.bias is not None:
107-
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream)
189+
if weight_has_function:
190+
weight=weight.to(dtype=dtype)
191+
for f in s.weight_function:
192+
weight = f(weight)
108193

109-
if bias_has_function:
110-
with wf_context:
111-
for f in s.bias_function:
112-
bias = f(bias)
194+
if s.bias is not None and bias_has_function:
195+
bias=bias.to(dtype=bias_dtype)
196+
for f in s.bias_function:
197+
bias = f(bias)
113198

114-
weight = weight.to(dtype=dtype)
115-
if weight_has_function:
116-
with wf_context:
117-
for f in s.weight_function:
118-
weight = f(weight)
199+
weight=weight.to(dtype=dtype)
200+
if bias is not None:
201+
bias=bias.to(dtype=bias_dtype)
119202

120-
comfy.model_management.sync_stream(device, offload_stream)
121203
if offloadable:
122204
return weight, bias, offload_stream
123205
else:

0 commit comments

Comments
 (0)