Skip to content

Conversation

@vivienfanghuagood
Copy link

Overview

This PR adds support for the AITER attention backend to ComfyUI, providing an alternative high-performance attention implementation that can significantly improve inference speed on AMD GPUs(like MI300x, MI355)

Usage

To enable AIter attention, start ComfyUI with the --use-aiter-attention flag:

python main.py --use-aiter-attention

Performance Improvements

Tested on Qwen-Image model (KSampler, main time-consuming part):
Before: 1.27 iter/s
After: 1.48 iter/s
Speedup: ~16.5% improvement

@vivienfanghuagood
Copy link
Author

@Kosinkadink @comfyanonymous Hi, can you help review this PR? It will be widely available on AMD GPUs, thanks very much!

@asagi4
Copy link
Contributor

asagi4 commented Oct 30, 2025

I gave this a quick test and the integration doesn't seem to break anything, but unfortunately I wasn't able to get the aiter library working on my system so I can't say how useful this is; its JIT compilation failed with errors that seemed to involve assembly code. (maybe it requires a newer CPU than I have, or it doesn't support XTX 7900? I don't know.)

If this is merged, a bit of documentation about what it is and what kind of setups are supported wouldn't hurt.

EDIT: looks like consumer GPUs are just not supported, so it seems this is a pretty niche thing.

AITER_ATTENTION_IS_AVAILABLE = True
except ImportError:
if model_management.aiter_attention_enabled():
logging.error(f"\n\nTo use the `--use-aiter-attention` feature, the `aiter` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install aiter")
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"pip install aiter" doesn't install the right aiter.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I have change it. Because AITER doesn't provide whl package now, so we just tell user to refer to AITER repo, hope users can install it smoothly!

@comfyanonymous
Copy link
Owner

Getting this when trying the neta yume workflow in the ComfyUI templates:

got prompt
  0%|                                                                                             | 0/30 [00:00<?, ?it/s]Aiter Attention failed, using default SDPA: flash_attn_func() got an unexpected keyword argument 'cu_seqlens_q'
Aiter Attention failed, using default SDPA: flash_attn_func() got an unexpected keyword argument 'cu_seqlens_q'
Aiter Attention failed, using default SDPA: flash_attn_func() got an unexpected keyword argument 'cu_seqlens_q'
  0%|                                                                                             | 0/30 [00:00<?, ?it/s]
