Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 76 additions & 0 deletions comfy_extras/nodes_apg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import torch

def project(v0, v1):
v1 = torch.nn.functional.normalize(v1, dim=[-1, -2, -3])
v0_parallel = (v0 * v1).sum(dim=[-1, -2, -3], keepdim=True) * v1
v0_orthogonal = v0 - v0_parallel
return v0_parallel, v0_orthogonal

class APG:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL",),
"eta": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01, "tooltip": "Controls the scale of the parallel guidance vector. Default CFG behavior at a setting of 1."}),
"norm_threshold": ("FLOAT", {"default": 5.0, "min": 0.0, "max": 50.0, "step": 0.1, "tooltip": "Normalize guidance vector to this value, normalization disable at a setting of 0."}),
"momentum": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip":"Controls a running average of guidance during diffusion, disabled at a setting of 0."}),
}
}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "sampling/custom_sampling"

def patch(self, model, eta, norm_threshold, momentum):
running_avg = 0
prev_sigma = None

def pre_cfg_function(args):
nonlocal running_avg, prev_sigma

if len(args["conds_out"]) == 1: return args["conds_out"]

cond = args["conds_out"][0]
uncond = args["conds_out"][1]
sigma = args["sigma"][0]
cond_scale = args["cond_scale"]

if prev_sigma is not None and sigma > prev_sigma:
running_avg = 0
prev_sigma = sigma

guidance = cond - uncond

if momentum > 0:
if not torch.is_tensor(running_avg):
running_avg = guidance
else:
running_avg = momentum * running_avg + guidance
guidance = running_avg

if norm_threshold > 0:
guidance_norm = guidance.norm(p=2, dim=[-1, -2, -3], keepdim=True)
scale = torch.minimum(
torch.ones_like(guidance_norm),
norm_threshold / guidance_norm
)
guidance = guidance * scale

guidance_parallel, guidance_orthogonal = project(guidance, cond)
modified_guidance = guidance_orthogonal + eta * guidance_parallel

modified_cond = (uncond + modified_guidance) + (cond - uncond) / cond_scale

return [modified_cond, uncond] + args["conds_out"][2:]

m = model.clone()
m.set_model_sampler_pre_cfg_function(pre_cfg_function)
return (m,)

NODE_CLASS_MAPPINGS = {
"APG": APG,
}

NODE_DISPLAY_NAME_MAPPINGS = {
"APG": "Adaptive Projected Guidance",
}
1 change: 1 addition & 0 deletions nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2261,6 +2261,7 @@ def init_builtin_extra_nodes():
"nodes_optimalsteps.py",
"nodes_hidream.py",
"nodes_fresca.py",
"nodes_apg.py",
"nodes_preview_any.py",
"nodes_ace.py",
"nodes_string.py",
Expand Down
8 changes: 8 additions & 0 deletions script_examples/basic_api_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,14 @@

def queue_prompt(prompt):
p = {"prompt": prompt}

# If the workflow contains API nodes, you can add a Comfy API key to the `extra_data`` field of the payload.
# p["extra_data"] = {
# "api_key_comfy_org": "comfyui-87d01e28d*******************************************************" # replace with real key
# }
# See: https://docs.comfy.org/tutorials/api-nodes/overview
# Generate a key here: https://platform.comfy.org/login

data = json.dumps(p).encode('utf-8')
req = request.Request("http://127.0.0.1:8188/prompt", data=data)
request.urlopen(req)
Expand Down
Loading