From 8594a2ea8a2ef4bc039c1de1fde893b1c12eab0c Mon Sep 17 00:00:00 2001 From: kabachuha Date: Sun, 19 Oct 2025 15:15:48 +0300 Subject: [PATCH] TAG: Tangential Amplifying Guidance https://huggingface.co/papers/2510.04533 --- comfy_extras/nodes_tag.py | 65 +++++++++++++++++++++++++++++++++++++++ nodes.py | 1 + 2 files changed, 66 insertions(+) create mode 100644 comfy_extras/nodes_tag.py diff --git a/comfy_extras/nodes_tag.py b/comfy_extras/nodes_tag.py new file mode 100644 index 000000000000..1ddb10398daa --- /dev/null +++ b/comfy_extras/nodes_tag.py @@ -0,0 +1,65 @@ +# TAG: Tangential Amplifying Guidance - (arXiv: https://arxiv.org/pdf/2510.04533) + +from typing_extensions import override +import torch + +from comfy_api.latest import ComfyExtension, io + +class TAGGuidance(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="TAG-Guidance", + display_name="Tangential Amplifying Guidance", + category="advanced/guidance", + description="TAG - Tangential Amplifying Guidance (2510.04533)\n\nLeverages an intermediate sample as a projection basis and amplifies the tangential components of the estimated scores with respect to this basis to correct the sampling trajectory. Improves diffusion sampling fidelity with minimal computational addition", + inputs=[ + io.Model.Input("model"), + io.Float.Input("t_guidance_scale", default=1.0, min=0.0, max=20.0, step=0.05), + io.Float.Input("r_guidance_scale", default=1.0, min=0.0, max=20.0, step=0.05), + ], + outputs=[ + io.Model.Output(display_name="patched_model"), + ], + ) + + @classmethod + def execute(cls, model, t_guidance_scale, r_guidance_scale): + m = model.clone() + + def tag_guidance(args): + + post_latents = args['input'] + v_t_2d = post_latents / (post_latents.norm(p=2, dim=(1,2,3), keepdim=True) + 1e-8) + + latents = args['denoised'] + + delta_latents = latents - post_latents + delta_unit = (delta_latents * v_t_2d).sum(dim=(1,2,3), keepdim=True) + + normal_update_vector = delta_unit * v_t_2d + tangential_update_vector = delta_latents - normal_update_vector + + eta_v = t_guidance_scale + eta_n = r_guidance_scale + + latents = post_latents + \ + eta_v * tangential_update_vector + \ + eta_n * normal_update_vector + + return latents + + m.set_model_sampler_post_cfg_function(tag_guidance) + return io.NodeOutput(m) + + +class TagExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + TAGGuidance, + ] + + +async def comfy_entrypoint() -> TagExtension: + return TagExtension() diff --git a/nodes.py b/nodes.py index 7cfa8ca1411d..d8b9a5b3a99b 100644 --- a/nodes.py +++ b/nodes.py @@ -2322,6 +2322,7 @@ async def init_builtin_extra_nodes(): "nodes_string.py", "nodes_camera_trajectory.py", "nodes_edit_model.py", + "nodes_tag.py", "nodes_tcfg.py", "nodes_context_windows.py", "nodes_qwen.py",