You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
[RFC] Refactor attention and make attention mask an argument to the model
**Status**
The PR is not landable yet but server as a RFC. If people are okay with
this design, this PR requires following changes and verifications:
1. Change all models, including the experimental ones.
2. E2E loss verification (this has been done for functional check, but loss verification is noot done yet).
3. We should add an unittest for attention. But since we don't have GPU unittest, this can be done in a seperate PR.
**Summary**
This PR aims to refactor how TorchTitan build the attention masks and pass to model. Before this PR, init_attention_masks() is called in Trainer but the masks are stored as a class variable of FlexAttentionWrapper(). We chose this shortcut to support the case where a single model requires multiple masks.
The previous design has several issues, one particular one is #1723.
Now that pytorch/pytorch#164111 proves that we can let PP split BlockMask, this PR performs the refactor to pass masks as an argument of model.forward().
The new design:
1. Model needs to provide `get_attention_masks()` that accepts `create_mask_fn`, `batch`, and `eos_id`. If the attention op is SDPA, then this API should return None as SDPA currently doesn't support varlen. But once it does, we may have to return some tuple of int that represents the mask.
Justification: attention logic is technically a part of the model, but requires some information from trainer/dataloader. So it's model author's responsibility to provide some API that let trainer calls to get the masks.
2. `get_attention_masks()` will be called from the trainer and the resulting masks are passed to the model.forward().
Justification: this will allow us to
fix#1723 with pytorch/pytorch#164111 and this PR.
3. Provide a single AttentionOp instead of two.
Justification: since the masking logic is moved outside, we don't need
to do bookkeeping of masks in FlexAttentionWrapper. The logic is so
simple that one AttentionOp makes things cleaner.
Note: we still have two very very thin op wrappers that are used for CP. I keep these two for the CP education purpose. But this certinaly can be confusion for Titan's users. I'm opn to merge them to AttentionOp.
See the discussion in #1723.
ghstack-source-id: 35aa425
Pull-Request-resolved: #1776
0 commit comments