Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions comfy/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from comfy.cli_args import args, PerformanceFeature
import comfy.float
import comfy.rmsnorm
import contextlib

cast_to = comfy.model_management.cast_to #TODO: remove once no more references

Expand All @@ -38,20 +39,28 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
device = input.device

offload_stream = comfy.model_management.get_offload_stream(device)
if offload_stream is not None:
wf_context = offload_stream
else:
wf_context = contextlib.nullcontext()

bias = None
non_blocking = comfy.model_management.device_supports_non_blocking(device)
if s.bias is not None:
has_function = len(s.bias_function) > 0
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream)

if has_function:
for f in s.bias_function:
bias = f(bias)
with wf_context:
for f in s.bias_function:
bias = f(bias)

has_function = len(s.weight_function) > 0
weight = comfy.model_management.cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream)
if has_function:
for f in s.weight_function:
weight = f(weight)
with wf_context:
for f in s.weight_function:
weight = f(weight)

comfy.model_management.sync_stream(device, offload_stream)
return weight, bias
Expand Down
17 changes: 17 additions & 0 deletions hook_breaker_ac10a0.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Prevent custom nodes from hooking anything important
import comfy.model_management

HOOK_BREAK = [(comfy.model_management, "cast_to")]


SAVED_FUNCTIONS = []


def save_functions():
for f in HOOK_BREAK:
SAVED_FUNCTIONS.append((f[0], f[1], getattr(f[0], f[1])))


def restore_functions():
for f in SAVED_FUNCTIONS:
setattr(f[0], f[1], f[2])
5 changes: 4 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def execute_script(script_path):
import comfy.model_management
import comfyui_version
import app.logger

import hook_breaker_ac10a0

def cuda_malloc_warning():
device = comfy.model_management.get_torch_device()
Expand Down Expand Up @@ -215,6 +215,7 @@ def prompt_worker(q, server_instance):
comfy.model_management.soft_empty_cache()
last_gc_collect = current_time
need_gc = False
hook_breaker_ac10a0.restore_functions()


async def run(server_instance, address='', port=8188, verbose=True, call_on_start=None):
Expand Down Expand Up @@ -268,7 +269,9 @@ def start_comfyui(asyncio_loop=None):
prompt_server = server.PromptServer(asyncio_loop)
q = execution.PromptQueue(prompt_server)

hook_breaker_ac10a0.save_functions()
nodes.init_extra_nodes(init_custom_nodes=not args.disable_all_custom_nodes)
hook_breaker_ac10a0.restore_functions()

cuda_malloc_warning()

Expand Down
Loading