-
Notifications
You must be signed in to change notification settings - Fork 12.3k
ggml : fix FA mask dim 2 and 3 #14505
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Edit: nvm, I think my approach was not correct. |
2a20a7e
to
89ee2f1
Compare
@JohannesGaessler @jeffbolznv For now I will disable the broadcast of To summarize, the correct broadcast logic that we have to support is as follow: Lines 1983 to 2003 in 89ee2f1
In practice, the The This way, the mask that we pass to Merging this for now so I can continue working on top of this and later on we'll hopefully add support for these cases. |
a65fa3a
to
b1b22ae
Compare
If we use dimension 3 to identify the sequence that a mask belongs to, do we still need broadcasting? Intuitively I would have thought that there would be a 1:1 mapping between sequences and attention masks. But in that case, wouldn't it be simpler to just run one instance of |
It is correct that
Technically, when we reach the attention, we could make views of the batch and loop over each sequence, but this has a few disadvantages:
It also has some advantages, but overall I think the single-shot version is better. The multi-shot strategy that you mention is currently a backup plan that I think we should consider only if the single-shot version does not work for some reason. |
I started implementing this expecting it to be straightforward, but got a bit confused on handling grouped query attention. Seems like we can't group when mask->ne[2] != 1 because the different iq2 values all want different masks. Is this intended? I could disable the optimization for that case, just want to make sure that's expected and not a problem with the definition. |
I could have missed something again. But is there a reason preventing to use mask with index at dimension 2: (here |
Right, this is the problematic part. The GQA path wants to have a single matrix per CTA, but now this would be multiple matrices per CTA. You had previously said "In practice, the ne32 will probably always be equal to 1", so maybe it's fine to just disable it. |
Yes, it's fine to disable it. Can you enable it for |
Seems to work fine, latest commit enables it. |
* origin/master: Fix conditional enabling following arch checks for ggml-sycl (ggml-org#14504) convert : correct gemma 3n conversion (ggml-org#14450) kv-cache : use ggml_set_rows (ggml-org#14285) ggml : fix FA mask dim 2 and 3 (ggml-org#14505) ggml : remove kompute backend (ggml-org#14501) CUDA: add dynamic shared mem to softmax, refactor general usage (ggml-org#14497)
In #14500, @JohannesGaessler correctly noted that the FA did not utilize dim 3 of the mask. I overlooked this and now as I was updating #14363 realized that we need to align the dimensions.
The fix is simple, I will try to update it myself across the Vulkan and CUDA backends later today.Also small fix for the
ggml_soft_max_ext()
: was incorrectly requiring the mask to be a 3D array + test for this.