@@ -71,6 +71,93 @@ def cast_to_input(weight, input, non_blocking=False, copy=True):
7171 return comfy .model_management .cast_to (weight , input .dtype , input .device , non_blocking = non_blocking , copy = copy )
7272
7373
74+ def cast_prefetch_all (module , device ):
75+ if not comfy .model_management .device_supports_non_blocking (device ):
76+ #Adios! prefetching works against you if you can't get the CPU past it
77+ return None
78+
79+ offload_stream = None
80+
81+ for n , m in module .named_modules ():
82+ if hasattr (m , "comfy_cast_weights" ):
83+ if m .weight is not None and m .weight .device != device and not hasattr (m , "weight_prefetch" ):
84+ if offload_stream is None :
85+ offload_stream = comfy .model_management .get_offload_stream (device )
86+ if offload_stream is None :
87+ return None
88+ m .weight_prefetch = comfy .model_management .cast_to (m .weight , None , device , non_blocking = True , copy = True , stream = offload_stream )
89+ if m .bias is not None and m .bias .device != device and not hasattr (m , "bias_prefetch" ):
90+ if offload_stream is None :
91+ offload_stream = comfy .model_management .get_offload_stream (device )
92+ if offload_stream is None :
93+ return None
94+ m .bias_prefetch = comfy .model_management .cast_to (m .bias , None , device , non_blocking = True , copy = True , stream = offload_stream )
95+
96+ return offload_stream
97+
98+
99+ def uncast_prefetch_all (module ):
100+ for n , m in module .named_modules ():
101+ if hasattr (m , "comfy_cast_weights" ):
102+ if hasattr (m , "weight_prefetch" ):
103+ delattr (m , "weight_prefetch" )
104+ if hasattr (m , "bias_prefetch" ):
105+ delattr (m , "bias_prefetch" )
106+
107+
108+ def prefetch_queue_pop (queue , device , module ):
109+ consumed = queue .pop (0 )
110+ if consumed is not None :
111+ offload_stream , m = consumed
112+ #Sync the offload stream with compute so when it starts
113+ #freeing the prefetches the compute stream has finished
114+ if offload_stream is not None :
115+ offload_stream .wait_stream (comfy .model_management .current_stream (device ))
116+ uncast_prefetch_all (m )
117+
118+ active = queue [0 ]
119+ if active is not None :
120+ offload_stream , m = active
121+ assert m == module
122+ #wait for the prefetch to complete before using the data
123+ if offload_stream is not None :
124+ comfy .model_management .sync_stream (device , offload_stream )
125+
126+ prefetch = queue [1 ]
127+ if prefetch is not None :
128+ offload_stream = comfy .ops .cast_prefetch_all (prefetch , device )
129+ queue [1 ] = (offload_stream , prefetch )
130+
131+
132+ def make_prefetch_queue (queue ):
133+ return [None , None ] + queue + [None , None ]
134+
135+
136+ def move_bias_weight (s , device , offloadable = False ):
137+
138+ bias_has_function = len (s .bias_function ) > 0
139+ weight_has_function = len (s .weight_function ) > 0
140+
141+ if offloadable and (
142+ s .weight .device != device or (s .bias is not None and s .bias .device != device ) or
143+ bias_has_function or weight_has_function ):
144+ offload_stream = comfy .model_management .get_offload_stream (device )
145+ else :
146+ offload_stream = None
147+
148+ bias = None
149+ non_blocking = comfy .model_management .device_supports_non_blocking (device )
150+
151+ weight = comfy .model_management .cast_to (s .weight , None , device , non_blocking = non_blocking , copy = weight_has_function , stream = offload_stream )
152+
153+ if s .bias is not None :
154+ bias = comfy .model_management .cast_to (s .bias , None , device , non_blocking = non_blocking , copy = bias_has_function , stream = offload_stream )
155+
156+ comfy .model_management .sync_stream (device , offload_stream )
157+
158+ return weight , bias , offload_stream
159+
160+
74161@torch .compiler .disable ()
75162def cast_bias_weight (s , input = None , dtype = None , device = None , bias_dtype = None , offloadable = False ):
76163 # NOTE: offloadable=False is a a legacy and if you are a custom node author reading this please pass
@@ -84,40 +171,35 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
84171 if device is None :
85172 device = input .device
86173
87- if offloadable and (device != s .weight .device or
88- (s .bias is not None and device != s .bias .device )):
89- offload_stream = comfy .model_management .get_offload_stream (device )
90- else :
174+ bias_has_function = len (s .bias_function ) > 0
175+ weight_has_function = len (s .weight_function ) > 0
176+
177+ if hasattr (s , "weight_prefetch" ) or hasattr (s , "bias_prefetch" ):
178+ weight = getattr (s , "weight_prefetch" , None )
179+ bias = getattr (s , "bias_prefetch" , None )
91180 offload_stream = None
181+ else :
182+ weight , bias , offload_stream = move_bias_weight (s , device , offloadable = offloadable )
92183
93184 if offload_stream is not None :
94185 wf_context = offload_stream
95186 else :
96187 wf_context = contextlib .nullcontext ()
97188
98- non_blocking = comfy .model_management .device_supports_non_blocking (device )
99-
100- weight_has_function = len (s .weight_function ) > 0
101- bias_has_function = len (s .bias_function ) > 0
102-
103- weight = comfy .model_management .cast_to (s .weight , None , device , non_blocking = non_blocking , copy = weight_has_function , stream = offload_stream )
104-
105- bias = None
106- if s .bias is not None :
107- bias = comfy .model_management .cast_to (s .bias , bias_dtype , device , non_blocking = non_blocking , copy = bias_has_function , stream = offload_stream )
189+ if weight_has_function :
190+ weight = weight .to (dtype = dtype )
191+ for f in s .weight_function :
192+ weight = f (weight )
108193
109- if bias_has_function :
110- with wf_context :
111- for f in s .bias_function :
112- bias = f (bias )
194+ if s . bias is not None and bias_has_function :
195+ bias = bias . to ( dtype = bias_dtype )
196+ for f in s .bias_function :
197+ bias = f (bias )
113198
114- weight = weight .to (dtype = dtype )
115- if weight_has_function :
116- with wf_context :
117- for f in s .weight_function :
118- weight = f (weight )
199+ weight = weight .to (dtype = dtype )
200+ if bias is not None :
201+ bias = bias .to (dtype = bias_dtype )
119202
120- comfy .model_management .sync_stream (device , offload_stream )
121203 if offloadable :
122204 return weight , bias , offload_stream
123205 else :
0 commit comments