!!! Exception during processing !!! The expanded size of the tensor (24) must match the existing size (4096) at non-singleton dimension 3.  Target sizes: [1, 4096, 24, 24].  Tensor sizes: [1, 1, 1, 4096]
Traceback (most recent call last):
  File "/workspace/ComfyUI/comfy/ldm/modules/attention.py", line 667, in attention_aiter
    out = aiter_flash_attn_wrapper(
  File "/workspace/ComfyUI/comfy/ldm/modules/attention.py", line 638, in aiter_flash_attn_wrapper
    return aiter.flash_attn_func(q, k, v, dropout_p=dropout_p, softmax_scale=softmax_scale,
TypeError: flash_attn_func() got an unexpected keyword argument 'cu_seqlens_q'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/workspace/ComfyUI/execution.py", line 510, in execute
    output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, hidden_inputs=hidden_inputs)
  File "/workspace/ComfyUI/execution.py", line 324, in get_output_data
    return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, hidden_inputs=hidden_inputs)
  File "/workspace/ComfyUI/execution.py", line 298, in _async_map_node_over_list
    await process_inputs(input_dict, i)
  File "/workspace/ComfyUI/execution.py", line 286, in process_inputs
    result = f(**inputs)
  File "/workspace/ComfyUI/nodes.py", line 1525, in sample
    return common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise)
  File "/workspace/ComfyUI/nodes.py", line 1492, in common_ksampler
    samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image,
  File "/workspace/ComfyUI/comfy/sample.py", line 60, in sample
    samples = sampler.sample(noise, positive, negative, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed)
  File "/workspace/ComfyUI/comfy/samplers.py", line 1163, in sample
    return sample(self.model, noise, positive, negative, cfg, self.device, sampler, sigmas, self.model_options, latent_image=latent_image, denoise_mask=denoise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
  File "/workspace/ComfyUI/comfy/samplers.py", line 1053, in sample
    return cfg_guider.sample(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
  File "/workspace/ComfyUI/comfy/samplers.py", line 1035, in sample
    output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes)
  File "/workspace/ComfyUI/comfy/patcher_extension.py", line 112, in execute
    return self.original(*args, **kwargs)
  File "/workspace/ComfyUI/comfy/samplers.py", line 997, in outer_sample
    output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes)
  File "/workspace/ComfyUI/comfy/samplers.py", line 980, in inner_sample
    samples = executor.execute(self, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
  File "/workspace/ComfyUI/comfy/patcher_extension.py", line 112, in execute
    return self.original(*args, **kwargs)
  File "/workspace/ComfyUI/comfy/samplers.py", line 752, in sample
    samples = self.sampler_function(model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar, **self.extra_options)
  File "/opt/venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
  File "/workspace/ComfyUI/comfy/k_diffusion/sampling.py", line 1429, in sample_res_multistep
    return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_noise=s_noise, noise_sampler=noise_sampler, eta=0., cfg_pp=False)
  File "/opt/venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
  File "/workspace/ComfyUI/comfy/k_diffusion/sampling.py", line 1387, in res_multistep
    denoised = model(x, sigmas[i] * s_in, **extra_args)
  File "/workspace/ComfyUI/comfy/samplers.py", line 401, in __call__
    out = self.inner_model(x, sigma, model_options=model_options, seed=seed)
  File "/workspace/ComfyUI/comfy/samplers.py", line 953, in __call__
    return self.outer_predict_noise(*args, **kwargs)
  File "/workspace/ComfyUI/comfy/samplers.py", line 960, in outer_predict_noise
    ).execute(x, timestep, model_options, seed)
  File "/workspace/ComfyUI/comfy/patcher_extension.py", line 112, in execute
    return self.original(*args, **kwargs)
  File "/workspace/ComfyUI/comfy/samplers.py", line 963, in predict_noise
    return sampling_function(self.inner_model, x, timestep, self.conds.get("negative", None), self.conds.get("positive", None), self.cfg, model_options=model_options, seed=seed)
  File "/workspace/ComfyUI/comfy/samplers.py", line 381, in sampling_function
    out = calc_cond_batch(model, conds, x, timestep, model_options)
  File "/workspace/ComfyUI/comfy/samplers.py", line 206, in calc_cond_batch
    return _calc_cond_batch_outer(model, conds, x_in, timestep, model_options)
  File "/workspace/ComfyUI/comfy/samplers.py", line 214, in _calc_cond_batch_outer
    return executor.execute(model, conds, x_in, timestep, model_options)
  File "/workspace/ComfyUI/comfy/patcher_extension.py", line 112, in execute
    return self.original(*args, **kwargs)
  File "/workspace/ComfyUI/comfy/samplers.py", line 326, in _calc_cond_batch
    output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks)
  File "/workspace/ComfyUI/comfy/model_base.py", line 161, in apply_model
    return comfy.patcher_extension.WrapperExecutor.new_class_executor(
  File "/workspace/ComfyUI/comfy/patcher_extension.py", line 112, in execute
    return self.original(*args, **kwargs)
  File "/workspace/ComfyUI/comfy/model_base.py", line 203, in _apply_model
    model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds)
  File "/opt/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workspace/ComfyUI/comfy/ldm/lumina/model.py", line 599, in forward
    return comfy.patcher_extension.WrapperExecutor.new_class_executor(
  File "/workspace/ComfyUI/comfy/patcher_extension.py", line 112, in execute
    return self.original(*args, **kwargs)
  File "/workspace/ComfyUI/comfy/ldm/lumina/model.py", line 625, in _forward
    x, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens, transformer_options=transformer_options)
  File "/workspace/ComfyUI/comfy/ldm/lumina/model.py", line 580, in patchify_and_embed
    padded_img_embed = layer(padded_img_embed, padded_img_mask, img_freqs_cis, t, transformer_options=transformer_options)
  File "/opt/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workspace/ComfyUI/comfy/ldm/lumina/model.py", line 291, in forward
    self.attention(
  File "/opt/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workspace/ComfyUI/comfy/ldm/lumina/model.py", line 144, in forward
    output = optimized_attention_masked(xq.movedim(1, 2), xk.movedim(1, 2), xv.movedim(1, 2), self.n_local_heads, x_mask, skip_reshape=True, transformer_options=transformer_options)
  File "/workspace/ComfyUI/comfy/ldm/modules/attention.py", line 139, in wrapper
    return func(*args, **kwargs)
  File "/workspace/ComfyUI/comfy/ldm/modules/attention.py", line 685, in attention_aiter
    out = torch.nn.functional.scaled_dot_product_attention(q_sdpa, k_sdpa, v_sdpa, attn_mask=mask, dropout_p=0.0, is_causal=False)
RuntimeError: The expanded size of the tensor (24) must match the existing size (4096) at non-singleton dimension 3.  Target sizes: [1, 4096, 24, 24].  Tensor sizes: [1, 1, 1, 4096]

return out


def aiter_flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are you doing this? Just put aiter.flash_attn_func in the attention_aiter function directly.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are you doing this? Just put aiter.flash_attn_func in the attention_aiter function directly.

DONE

@vivienfanghuagood
Copy link
Author

I gave this a quick test and the integration doesn't seem to break anything, but unfortunately I wasn't able to get the aiter library working on my system so I can't say how useful this is; its JIT compilation failed with errors that seemed to involve assembly code. (maybe it requires a newer CPU than I have, or it doesn't support XTX 7900? I don't know.)

If this is merged, a bit of documentation about what it is and what kind of setups are supported wouldn't hurt.

EDIT: looks like consumer GPUs are just not supported, so it seems this is a pretty niche thing.

Yes, aiter is mainly for server GPU like MI300x/MI355, but I think consumer GPUs will be considered in the near future. I will clarify these requirements in code comments, thanks for try!

@vivienfanghuagood
Copy link
Author

Getting this when trying the neta yume workflow in the ComfyUI templates:

got prompt
  0%|                                                                                             | 0/30 [00:00<?, ?it/s]Aiter Attention failed, using default SDPA: flash_attn_func() got an unexpected keyword argument 'cu_seqlens_q'
Aiter Attention failed, using default SDPA: flash_attn_func() got an unexpected keyword argument 'cu_seqlens_q'
Aiter Attention failed, using default SDPA: flash_attn_func() got an unexpected keyword argument 'cu_seqlens_q'
  0%|                                                                                             | 0/30 [00:00<?, ?it/s]
!!! Exception during processing !!! The expanded size of the tensor (24) must match the existing size (4096) at non-singleton dimension 3.  Target sizes: [1, 4096, 24, 24].  Tensor sizes: [1, 1, 1, 4096]
Traceback (most recent call last):
  File "/workspace/ComfyUI/comfy/ldm/modules/attention.py", line 667, in attention_aiter
    out = aiter_flash_attn_wrapper(
  File "/workspace/ComfyUI/comfy/ldm/modules/attention.py", line 638, in aiter_flash_attn_wrapper
    return aiter.flash_attn_func(q, k, v, dropout_p=dropout_p, softmax_scale=softmax_scale,
TypeError: flash_attn_func() got an unexpected keyword argument 'cu_seqlens_q'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/workspace/ComfyUI/execution.py", line 510, in execute
    output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, hidden_inputs=hidden_inputs)
  File "/workspace/ComfyUI/execution.py", line 324, in get_output_data
    return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, hidden_inputs=hidden_inputs)
  File "/workspace/ComfyUI/execution.py", line 298, in _async_map_node_over_list
    await process_inputs(input_dict, i)
  File "/workspace/ComfyUI/execution.py", line 286, in process_inputs
    result = f(**inputs)
  File "/workspace/ComfyUI/nodes.py", line 1525, in sample
    return common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise)
  File "/workspace/ComfyUI/nodes.py", line 1492, in common_ksampler
    samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image,
  File "/workspace/ComfyUI/comfy/sample.py", line 60, in sample
    samples = sampler.sample(noise, positive, negative, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed)
  File "/workspace/ComfyUI/comfy/samplers.py", line 1163, in sample
    return sample(self.model, noise, positive, negative, cfg, self.device, sampler, sigmas, self.model_options, latent_image=latent_image, denoise_mask=denoise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
  File "/workspace/ComfyUI/comfy/samplers.py", line 1053, in sample
    return cfg_guider.sample(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
  File "/workspace/ComfyUI/comfy/samplers.py", line 1035, in sample
    output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes)
  File "/workspace/ComfyUI/comfy/patcher_extension.py", line 112, in execute
    return self.original(*args, **kwargs)
  File "/workspace/ComfyUI/comfy/samplers.py", line 997, in outer_sample
    output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes)
  File "/workspace/ComfyUI/comfy/samplers.py", line 980, in inner_sample
    samples = executor.execute(self, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
  File "/workspace/ComfyUI/comfy/patcher_extension.py", line 112, in execute
    return self.original(*args, **kwargs)
  File "/workspace/ComfyUI/comfy/samplers.py", line 752, in sample
    samples = self.sampler_function(model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar, **self.extra_options)
  File "/opt/venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
  File "/workspace/ComfyUI/comfy/k_diffusion/sampling.py", line 1429, in sample_res_multistep
    return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_noise=s_noise, noise_sampler=noise_sampler, eta=0., cfg_pp=False)
  File "/opt/venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
  File "/workspace/ComfyUI/comfy/k_diffusion/sampling.py", line 1387, in res_multistep
    denoised = model(x, sigmas[i] * s_in, **extra_args)
  File "/workspace/ComfyUI/comfy/samplers.py", line 401, in __call__
    out = self.inner_model(x, sigma, model_options=model_options, seed=seed)
  File "/workspace/ComfyUI/comfy/samplers.py", line 953, in __call__
    return self.outer_predict_noise(*args, **kwargs)
  File "/workspace/ComfyUI/comfy/samplers.py", line 960, in outer_predict_noise
    ).execute(x, timestep, model_options, seed)
  File "/workspace/ComfyUI/comfy/patcher_extension.py", line 112, in execute
    return self.original(*args, **kwargs)
  File "/workspace/ComfyUI/comfy/samplers.py", line 963, in predict_noise
    return sampling_function(self.inner_model, x, timestep, self.conds.get("negative", None), self.conds.get("positive", None), self.cfg, model_options=model_options, seed=seed)
  File "/workspace/ComfyUI/comfy/samplers.py", line 381, in sampling_function
    out = calc_cond_batch(model, conds, x, timestep, model_options)
  File "/workspace/ComfyUI/comfy/samplers.py", line 206, in calc_cond_batch
    return _calc_cond_batch_outer(model, conds, x_in, timestep, model_options)
  File "/workspace/ComfyUI/comfy/samplers.py", line 214, in _calc_cond_batch_outer
    return executor.execute(model, conds, x_in, timestep, model_options)
  File "/workspace/ComfyUI/comfy/patcher_extension.py", line 112, in execute
    return self.original(*args, **kwargs)
  File "/workspace/ComfyUI/comfy/samplers.py", line 326, in _calc_cond_batch
    output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks)
  File "/workspace/ComfyUI/comfy/model_base.py", line 161, in apply_model
    return comfy.patcher_extension.WrapperExecutor.new_class_executor(
  File "/workspace/ComfyUI/comfy/patcher_extension.py", line 112, in execute
    return self.original(*args, **kwargs)
  File "/workspace/ComfyUI/comfy/model_base.py", line 203, in _apply_model
    model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds)
  File "/opt/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workspace/ComfyUI/comfy/ldm/lumina/model.py", line 599, in forward
    return comfy.patcher_extension.WrapperExecutor.new_class_executor(
  File "/workspace/ComfyUI/comfy/patcher_extension.py", line 112, in execute
    return self.original(*args, **kwargs)
  File "/workspace/ComfyUI/comfy/ldm/lumina/model.py", line 625, in _forward
    x, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens, transformer_options=transformer_options)
  File "/workspace/ComfyUI/comfy/ldm/lumina/model.py", line 580, in patchify_and_embed
    padded_img_embed = layer(padded_img_embed, padded_img_mask, img_freqs_cis, t, transformer_options=transformer_options)
  File "/opt/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workspace/ComfyUI/comfy/ldm/lumina/model.py", line 291, in forward
    self.attention(
  File "/opt/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workspace/ComfyUI/comfy/ldm/lumina/model.py", line 144, in forward
    output = optimized_attention_masked(xq.movedim(1, 2), xk.movedim(1, 2), xv.movedim(1, 2), self.n_local_heads, x_mask, skip_reshape=True, transformer_options=transformer_options)
  File "/workspace/ComfyUI/comfy/ldm/modules/attention.py", line 139, in wrapper
    return func(*args, **kwargs)
  File "/workspace/ComfyUI/comfy/ldm/modules/attention.py", line 685, in attention_aiter
    out = torch.nn.functional.scaled_dot_product_attention(q_sdpa, k_sdpa, v_sdpa, attn_mask=mask, dropout_p=0.0, is_causal=False)
RuntimeError: The expanded size of the tensor (24) must match the existing size (4096) at non-singleton dimension 3.  Target sizes: [1, 4096, 24, 24].  Tensor sizes: [1, 1, 1, 4096]

Thanks for pointing the corner case when attention mask is not none. I have fixed it:
image
If you have more testing cases, I will try it!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants