Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion app/user_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ async def move_userdata(request):
return source

dest = get_user_data_path(request, check_exists=False, param="dest")
if not isinstance(source, str):
if not isinstance(dest, str):
return dest

overwrite = request.query.get("overwrite", 'true') != "false"
Expand Down
22 changes: 18 additions & 4 deletions comfy/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -1042,6 +1042,9 @@ def get_offload_stream(device):
if NUM_STREAMS == 0:
return None

if torch.compiler.is_compiling():
return None

if device in STREAMS:
ss = STREAMS[device]
#Sync the oldest stream in the queue with the current
Expand All @@ -1052,15 +1055,19 @@ def get_offload_stream(device):
elif is_device_cuda(device):
ss = []
for k in range(NUM_STREAMS):
ss.append(torch.cuda.Stream(device=device, priority=0))
s1 = torch.cuda.Stream(device=device, priority=0)
s1.as_context = torch.cuda.stream
ss.append(s1)
STREAMS[device] = ss
s = ss[stream_counter]
stream_counters[device] = stream_counter
return s
elif is_device_xpu(device):
ss = []
for k in range(NUM_STREAMS):
ss.append(torch.xpu.Stream(device=device, priority=0))
s1 = torch.xpu.Stream(device=device, priority=0)
s1.as_context = torch.xpu.stream
ss.append(s1)
STREAMS[device] = ss
s = ss[stream_counter]
stream_counters[device] = stream_counter
Expand All @@ -1078,12 +1085,19 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str
if dtype is None or weight.dtype == dtype:
return weight
if stream is not None:
with stream:
wf_context = stream
if hasattr(wf_context, "as_context"):
wf_context = wf_context.as_context(stream)
with wf_context:
return weight.to(dtype=dtype, copy=copy)
return weight.to(dtype=dtype, copy=copy)


if stream is not None:
with stream:
wf_context = stream
if hasattr(wf_context, "as_context"):
wf_context = wf_context.as_context(stream)
with wf_context:
r = torch.empty_like(weight, dtype=dtype, device=device)
r.copy_(weight, non_blocking=non_blocking)
else:
Expand Down
2 changes: 2 additions & 0 deletions comfy/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of

if offload_stream is not None:
wf_context = offload_stream
if hasattr(wf_context, "as_context"):
wf_context = wf_context.as_context(offload_stream)
else:
wf_context = contextlib.nullcontext()

Expand Down
4 changes: 2 additions & 2 deletions comfy/quant_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,8 +235,8 @@ def data_ptr(self):
def is_pinned(self):
return self._qdata.is_pinned()

def is_contiguous(self):
return self._qdata.is_contiguous()
def is_contiguous(self, *arg, **kwargs):
return self._qdata.is_contiguous(*arg, **kwargs)

# ==============================================================================
# Generic Utilities (Layout-Agnostic Operations)
Expand Down
Loading