Skip to content

fix: attach CP attention-mask hooks for dense (non-TE) context parallelism#1470

Open
hemildesai wants to merge 10 commits intomainfrom
hemil/cp-dense-fixes
Open

fix: attach CP attention-mask hooks for dense (non-TE) context parallelism#1470
hemildesai wants to merge 10 commits intomainfrom
hemil/cp-dense-fixes

Conversation

@hemildesai
Copy link
Contributor

Summary

  • Add _attach_context_parallel_hooks to register forward pre-hooks on self_attn modules that strip attention_mask and set is_causal=True, fixing shape mismatches when dense (non-TE) context parallelism shards Q/K/V as DTensors
  • Call the hooks in TrainFinetuneRecipeForNextTokenPrediction when cp_size > 1 and TE attention is not used
  • Add unit tests for the new hook function and the attention_mask removal in make_cp_batch_and_ctx

Test plan

  • Unit tests pass: pytest tests/unit_tests/distributed/test_cp_utils.py (12 tests, all passing)
  • Manual validation with dense CP training (non-TE backend)

🤖 Generated with Claude Code

@copy-pr-bot
Copy link

copy-pr-bot bot commented Mar 6, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@hemildesai
Copy link
Contributor Author

/ok to test e3fb07e

@hemildesai
Copy link
Contributor Author

/ok to test 5165660

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we move the changes in this file into model_init.py

hemildesai and others added 5 commits March 6, 2026 09:27
…elism

Strip the 4D attention_mask from the batch and register forward pre-hooks
on self_attn modules to set is_causal=True, so that SDPA handles causal
masking internally when using dense context parallelism without TE.

Signed-off-by: hemildesai <hemild@nvidia.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: hemildesai <hemild@nvidia.com>
Replace functools.partial(F.scaled_dot_product_attention, ...) with a
closure that resolves F.scaled_dot_product_attention at call time. This
ensures CP's runtime monkey-patch of the function is picked up by all
custom models instead of being bypassed by the early-bound reference.

Also make _attach_context_parallel_hooks public (renamed to
attach_context_parallel_hooks).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: hemildesai <hemild@nvidia.com>
…ends

Extract SDPA backend selection into a resolve_sdpa_method() helper that
accepts string names from YAML config (e.g. ["flash_attention",
"efficient_attention"]) and converts them to SDPBackend enum members.
When no explicit config is provided, auto-selects based on CP and
activation checkpointing constraints.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: hemildesai <hemild@nvidia.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: hemildesai <hemild@nvidia.com>
Replace the assert that required all attention modules to be TE
DotProductAttention with a continue, so dense (SDPA) attention
modules are gracefully skipped. This allows MoE models to use
context parallelism with non-TE attention backends.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: hemildesai <hemild@nvidia.com>
@hemildesai hemildesai force-pushed the hemil/cp-dense-fixes branch from 59deabe to 1d2838b Compare March 6, 2026 17:28
@hemildesai
Copy link
Contributor Author

/ok to test 1d2838b

Move the resolve_sdpa_method helper from train_ft.py to
_transformers/model_init.py per review feedback. The config
resolution (reading sdpa_method from YAML and passing it to
build_model) remains in train_ft.py.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: hemildesai <hemild@nvidia.com>
@hemildesai
Copy link
Contributor Author

/ok to test 274666d

@akoumpa
Copy link
Contributor

akoumpa commented Mar 7, 2026

/ok to test 9098ba2

hemildesai and others added 2 commits March 6, 2026 16:38
Move the attach_context_parallel_hooks call from train_ft.py into
apply_model_infrastructure in infrastructure.py, which already has
access to the device mesh. Add _uses_te_attention helper that inspects
the model's self_attn.attn_module instances to determine if TE
DotProductAttention is used, replacing the config-based check.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: hemildesai <hemild@nvidia.com>
When the model is an AutoPipeline, iterate over model.parts to inspect
self_attn modules instead of only the pipeline wrapper itself.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: hemildesai <hemild@nvidia.com>
@hemildesai
Copy link
Contributor Author

/ok to test 32e55a4

The SDPA attn_func changed from functools.partial to a closure,
so .keywords no longer exists. Mock F.scaled_dot_product_attention
and inspect call kwargs instead.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants